os.path.join

Here are the examples of the python api os.path.join taken from open source projects. By voting up you can indicate which examples are most useful and appropriate.

200 Examples 7

Example 1

Project: auto-sklearn
Source File: plot_metafeatures.py
View license
def plot_metafeatures(metafeatures_plot_dir, metafeatures, metafeatures_times,
                      runs, method='pca', seed=1, depth=1, distance='l2'):
    """Project datasets in a 2d space and plot them.

    arguments:
      * metafeatures_plot_dir: a directory to save the generated plots
      * metafeatures: a pandas Dataframe from the MetaBase
      * runs: a dictionary of runs from the MetaBase
      * method: either pca or t-sne
      * seed: only used for t-sne
      * depth: if 1, a one-step look-ahead is performed
    """
    if type(metafeatures) != pd.DataFrame:
        raise ValueError("Argument metafeatures must be of type pd.Dataframe "
                         "but is %s" % str(type(metafeatures)))

    ############################################################################
    # Write out the datasets and their size as a TEX table
    # TODO put this in an own function
    dataset_tex = StringIO.StringIO()
    dataset_tex.write('\\begin{tabular}{lrrr}\n')
    dataset_tex.write('\\textbf{Dataset name} & '
                      '\\textbf{\#features} & '
                      '\\textbf{\#patterns} & '
                      '\\textbf{\#classes} \\\\\n')

    num_features = []
    num_instances = []
    num_classes = []

    for dataset in sorted(metafeatures.index):
        dataset_tex.write('%s & %d & %d & %d \\\\\n' % (
                        dataset.replace('larochelle_etal_2007_', '').replace(
                            '_', '-'),
                        metafeatures.loc[dataset]['number_of_features'],
                        metafeatures.loc[dataset]['number_of_instances'],
                        metafeatures.loc[dataset]['number_of_classes']))
        num_features.append(metafeatures.loc[dataset]['number_of_features'])
        num_instances.append(metafeatures.loc[dataset]['number_of_instances'])
        num_classes.append(metafeatures.loc[dataset]['number_of_classes'])

    dataset_tex.write('Minimum & %.1f & %.1f & %.1f \\\\\n' %
        (np.min(num_features), np.min(num_instances), np.min(num_classes)))
    dataset_tex.write('Maximum & %.1f & %.1f & %.1f \\\\\n' %
        (np.max(num_features), np.max(num_instances), np.max(num_classes)))
    dataset_tex.write('Mean & %.1f & %.1f & %.1f \\\\\n' %
        (np.mean(num_features), np.mean(num_instances), np.mean(num_classes)))

    dataset_tex.write('10\\%% quantile & %.1f & %.1f & %.1f \\\\\n' % (
        np.percentile(num_features, 10), np.percentile(num_instances, 10),
        np.percentile(num_classes, 10)))
    dataset_tex.write('90\\%% quantile & %.1f & %.1f & %.1f \\\\\n' % (
        np.percentile(num_features, 90), np.percentile(num_instances, 90),
        np.percentile(num_classes, 90)))
    dataset_tex.write('median & %.1f & %.1f & %.1f \\\\\n' % (
        np.percentile(num_features, 50), np.percentile(num_instances, 50),
        np.percentile(num_classes, 50)))
    dataset_tex.write('\\end{tabular}')
    dataset_tex.seek(0)

    dataset_tex_output = os.path.join(metafeatures_plot_dir, 'datasets.tex')
    with open(dataset_tex_output, 'w') as fh:
        fh.write(dataset_tex.getvalue())

    ############################################################################
    # Write out a list of metafeatures, each with the min/max/mean
    # calculation time and the min/max/mean value
    metafeatures_tex = StringIO.StringIO()
    metafeatures_tex.write('\\begin{tabular}{lrrrrrr}\n')
    metafeatures_tex.write('\\textbf{Metafeature} & '
                      '\\textbf{Minimum} & '
                      '\\textbf{Mean} & '
                      '\\textbf{Maximum} &'
                      '\\textbf{Minimum time} &'
                      '\\textbf{Mean time} &'
                      '\\textbf{Maximum time} '
                      '\\\\\n')

    for mf_name in sorted(metafeatures.columns):
        metafeatures_tex.write('%s & %.2f & %.2f & %.2f & %.2f & %.2f & %.2f \\\\\n'
                               % (mf_name.replace('_', '-'),
                                  metafeatures.loc[:,mf_name].min(),
                                  metafeatures.loc[:,mf_name].mean(),
                                  metafeatures.loc[:,mf_name].max(),
                                  metafeature_times.loc[:, mf_name].min(),
                                  metafeature_times.loc[:, mf_name].mean(),
                                  metafeature_times.loc[:, mf_name].max()))

    metafeatures_tex.write('\\end{tabular}')
    metafeatures_tex.seek(0)

    metafeatures_tex_output = os.path.join(metafeatures_plot_dir, 'metafeatures.tex')
    with open(metafeatures_tex_output, 'w') as fh:
        fh.write(metafeatures_tex.getvalue())

    # Without this scaling the transformation for visualization purposes is
    # useless
    metafeatures = metafeatures.copy()
    X_min = np.nanmin(metafeatures, axis=0)
    X_max = np.nanmax(metafeatures, axis=0)
    metafeatures = (metafeatures - X_min) / (X_max - X_min)

    # PCA
    if method == 'pca':
        pca = PCA(2)
        transformation = pca.fit_transform(metafeatures.values)

    elif method == 't-sne':
        if distance == 'l2':
            distance_matrix = sklearn.metrics.pairwise.pairwise_distances(
                metafeatures.values, metric='l2')
        elif distance == 'l1':
            distance_matrix = sklearn.metrics.pairwise.pairwise_distances(
                metafeatures.values, metric='l1')
        elif distance == 'runs':
            names_to_indices = dict()
            for metafeature in metafeatures.index:
                idx = len(names_to_indices)
                names_to_indices[metafeature] = idx

            X, Y = pyMetaLearn.metalearning.create_datasets\
                .create_predict_spearman_rank(metafeatures, runs,
                                              'combination')
            # Make a metric matrix out of Y
            distance_matrix = np.zeros((metafeatures.shape[0],
                                        metafeatures.shape[0]), dtype=np.float64)

            for idx in Y.index:
                dataset_names = idx.split("_")
                d1 = names_to_indices[dataset_names[0]]
                d2 = names_to_indices[dataset_names[1]]
                distance_matrix[d1][d2] = Y.loc[idx]
                distance_matrix[d2][d1] = Y.loc[idx]

        else:
            raise NotImplementedError()

        # For whatever reason, tsne doesn't accept l1 metric
        tsne = TSNE(random_state=seed, perplexity=50, verbose=1)
        transformation = tsne.fit_transform(distance_matrix)

    # Transform the transformation back to range [0, 1] to ease plotting
    transformation_min = np.nanmin(transformation, axis=0)
    transformation_max = np.nanmax(transformation, axis=0)
    transformation = (transformation - transformation_min) / \
                     (transformation_max - transformation_min)
    print(transformation_min, transformation_max)

    #for i, dataset in enumerate(directory_content):
    #    print dataset, meta_feature_array[i]
    fig = plt.figure(dpi=600, figsize=(12, 12))
    ax = plt.subplot(111)

    # The dataset names must be aligned at the borders of the plot in a way
    # the arrows don't cross each other. First, define the different slots
    # where the labels will be positioned and then figure out the optimal
    # order of the labels
    slots = []
    # 25 datasets on the top y-axis
    slots.extend([(-0.1 + 0.05 * i, 1.1) for i in range(25)])
    # 24 datasets on the right x-axis
    slots.extend([(1.1, 1.05 - 0.05 * i) for i in range(24)])
    # 25 datasets on the bottom y-axis
    slots.extend([(-0.1 + 0.05 * i, -0.1) for i in range(25)])
    # 24 datasets on the left x-axis
    slots.extend([(-0.1, 1.05 - 0.05 * i) for i in range(24)])

    # Align the labels on the outer axis
    labels_top = []
    labels_left = []
    labels_right = []
    labels_bottom = []

    for values in zip(metafeatures.index,
                      transformation[:, 0], transformation[:, 1]):
        label, x, y = values
        # Although all plot area goes up to 1.1, 1.1, the range of all the
        # points lies inside [0,1]
        if x >= y and x < 1.0 - y:
            labels_bottom.append((x, label))
        elif x >= y and x >= 1.0 - y:
            labels_right.append((y, label))
        elif y > x and x <= 1.0 -y:
             labels_left.append((y, label))
        else:
            labels_top.append((x, label))

    # Sort the labels according to their alignment
    labels_bottom.sort()
    labels_left.sort()
    labels_left.reverse()
    labels_right.sort()
    labels_right.reverse()
    labels_top.sort()

    # Build an index label -> x, y
    points = {}
    for values in zip(metafeatures.index,
                      transformation[:, 0], transformation[:, 1]):
        label, x, y = values
        points[label] = (x, y)

    # Find out the final positions...
    positions_top = {}
    positions_left = {}
    positions_right = {}
    positions_bottom = {}

    # Find the actual positions
    for i, values in enumerate(labels_bottom):
        y, label = values
        margin = 1.2 / len(labels_bottom)
        positions_bottom[label] = (-0.05 + i * margin, -0.1,)
    for i, values in enumerate(labels_left):
        x, label = values
        margin = 1.2 / len(labels_left)
        positions_left[label] = (-0.1, 1.1 - i * margin)
    for i, values in enumerate(labels_top):
        y, label = values
        margin = 1.2 / len(labels_top)
        positions_top[label] = (-0.05 + i * margin, 1.1)
    for i, values in enumerate(labels_right):
        y, label = values
        margin = 1.2 / len(labels_right)
        positions_right[label] = (1.1, 1.05 - i * margin)

    # Do greedy resorting if it decreases the number of intersections...
    def resort(label_positions, marker_positions, maxdepth=1):
        # TODO: are the inputs dicts or lists
        # TODO: two-step look-ahead
        def intersect(start1, end1, start2, end2):
            # Compute if there is an intersection, for the algorithm see
            # Computer Graphics by F.S.Hill

            # If one vector is just a point, it cannot intersect with a line...
            for v in [start1, start2, end1, end2]:
                if not np.isfinite(v).all():
                    return False     # Obviously there is no intersection

            def perpendicular(d):
                return np.array((-d[1], d[0]))

            d1 = end1 - start1      # denoted b
            d2 = end2 - start2      # denoted d
            d2_1 = start2 - start1  # denoted c
            d1_perp = perpendicular(d1)   # denoted by b_perp
            d2_perp = perpendicular(d2)   # denoted by d_perp

            t = np.dot(d2_1, d2_perp) / np.dot(d1, d2_perp)
            u = - np.dot(d2_1, d1_perp) / np.dot(d2, d1_perp)

            if 0 <= t <= 1 and 0 <= u <= 1:
                return True    # There is an intersection
            else:
                return False     # There is no intersection

        def number_of_intersections(label_positions, marker_positions):
            num = 0
            for key1, key2 in itertools.permutations(label_positions, r=2):
                s1 = np.array(label_positions[key1])
                e1 = np.array(marker_positions[key1])
                s2 = np.array(label_positions[key2])
                e2 = np.array(marker_positions[key2])
                if intersect(s1, e1, s2, e2):
                    num += 1
            return num

        # test if swapping two lines would decrease the number of intersections
        # TODO: if this was done with a datastructure different than dicts,
        # it could be much faster, because there is a lot of redundant
        # computing performed in the second iteration
        def swap(label_positions, marker_positions, depth=0,
                 maxdepth=maxdepth, best_found=sys.maxint):
            if len(label_positions) <= 1:
                return

            two_step_look_ahead = False
            while True:
                improvement = False
                for key1, key2 in itertools.combinations(label_positions, r=2):
                    before = number_of_intersections(label_positions, marker_positions)
                    # swap:
                    tmp = label_positions[key1]
                    label_positions[key1] = label_positions[key2]
                    label_positions[key2] = tmp
                    if depth < maxdepth and two_step_look_ahead:
                        swap(label_positions, marker_positions,
                             depth=depth+1, best_found=before)

                    after = number_of_intersections(label_positions, marker_positions)

                    if best_found > after and before > after:
                        improvement = True
                        print(before, after)
                        print("Depth %d: Swapped %s with %s" %
                              (depth, key1, key2))
                    else:       # swap back...
                        tmp = label_positions[key1]
                        label_positions[key1] = label_positions[key2]
                        label_positions[key2] = tmp

                    if after == 0:
                        break

                # If it is not yet sorted perfectly, do another pass with
                # two-step lookahead
                if before == 0:
                    print("Sorted perfectly...")
                    break
                print(depth, two_step_look_ahead)
                if two_step_look_ahead:
                    break
                if maxdepth == depth:
                    print("Reached maximum recursion depth...")
                    break
                if not improvement and depth < maxdepth:
                    print("Still %d errors, trying two-step lookahead" % before)
                    two_step_look_ahead = True

        swap(label_positions, marker_positions, maxdepth=maxdepth)

    resort(positions_bottom, points, maxdepth=depth)
    resort(positions_left, points, maxdepth=depth)
    resort(positions_right, points, maxdepth=depth)
    resort(positions_top, points, maxdepth=depth)

    # Helper function
    def plot(x, y, label_x, label_y, label, ha, va, relpos, rotation=0):
        ax.scatter(x, y, marker='o', label=label, s=80, linewidths=0.1,
                   color='blue', edgecolor='black')

        label = label.replace('larochelle_etal_2007_', '')

        x = ax.annotate(label, xy=(x, y), xytext=(label_x, label_y),
                    ha=ha, va=va, rotation=rotation,
                    bbox=dict(boxstyle='round', fc='gray', alpha=0.5),
                    arrowprops=dict(arrowstyle='->', color='black',
                                    relpos=relpos))

    # Do the plotting
    for i, key in enumerate(positions_bottom):
        x, y = positions_bottom[key]
        plot(points[key][0], points[key][1], x, y,
             key, ha='right', va='top', rotation=45, relpos=(1, 1))
    for i, key in enumerate(positions_left):
        x, y = positions_left[key]
        plot(points[key][0], points[key][1], x, y, key,
             ha='right', va='top', rotation=45, relpos=(1, 1))
    for i, key in enumerate(positions_top):
        x, y = positions_top[key]
        plot(points[key][0], points[key][1], x, y, key,
             ha='left', va='bottom', rotation=45, relpos=(0, 0))
    for i, key in enumerate(positions_right):
        x, y = positions_right[key]
        plot(points[key][0], points[key][1], x, y, key,
             ha='left', va='bottom', rotation=45, relpos=(0, 0))

    # Resize everything
    box = ax.get_position()
    remove = 0.05 * box.width
    ax.set_position([box.x0 + remove, box.y0 + remove,
                     box.width - remove*2, box.height - remove*2])

    locs_x = ax.get_xticks()
    locs_y = ax.get_yticks()
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.set_xlim((-0.1, 1.1))
    ax.set_ylim((-0.1, 1.1))
    plt.savefig(os.path.join(metafeatures_plot_dir, "pca.png"))
    plt.savefig(os.path.join(metafeatures_plot_dir, "pca.pdf"))
    plt.clf()

Example 2

Project: pupil
Source File: eye.py
View license
def eye(timebase, is_alive_flag, ipc_pub_url, ipc_sub_url,ipc_push_url, user_dir, version, eye_id,overwrite_cap_settings=None):
    """reads eye video and detects the pupil.

    Creates a window, gl context.
    Grabs images from a capture.
    Streams Pupil coordinates.

    Reacts to notifications:
       ``set_detection_mapping_mode``: Sets detection method
       ``eye_process.should_stop``: Stops the eye process
       ``recording.started``: Starts recording eye video
       ``recording.stopped``: Stops recording eye video
       ``frame_publishing.started``: Starts frame publishing
       ``frame_publishing.stopped``: Stops frame publishing

    Emits notifications:
        ``eye_process.started``: Eye process started
        ``eye_process.stopped``: Eye process stopped

    Emits data:
        ``pupil.<eye id>``: Pupil data for eye with id ``<eye id>``
        ``frame.eye.<eye id>``: Eye frames with id ``<eye id>``
    """


    # We deferr the imports becasue of multiprocessing.
    # Otherwise the world process each process also loads the other imports.
    import zmq
    import zmq_tools
    zmq_ctx = zmq.Context()
    ipc_socket = zmq_tools.Msg_Dispatcher(zmq_ctx,ipc_push_url)
    pupil_socket = zmq_tools.Msg_Streamer(zmq_ctx,ipc_pub_url)
    notify_sub = zmq_tools.Msg_Receiver(zmq_ctx,ipc_sub_url,topics=("notify",))

    with Is_Alive_Manager(is_alive_flag,ipc_socket,eye_id):

        #logging setup
        import logging
        logging.getLogger("OpenGL").setLevel(logging.ERROR)
        logger = logging.getLogger()
        logger.handlers = []
        logger.addHandler(zmq_tools.ZMQ_handler(zmq_ctx,ipc_push_url))
        # create logger for the context of this function
        logger = logging.getLogger(__name__)

        #general imports
        import numpy as np
        import cv2

        #display
        import glfw
        from pyglui import ui,graph,cygl
        from pyglui.cygl.utils import draw_points, RGBA, draw_polyline, Named_Texture, Sphere
        import OpenGL.GL as gl
        from gl_utils import basic_gl_setup,adjust_gl_view, clear_gl_screen ,make_coord_system_pixel_based,make_coord_system_norm_based, make_coord_system_eye_camera_based,is_window_visible
        from ui_roi import UIRoi
        #monitoring
        import psutil
        import math


        # helpers/utils
        from uvc import get_time_monotonic, StreamError
        from file_methods import Persistent_Dict
        from version_utils import VersionFormat
        from methods import normalize, denormalize, Roi, timer
        from av_writer import JPEG_Writer,AV_Writer

        from video_capture import InitialisationError,StreamError, Fake_Source,EndofVideoFileError, source_classes, manager_classes
        source_by_name = {src.class_name():src for src in source_classes}

        # Pupil detectors
        from pupil_detectors import Detector_2D, Detector_3D
        pupil_detectors = {Detector_2D.__name__:Detector_2D,Detector_3D.__name__:Detector_3D}



        #UI Platform tweaks
        if platform.system() == 'Linux':
            scroll_factor = 10.0
            window_position_default = (600,300*eye_id)
        elif platform.system() == 'Windows':
            scroll_factor = 1.0
            window_position_default = (600,31+300*eye_id)
        else:
            scroll_factor = 1.0
            window_position_default = (600,300*eye_id)


        #g_pool holds variables for this process
        g_pool = Global_Container()

        # make some constants avaiable
        g_pool.user_dir = user_dir
        g_pool.version = version
        g_pool.app = 'capture'
        g_pool.timebase = timebase

        g_pool.ipc_pub = ipc_socket

        def get_timestamp():
            return get_time_monotonic()-g_pool.timebase.value
        g_pool.get_timestamp = get_timestamp
        g_pool.get_now = get_time_monotonic

        # Callback functions
        def on_resize(window,w, h):
            if is_window_visible(window):
                active_window = glfw.glfwGetCurrentContext()
                glfw.glfwMakeContextCurrent(window)
                g_pool.gui.update_window(w,h)
                graph.adjust_size(w,h)
                adjust_gl_view(w,h)
                glfw.glfwMakeContextCurrent(active_window)

        def on_key(window, key, scancode, action, mods):
            g_pool.gui.update_key(key,scancode,action,mods)

        def on_char(window,char):
            g_pool.gui.update_char(char)

        def on_iconify(window,iconified):
            g_pool.iconified = iconified

        def on_button(window,button, action, mods):
            if g_pool.display_mode == 'roi':
                if action == glfw.GLFW_RELEASE and g_pool.u_r.active_edit_pt:
                    g_pool.u_r.active_edit_pt = False
                    return # if the roi interacts we dont what the gui to interact as well
                elif action == glfw.GLFW_PRESS:
                    pos = glfw.glfwGetCursorPos(window)
                    pos = normalize(pos,glfw.glfwGetWindowSize(main_window))
                    if g_pool.flip:
                        pos = 1-pos[0],1-pos[1]
                    pos = denormalize(pos,(frame.width,frame.height)) # Position in img pixels
                    if g_pool.u_r.mouse_over_edit_pt(pos,g_pool.u_r.handle_size+40,g_pool.u_r.handle_size+40):
                        return # if the roi interacts we dont what the gui to interact as well

            g_pool.gui.update_button(button,action,mods)



        def on_pos(window,x, y):
            hdpi_factor = float(glfw.glfwGetFramebufferSize(window)[0]/glfw.glfwGetWindowSize(window)[0])
            g_pool.gui.update_mouse(x*hdpi_factor,y*hdpi_factor)

            if g_pool.u_r.active_edit_pt:
                pos = normalize((x,y),glfw.glfwGetWindowSize(main_window))
                if g_pool.flip:
                    pos = 1-pos[0],1-pos[1]
                pos = denormalize(pos,(frame.width,frame.height) )
                g_pool.u_r.move_vertex(g_pool.u_r.active_pt_idx,pos)

        def on_scroll(window,x,y):
            g_pool.gui.update_scroll(x,y*scroll_factor)

        g_pool.on_frame_size_change = lambda new_size: None

        # load session persistent settings
        session_settings = Persistent_Dict(os.path.join(g_pool.user_dir,'user_settings_eye%s'%eye_id))
        if session_settings.get("version",VersionFormat('0.0')) < g_pool.version:
            logger.info("Session setting are from older version of this app. I will not use those.")
            session_settings.clear()

        capture_manager_settings = session_settings.get(
            'capture_manager_settings', ('UVC_Manager',{}))

        if eye_id == 0:
            cap_src = ["Pupil Cam1 ID0","HD-6000","Integrated Camera","HD USB Camera","USB 2.0 Camera"]
        else:
            cap_src = ["Pupil Cam1 ID1","HD-6000","Integrated Camera"]

        # Initialize capture
        default_settings = {
            'source_class_name': 'UVC_Source',
            'preferred_names'  : cap_src,
            'frame_size': (640,480),
            'frame_rate': 60
        }
        settings = overwrite_cap_settings or session_settings.get('capture_settings', default_settings)
        try:
            cap = source_by_name[settings['source_class_name']](g_pool, **settings)
        except (KeyError,InitialisationError) as e:
            if isinstance(e,KeyError):
                logger.warning('Incompatible capture setting encountered. Falling back to fake source.')
            cap = Fake_Source(g_pool, **settings)

        g_pool.iconified = False
        g_pool.capture = cap
        g_pool.capture_manager = None
        g_pool.flip = session_settings.get('flip',False)
        g_pool.display_mode = session_settings.get('display_mode','camera_image')
        g_pool.display_mode_info_text = {'camera_image': "Raw eye camera image. This uses the least amount of CPU power",
                                    'roi': "Click and drag on the blue circles to adjust the region of interest. The region should be as small as possible, but large enough to capture all pupil movements.",
                                    'algorithm': "Algorithm display mode overlays a visualization of the pupil detection parameters on top of the eye video. Adjust parameters within the Pupil Detection menu below."}



        g_pool.u_r = UIRoi((g_pool.capture.frame_size[1],g_pool.capture.frame_size[0]))
        roi_user_settings = session_settings.get('roi')
        if roi_user_settings and roi_user_settings[-1] == g_pool.u_r.get()[-1]:
            g_pool.u_r.set(roi_user_settings)

        def on_frame_size_change(new_size):
            g_pool.u_r = UIRoi((new_size[1],new_size[0]))

        g_pool.on_frame_size_change = on_frame_size_change

        writer = None

        pupil_detector_settings = session_settings.get('pupil_detector_settings',None)
        last_pupil_detector = pupil_detectors[session_settings.get('last_pupil_detector',Detector_2D.__name__)]
        g_pool.pupil_detector = last_pupil_detector(g_pool,pupil_detector_settings)

        # UI callback functions
        def set_scale(new_scale):
            g_pool.gui.scale = new_scale
            g_pool.gui.collect_menus()


        def set_display_mode_info(val):
            g_pool.display_mode = val
            g_pool.display_mode_info.text = g_pool.display_mode_info_text[val]


        def set_detector(new_detector):
            g_pool.pupil_detector.cleanup()
            g_pool.pupil_detector = new_detector(g_pool)
            g_pool.pupil_detector.init_gui(g_pool.sidebar)


        # Initialize glfw
        glfw.glfwInit()
        title = "eye %s"%eye_id
        width,height = session_settings.get('window_size',g_pool.capture.frame_size)
        main_window = glfw.glfwCreateWindow(width,height, title, None, None)
        window_pos = session_settings.get('window_position',window_position_default)
        glfw.glfwSetWindowPos(main_window,window_pos[0],window_pos[1])
        glfw.glfwMakeContextCurrent(main_window)
        cygl.utils.init()

        # gl_state settings
        basic_gl_setup()
        g_pool.image_tex = Named_Texture()
        glfw.glfwSwapInterval(0)

        #setup GUI
        g_pool.gui = ui.UI()
        g_pool.gui.scale = session_settings.get('gui_scale',1)
        g_pool.sidebar = ui.Scrolling_Menu("Settings",pos=(-300,0),size=(0,0),header_pos='left')
        general_settings = ui.Growing_Menu('General')
        general_settings.append(ui.Slider('scale',g_pool.gui, setter=set_scale,step = .05,min=1.,max=2.5,label='Interface Size'))
        general_settings.append(ui.Button('Reset window size',lambda: glfw.glfwSetWindowSize(main_window,frame.width,frame.height)) )
        general_settings.append(ui.Switch('flip',g_pool,label='Flip image display'))
        general_settings.append(ui.Selector('display_mode',g_pool,setter=set_display_mode_info,selection=['camera_image','roi','algorithm'], labels=['Camera Image', 'ROI', 'Algorithm'], label="Mode") )
        g_pool.display_mode_info = ui.Info_Text(g_pool.display_mode_info_text[g_pool.display_mode])
        general_settings.append(g_pool.display_mode_info)
        g_pool.gui.append(g_pool.sidebar)
        detector_selector = ui.Selector('pupil_detector',getter = lambda: g_pool.pupil_detector.__class__ ,setter=set_detector,selection=[Detector_2D, Detector_3D],labels=['C++ 2d detector', 'C++ 3d detector'], label="Detection method")
        general_settings.append(detector_selector)

        g_pool.capture_selector_menu = ui.Growing_Menu('Capture Selection')
        g_pool.capture_source_menu = ui.Growing_Menu('Capture Source')
        g_pool.capture.init_gui()

        g_pool.sidebar.append(general_settings)
        g_pool.sidebar.append(g_pool.capture_selector_menu)
        g_pool.sidebar.append(g_pool.capture_source_menu)

        g_pool.pupil_detector.init_gui(g_pool.sidebar)

        manager_class_name, manager_settings = capture_manager_settings
        manager_class_by_name = {c.__name__:c for c in manager_classes}
        g_pool.capture_manager = manager_class_by_name[manager_class_name](g_pool,**manager_settings)
        g_pool.capture_manager.init_gui()

        def open_manager(manager_class):
            g_pool.capture_manager.cleanup()
            g_pool.capture_manager = manager_class(g_pool)
            g_pool.capture_manager.init_gui()

        #We add the capture selection menu, after a manager has been added:
        g_pool.capture_selector_menu.insert(0,ui.Selector(
            'capture_manager',g_pool,
            setter    = open_manager,
            getter    = lambda: g_pool.capture_manager.__class__,
            selection = manager_classes,
            labels    = [b.gui_name for b in manager_classes],
            label     = 'Manager'
        ))

        # Register callbacks main_window
        glfw.glfwSetFramebufferSizeCallback(main_window,on_resize)
        glfw.glfwSetWindowIconifyCallback(main_window,on_iconify)
        glfw.glfwSetKeyCallback(main_window,on_key)
        glfw.glfwSetCharCallback(main_window,on_char)
        glfw.glfwSetMouseButtonCallback(main_window,on_button)
        glfw.glfwSetCursorPosCallback(main_window,on_pos)
        glfw.glfwSetScrollCallback(main_window,on_scroll)

        #set the last saved window size
        on_resize(main_window, *glfw.glfwGetWindowSize(main_window))


        # load last gui configuration
        g_pool.gui.configuration = session_settings.get('ui_config',{})


        #set up performance graphs
        pid = os.getpid()
        ps = psutil.Process(pid)
        ts = g_pool.get_timestamp()

        cpu_graph = graph.Bar_Graph()
        cpu_graph.pos = (20,130)
        cpu_graph.update_fn = ps.cpu_percent
        cpu_graph.update_rate = 5
        cpu_graph.label = 'CPU %0.1f'

        fps_graph = graph.Bar_Graph()
        fps_graph.pos = (140,130)
        fps_graph.update_rate = 5
        fps_graph.label = "%0.0f FPS"

        should_publish_frames = False
        frame_publish_format = 'jpeg'

        #create a timer to control window update frequency
        window_update_timer = timer(1/60.)
        def window_should_update():
            return next(window_update_timer)

        logger.warning('Process started.')

        # Event loop
        while not glfw.glfwWindowShouldClose(main_window):

            if notify_sub.new_data:
                t,notification = notify_sub.recv()
                subject = notification['subject']
                if subject == 'eye_process.should_stop':
                    if notification['eye_id'] == eye_id:
                        break
                elif subject == 'set_detection_mapping_mode':
                    if notification['mode'] == '3d':
                        if not isinstance(g_pool.pupil_detector,Detector_3D):
                            set_detector(Detector_3D)
                        detector_selector.read_only  = True
                    else:
                        if not isinstance(g_pool.pupil_detector,Detector_2D):
                            set_detector(Detector_2D)
                        detector_selector.read_only = False
                elif subject == 'recording.started':
                    if notification['record_eye']:
                        record_path = notification['rec_path']
                        raw_mode = notification['compression']
                        logger.info("Will save eye video to: %s"%record_path)
                        timestamps_path = os.path.join(record_path, "eye%s_timestamps.npy"%eye_id)
                        if raw_mode and frame.jpeg_buffer:
                            video_path = os.path.join(record_path, "eye%s.mp4"%eye_id)
                            writer = JPEG_Writer(video_path,g_pool.capture.frame_rate)
                        else:
                            video_path = os.path.join(record_path, "eye%s.mp4"%eye_id)
                            writer = AV_Writer(video_path,g_pool.capture.frame_rate)
                        timestamps = []
                elif subject == 'recording.stopped':
                    if writer:
                        logger.info("Done recording.")
                        writer.release()
                        writer = None
                        np.save(timestamps_path,np.asarray(timestamps))
                        del timestamps
                elif subject.startswith('meta.should_doc'):
                    ipc_socket.notify({
                        'subject':'meta.doc',
                        'actor':'eye%i'%eye_id,
                        'doc':eye.__doc__
                        })
                elif subject.startswith('frame_publishing.started'):
                    should_publish_frames = True
                    frame_publish_format = notification.get('format','jpeg')
                elif subject.startswith('frame_publishing.stopped'):
                    should_publish_frames = False
                    frame_publish_format = 'jpeg'
                else:
                    g_pool.capture_manager.on_notify(notification)

            # Get an image from the grabber
            try:
                frame = g_pool.capture.get_frame()
            except StreamError as e:
                logger.error("Error getting frame. Stopping eye process.")
                logger.debug("Caught error: %s"%e)
                break
            except EndofVideoFileError:
                logger.warning("Video File is done. Stopping")
                g_pool.capture.seek_to_frame(0)
                frame = g_pool.capture.get_frame()

            g_pool.capture_manager.update(frame, {})

            if should_publish_frames and frame.jpeg_buffer:
                if   frame_publish_format == "jpeg":
                    data = frame.jpeg_buffer
                elif frame_publish_format == "yuv":
                    data = frame.yuv_buffer
                elif frame_publish_format == "bgr":
                    data = frame.bgr
                elif frame_publish_format == "gray":
                    data = frame.gray
                pupil_socket.send('frame.eye.%s'%eye_id,{
                    'width': frame.width,
                    'height': frame.width,
                    'index': frame.index,
                    'timestamp': frame.timestamp,
                    'format': frame_publish_format,
                    '__raw_data__': [data]
                })


            #update performace graphs
            t = frame.timestamp
            dt,ts = t-ts,t
            try:
                fps_graph.add(1./dt)
            except ZeroDivisionError:
                pass
            cpu_graph.update()



            if writer:
                writer.write_video_frame(frame)
                timestamps.append(frame.timestamp)


            # pupil ellipse detection
            result = g_pool.pupil_detector.detect(frame, g_pool.u_r, g_pool.display_mode == 'algorithm')
            result['id'] = eye_id
            # stream the result
            pupil_socket.send('pupil.%s'%eye_id,result)

            # GL drawing
            if window_should_update():
                if is_window_visible(main_window):
                    glfw.glfwMakeContextCurrent(main_window)
                    clear_gl_screen()

                    # switch to work in normalized coordinate space
                    if g_pool.display_mode == 'algorithm':
                        g_pool.image_tex.update_from_ndarray(frame.img)
                    elif g_pool.display_mode in ('camera_image','roi'):
                        g_pool.image_tex.update_from_ndarray(frame.gray)
                    else:
                        pass

                    make_coord_system_norm_based(g_pool.flip)
                    g_pool.image_tex.draw()

                    window_size =  glfw.glfwGetWindowSize(main_window)
                    make_coord_system_pixel_based((frame.height,frame.width,3),g_pool.flip)
                    g_pool.capture.gl_display()

                    if result['method'] == '3d c++':

                        eye_ball = result['projected_sphere']
                        try:
                            pts = cv2.ellipse2Poly( (int(eye_ball['center'][0]),int(eye_ball['center'][1])),
                                                (int(eye_ball['axes'][0]/2),int(eye_ball['axes'][1]/2)),
                                                int(eye_ball['angle']),0,360,8)
                        except ValueError as e:
                            pass
                        else:
                            draw_polyline(pts,2,RGBA(0.,.9,.1,result['model_confidence']) )

                    if result['confidence'] >0:
                        if result.has_key('ellipse'):
                            pts = cv2.ellipse2Poly( (int(result['ellipse']['center'][0]),int(result['ellipse']['center'][1])),
                                            (int(result['ellipse']['axes'][0]/2),int(result['ellipse']['axes'][1]/2)),
                                            int(result['ellipse']['angle']),0,360,15)
                            confidence = result['confidence'] * 0.7 #scale it a little
                            draw_polyline(pts,1,RGBA(1.,0,0,confidence))
                            draw_points([result['ellipse']['center']],size=20,color=RGBA(1.,0.,0.,confidence),sharpness=1.)

                    # render graphs
                    graph.push_view()
                    fps_graph.draw()
                    cpu_graph.draw()
                    graph.pop_view()

                    # render GUI
                    g_pool.gui.update()

                    #render the ROI
                    g_pool.u_r.draw(g_pool.gui.scale)
                    if g_pool.display_mode == 'roi':
                        g_pool.u_r.draw_points(g_pool.gui.scale)

                    #update screen
                    glfw.glfwSwapBuffers(main_window)
                glfw.glfwPollEvents()
                g_pool.pupil_detector.visualize() #detector decides if we visualize or not


        # END while running

        # in case eye recording was still runnnig: Save&close
        if writer:
            logger.info("Done recording eye.")
            writer = None
            np.save(timestamps_path,np.asarray(timestamps))

        glfw.glfwRestoreWindow(main_window) #need to do this for windows os
        # save session persistent settings
        session_settings['gui_scale'] = g_pool.gui.scale
        session_settings['roi'] = g_pool.u_r.get()
        session_settings['flip'] = g_pool.flip
        session_settings['display_mode'] = g_pool.display_mode
        session_settings['ui_config'] = g_pool.gui.configuration
        session_settings['capture_settings'] = g_pool.capture.settings
        session_settings['capture_manager_settings'] = g_pool.capture_manager.class_name, g_pool.capture_manager.get_init_dict()
        session_settings['window_size'] = glfw.glfwGetWindowSize(main_window)
        session_settings['window_position'] = glfw.glfwGetWindowPos(main_window)
        session_settings['version'] = g_pool.version
        session_settings['last_pupil_detector'] = g_pool.pupil_detector.__class__.__name__
        session_settings['pupil_detector_settings'] = g_pool.pupil_detector.get_settings()
        session_settings.close()

        g_pool.capture.deinit_gui()
        g_pool.pupil_detector.cleanup()
        g_pool.gui.terminate()
        glfw.glfwDestroyWindow(main_window)
        glfw.glfwTerminate()
        g_pool.capture_manager.cleanup()
        g_pool.capture.cleanup()
        logger.info("Process shutting down.")

Example 3

Project: pupil
Source File: main.py
View license
def session(rec_dir):

    system_plugins = [Log_Display,Seek_Bar,Trim_Marks]
    vis_plugins = sorted([Vis_Circle,Vis_Polyline,Vis_Light_Points,Vis_Cross,Vis_Watermark,Eye_Video_Overlay,Scan_Path], key=lambda x: x.__name__)
    analysis_plugins = sorted([Gaze_Position_2D_Fixation_Detector,Pupil_Angle_3D_Fixation_Detector,Pupil_Angle_3D_Fixation_Detector,Manual_Gaze_Correction,Video_Export_Launcher,Offline_Surface_Tracker,Raw_Data_Exporter,Batch_Exporter,Annotation_Player], key=lambda x: x.__name__)
    other_plugins = sorted([Show_Calibration,Log_History], key=lambda x: x.__name__)
    user_plugins = sorted(import_runtime_plugins(os.path.join(user_dir,'plugins')), key=lambda x: x.__name__)
    user_launchable_plugins = vis_plugins + analysis_plugins + other_plugins + user_plugins
    available_plugins = system_plugins + user_launchable_plugins
    name_by_index = [p.__name__ for p in available_plugins]
    index_by_name = dict(zip(name_by_index,range(len(name_by_index))))
    plugin_by_name = dict(zip(name_by_index,available_plugins))


    # Callback functions
    def on_resize(window,w, h):
        g_pool.gui.update_window(w,h)
        g_pool.gui.collect_menus()
        graph.adjust_size(w,h)
        adjust_gl_view(w,h)
        for p in g_pool.plugins:
            p.on_window_resize(window,w,h)

    def on_key(window, key, scancode, action, mods):
        g_pool.gui.update_key(key,scancode,action,mods)

    def on_char(window,char):
        g_pool.gui.update_char(char)

    def on_button(window,button, action, mods):
        g_pool.gui.update_button(button,action,mods)
        pos = glfwGetCursorPos(window)
        pos = normalize(pos,glfwGetWindowSize(window))
        pos = denormalize(pos,(frame.img.shape[1],frame.img.shape[0]) ) # Position in img pixels
        for p in g_pool.plugins:
            p.on_click(pos,button,action)

    def on_pos(window,x, y):
        hdpi_factor = float(glfwGetFramebufferSize(window)[0]/glfwGetWindowSize(window)[0])
        g_pool.gui.update_mouse(x*hdpi_factor,y*hdpi_factor)

    def on_scroll(window,x,y):
        g_pool.gui.update_scroll(x,y*y_scroll_factor)


    def on_drop(window,count,paths):
        for x in range(count):
            new_rec_dir =  paths[x]
            if is_pupil_rec_dir(new_rec_dir):
                logger.debug("Starting new session with '%s'"%new_rec_dir)
                global rec_dir
                rec_dir = new_rec_dir
                glfwSetWindowShouldClose(window,True)
            else:
                logger.error("'%s' is not a valid pupil recording"%new_rec_dir)




    tick = delta_t()
    def get_dt():
        return next(tick)

    update_recording_to_recent(rec_dir)

    video_path = [f for f in glob(os.path.join(rec_dir,"world.*")) if f[-3:] in ('mp4','mkv','avi')][0]
    timestamps_path = os.path.join(rec_dir, "world_timestamps.npy")
    pupil_data_path = os.path.join(rec_dir, "pupil_data")

    meta_info = load_meta_info(rec_dir)
    rec_version = read_rec_version(meta_info)
    app_version = get_version(version_file)

    # log info about Pupil Platform and Platform in player.log
    logger.info('Application Version: %s'%app_version)
    logger.info('System Info: %s'%get_system_info())

    timestamps = np.load(timestamps_path)

    # create container for globally scoped vars
    g_pool = Global_Container()
    g_pool.app = 'player'

    # Initialize capture
    cap = File_Source(g_pool,video_path,timestamps=list(timestamps))

    # load session persistent settings
    session_settings = Persistent_Dict(os.path.join(user_dir,"user_settings"))
    if session_settings.get("version",VersionFormat('0.0')) < get_version(version_file):
        logger.info("Session setting are from older version of this app. I will not use those.")
        session_settings.clear()

    width,height = session_settings.get('window_size',cap.frame_size)
    window_pos = session_settings.get('window_position',(0,0))
    main_window = glfwCreateWindow(width, height, "Pupil Player: "+meta_info["Recording Name"]+" - "+ rec_dir.split(os.path.sep)[-1], None, None)
    glfwSetWindowPos(main_window,window_pos[0],window_pos[1])
    glfwMakeContextCurrent(main_window)
    cygl.utils.init()

    # load pupil_positions, gaze_positions
    pupil_data = load_object(pupil_data_path)
    pupil_list = pupil_data['pupil_positions']
    gaze_list = pupil_data['gaze_positions']

    g_pool.binocular = meta_info.get('Eye Mode','monocular') == 'binocular'
    g_pool.version = app_version
    g_pool.capture = cap
    g_pool.timestamps = timestamps
    g_pool.play = False
    g_pool.new_seek = True
    g_pool.user_dir = user_dir
    g_pool.rec_dir = rec_dir
    g_pool.rec_version = rec_version
    g_pool.meta_info = meta_info
    g_pool.min_data_confidence = session_settings.get('min_data_confidence',0.6)
    g_pool.pupil_positions_by_frame = correlate_data(pupil_list,g_pool.timestamps)
    g_pool.gaze_positions_by_frame = correlate_data(gaze_list,g_pool.timestamps)
    g_pool.fixations_by_frame = [[] for x in g_pool.timestamps] #populated by the fixation detector plugin

    def next_frame(_):
        try:
            cap.seek_to_frame(cap.get_frame_index())
        except FileSeekError:
            logger.warning("Could not seek to next frame.")
        else:
            g_pool.new_seek = True

    def prev_frame(_):
        try:
            cap.seek_to_frame(cap.get_frame_index()-2)
        except FileSeekError:
            logger.warning("Could not seek to previous frame.")
        else:
            g_pool.new_seek = True

    def toggle_play(new_state):
        if cap.get_frame_index() >= cap.get_frame_count()-5:
            cap.seek_to_frame(1) #avoid pause set by hitting trimmark pause.
            logger.warning("End of video - restart at beginning.")
        g_pool.play = new_state

    def set_scale(new_scale):
        g_pool.gui.scale = new_scale
        g_pool.gui.collect_menus()

    def set_data_confidence(new_confidence):
        g_pool.min_data_confidence = new_confidence
        notification = {'subject':'min_data_confidence_changed'}
        notification['_notify_time_'] = time()+.8
        g_pool.delayed_notifications[notification['subject']] = notification

    def open_plugin(plugin):
        if plugin ==  "Select to load":
            return
        g_pool.plugins.add(plugin)

    def purge_plugins():
        for p in g_pool.plugins:
            if p.__class__ in user_launchable_plugins:
                p.alive = False
        g_pool.plugins.clean()

    def do_export(_):
        export_range = slice(g_pool.trim_marks.in_mark,g_pool.trim_marks.out_mark)
        export_dir = os.path.join(g_pool.rec_dir,'exports','%s-%s'%(export_range.start,export_range.stop))
        try:
            os.makedirs(export_dir)
        except OSError as e:
            if e.errno != errno.EEXIST:
                logger.error("Could not create export dir")
                raise e
            else:
                logger.warning("Previous export for range [%s-%s] already exsits - overwriting."%(export_range.start,export_range.stop))
        else:
            logger.info('Created export dir at "%s"'%export_dir)

        notification = {'subject':'should_export','range':export_range,'export_dir':export_dir}
        g_pool.notifications.append(notification)

    g_pool.gui = ui.UI()
    g_pool.gui.scale = session_settings.get('gui_scale',1)
    g_pool.main_menu = ui.Scrolling_Menu("Settings",pos=(-350,20),size=(300,500))
    g_pool.main_menu.append(ui.Button("Close Pupil Player",lambda:glfwSetWindowShouldClose(main_window,True)))
    g_pool.main_menu.append(ui.Slider('scale',g_pool.gui, setter=set_scale,step = .05,min=0.75,max=2.5,label='Interface Size'))
    g_pool.main_menu.append(ui.Info_Text('Player Version: %s'%g_pool.version))
    g_pool.main_menu.append(ui.Info_Text('Recording Version: %s'%rec_version))
    g_pool.main_menu.append(ui.Slider('min_data_confidence',g_pool, setter=set_data_confidence,step=.05 ,min=0.0,max=1.0,label='Confidence threshold'))

    selector_label = "Select to load"

    vis_labels = ["   " + p.__name__.replace('_',' ') for p in vis_plugins]
    analysis_labels = ["   " + p.__name__.replace('_',' ') for p in analysis_plugins]
    other_labels = ["   " + p.__name__.replace('_',' ') for p in other_plugins]
    user_labels = ["   " + p.__name__.replace('_',' ') for p in user_plugins]

    plugins = [selector_label, selector_label] + vis_plugins + [selector_label] + analysis_plugins + [selector_label] + other_plugins + [selector_label] + user_plugins
    labels = [selector_label, "Visualization"] + vis_labels + ["Analysis"] + analysis_labels + ["Other"] + other_labels + ["User added"] + user_labels

    g_pool.main_menu.append(ui.Selector('Open plugin:',
                                        selection = plugins,
                                        labels    = labels,
                                        setter    = open_plugin,
                                        getter    = lambda: selector_label))

    g_pool.main_menu.append(ui.Button('Close all plugins',purge_plugins))
    g_pool.main_menu.append(ui.Button('Reset window size',lambda: glfwSetWindowSize(main_window,cap.frame_size[0],cap.frame_size[1])) )
    g_pool.quickbar = ui.Stretching_Menu('Quick Bar',(0,100),(120,-100))
    g_pool.play_button = ui.Thumb('play',g_pool,label=unichr(0xf04b).encode('utf-8'),setter=toggle_play,hotkey=GLFW_KEY_SPACE,label_font='fontawesome',label_offset_x=5,label_offset_y=0,label_offset_size=-24)
    g_pool.play_button.on_color[:] = (0,1.,.0,.8)
    g_pool.forward_button = ui.Thumb('forward',label=unichr(0xf04e).encode('utf-8'),getter = lambda: False,setter= next_frame, hotkey=GLFW_KEY_RIGHT,label_font='fontawesome',label_offset_x=5,label_offset_y=0,label_offset_size=-24)
    g_pool.backward_button = ui.Thumb('backward',label=unichr(0xf04a).encode('utf-8'),getter = lambda: False, setter = prev_frame, hotkey=GLFW_KEY_LEFT,label_font='fontawesome',label_offset_x=-5,label_offset_y=0,label_offset_size=-24)
    g_pool.export_button = ui.Thumb('export',label=unichr(0xf063).encode('utf-8'),getter = lambda: False, setter = do_export, hotkey='e',label_font='fontawesome',label_offset_x=0,label_offset_y=2,label_offset_size=-24)
    g_pool.quickbar.extend([g_pool.play_button,g_pool.forward_button,g_pool.backward_button,g_pool.export_button])
    g_pool.gui.append(g_pool.quickbar)
    g_pool.gui.append(g_pool.main_menu)


    #we always load these plugins
    system_plugins = [('Trim_Marks',{}),('Seek_Bar',{})]
    default_plugins = [('Log_Display',{}),('Scan_Path',{}),('Vis_Polyline',{}),('Vis_Circle',{}),('Video_Export_Launcher',{})]
    previous_plugins = session_settings.get('loaded_plugins',default_plugins)
    g_pool.notifications = []
    g_pool.delayed_notifications = {}
    g_pool.plugins = Plugin_List(g_pool,plugin_by_name,system_plugins+previous_plugins)


    # Register callbacks main_window
    glfwSetFramebufferSizeCallback(main_window,on_resize)
    glfwSetKeyCallback(main_window,on_key)
    glfwSetCharCallback(main_window,on_char)
    glfwSetMouseButtonCallback(main_window,on_button)
    glfwSetCursorPosCallback(main_window,on_pos)
    glfwSetScrollCallback(main_window,on_scroll)
    glfwSetDropCallback(main_window,on_drop)
    #trigger on_resize
    on_resize(main_window, *glfwGetFramebufferSize(main_window))

    g_pool.gui.configuration = session_settings.get('ui_config',{})

    # gl_state settings
    basic_gl_setup()
    g_pool.image_tex = Named_Texture()

    #set up performace graphs:
    pid = os.getpid()
    ps = psutil.Process(pid)
    ts = None

    cpu_graph = graph.Bar_Graph()
    cpu_graph.pos = (20,110)
    cpu_graph.update_fn = ps.cpu_percent
    cpu_graph.update_rate = 5
    cpu_graph.label = 'CPU %0.1f'

    fps_graph = graph.Bar_Graph()
    fps_graph.pos = (140,110)
    fps_graph.update_rate = 5
    fps_graph.label = "%0.0f REC FPS"

    pupil_graph = graph.Bar_Graph(max_val=1.0)
    pupil_graph.pos = (260,110)
    pupil_graph.update_rate = 5
    pupil_graph.label = "Confidence: %0.2f"

    while not glfwWindowShouldClose(main_window):


        #grab new frame
        if g_pool.play or g_pool.new_seek:
            g_pool.new_seek = False
            try:
                new_frame = cap.get_frame_nowait()
            except EndofVideoFileError:
                #end of video logic: pause at last frame.
                g_pool.play=False
                logger.warning("end of video")
            update_graph = True
        else:
            update_graph = False


        frame = new_frame.copy()
        events = {}
        #report time between now and the last loop interation
        events['dt'] = get_dt()
        #new positons we make a deepcopy just like the image is a copy.
        events['gaze_positions'] = deepcopy(g_pool.gaze_positions_by_frame[frame.index])
        events['pupil_positions'] = deepcopy(g_pool.pupil_positions_by_frame[frame.index])

        if update_graph:
            #update performace graphs
            for p in  events['pupil_positions']:
                pupil_graph.add(p['confidence'])

            t = new_frame.timestamp
            if ts and ts != t:
                dt,ts = t-ts,t
                fps_graph.add(1./dt)

            g_pool.play_button.status_text = str(frame.index)
        #always update the CPU graph
        cpu_graph.update()


        # publish delayed notifiactions when their time has come.
        for n in g_pool.delayed_notifications.values():
            if n['_notify_time_'] < time():
                del n['_notify_time_']
                del g_pool.delayed_notifications[n['subject']]
                g_pool.notifications.append(n)

        # notify each plugin if there are new notifactions:
        while g_pool.notifications:
            n = g_pool.notifications.pop(0)
            for p in g_pool.plugins:
                p.on_notify(n)

        # allow each Plugin to do its work.
        for p in g_pool.plugins:
            p.update(frame,events)

        #check if a plugin need to be destroyed
        g_pool.plugins.clean()

        # render camera image
        glfwMakeContextCurrent(main_window)
        make_coord_system_norm_based()
        g_pool.image_tex.update_from_frame(frame)
        g_pool.image_tex.draw()
        make_coord_system_pixel_based(frame.img.shape)
        # render visual feedback from loaded plugins
        for p in g_pool.plugins:
            p.gl_display()

        graph.push_view()
        fps_graph.draw()
        cpu_graph.draw()
        pupil_graph.draw()
        graph.pop_view()
        g_pool.gui.update()

        #present frames at appropriate speed
        cap.wait(frame)

        glfwSwapBuffers(main_window)
        glfwPollEvents()

    session_settings['loaded_plugins'] = g_pool.plugins.get_initializers()
    session_settings['min_data_confidence'] = g_pool.min_data_confidence
    session_settings['gui_scale'] = g_pool.gui.scale
    session_settings['ui_config'] = g_pool.gui.configuration
    session_settings['window_size'] = glfwGetWindowSize(main_window)
    session_settings['window_position'] = glfwGetWindowPos(main_window)
    session_settings['version'] = g_pool.version
    session_settings.close()

    # de-init all running plugins
    for p in g_pool.plugins:
        p.alive = False
    g_pool.plugins.clean()

    cap.cleanup()
    g_pool.gui.terminate()
    glfwDestroyWindow(main_window)

Example 4

View license
def run(test, params, env):
    """
    Test command: virsh update-device.

    Update device from an XML <file>.
    1.Prepare test environment, adding a cdrom/floppy to VM.
    2.Perform virsh update-device operation.
    3.Recover test environment.
    4.Confirm the test result.
    """

    vm_name = params.get("main_vm")
    vm = env.get_vm(vm_name)
    pre_vm_state = params.get("at_dt_device_pre_vm_state")
    virsh_dargs = {"debug": True, "ignore_status": True}

    def is_attached(vmxml_devices, disk_type, source_file, target_dev):
        """
        Check attached device and disk exist or not.

        :param vmxml_devices: VMXMLDevices instance
        :param disk_type: disk's device type: cdrom or floppy
        :param source_file : disk's source file to check
        :param target_dev : target device name
        :return: True/False if backing file and device found
        """
        disks = vmxml_devices.by_device_tag('disk')
        for disk in disks:
            logging.debug("Check disk XML:\n%s", open(disk['xml']).read())
            if disk.device != disk_type:
                continue
            if disk.target['dev'] != target_dev:
                continue
            if disk.xmltreefile.find('source') is not None:
                if disk.source.attrs['file'] != source_file:
                    continue
            else:
                continue
            # All three conditions met
            logging.debug("Find %s in given disk XML", source_file)
            return True
        logging.debug("Not find %s in gievn disk XML", source_file)
        return False

    def check_result(disk_source, disk_type, disk_target,
                     flags, attach=True):
        """
        Check the test result of update-device command.
        """
        vm_state = pre_vm_state
        active_vmxml = VMXML.new_from_dumpxml(vm_name)
        active_attached = is_attached(active_vmxml.devices, disk_type,
                                      disk_source, disk_target)
        if vm_state != "transient":
            inactive_vmxml = VMXML.new_from_dumpxml(vm_name,
                                                    options="--inactive")
            inactive_attached = is_attached(inactive_vmxml.devices, disk_type,
                                            disk_source, disk_target)

        if flags.count("config") and not flags.count("live"):
            if vm_state != "transient":
                if attach:
                    if not inactive_attached:
                        raise exceptions.TestFail("Inactive domain XML not updated"
                                                  " when --config options used for"
                                                  " attachment")
                    if vm_state != "shutoff":
                        if active_attached:
                            raise exceptions.TestFail("Active domain XML updated "
                                                      "when --config options used"
                                                      " for attachment")
                else:
                    if inactive_attached:
                        raise exceptions.TestFail("Inactive domain XML not updated"
                                                  " when --config options used for"
                                                  " detachment")
                    if vm_state != "shutoff":
                        if not active_attached:
                            raise exceptions.TestFail("Active domain XML updated "
                                                      "when --config options used"
                                                      " for detachment")
        elif flags.count("live") and not flags.count("config"):
            if attach:
                if vm_state in ["paused", "running", "transient"]:
                    if not active_attached:
                        raise exceptions.TestFail("Active domain XML not updated"
                                                  " when --live options used for"
                                                  " attachment")
                if vm_state in ["paused", "running"]:
                    if inactive_attached:
                        raise exceptions.TestFail("Inactive domain XML updated "
                                                  "when --live options used for"
                                                  " attachment")
            else:
                if vm_state in ["paused", "running", "transient"]:
                    if active_attached:
                        raise exceptions.TestFail("Active domain XML not updated"
                                                  " when --live options used for"
                                                  " detachment")
                if vm_state in ["paused", "running"]:
                    if not inactive_attached:
                        raise exceptions.TestFail("Inactive domain XML updated "
                                                  "when --live options used for"
                                                  " detachment")
        elif flags.count("live") and flags.count("config"):
            if attach:
                if vm_state in ["paused", "running"]:
                    if not active_attached:
                        raise exceptions.TestFail("Active domain XML not updated"
                                                  " when --live --config options"
                                                  " used for attachment")
                    if not inactive_attached:
                        raise exceptions.TestFail("Inactive domain XML not updated"
                                                  " when --live --config options "
                                                  "used for attachment")
            else:
                if vm_state in ["paused", "running"]:
                    if active_attached:
                        raise exceptions.TestFail("Active domain XML not updated"
                                                  " when --live --config options"
                                                  " used for detachment")
                    if inactive_attached:
                        raise exceptions.TestFail("Inactive domain XML not updated"
                                                  " when --live --config options "
                                                  "used for detachment")
        elif flags.count("current") or flags == "":
            if attach:
                if vm_state in ["paused", "running", "transient"]:
                    if not active_attached:
                        raise exceptions.TestFail("Active domain XML not updated "
                                                  "when --current options used "
                                                  "for attachment")
                if vm_state in ["paused", "running"]:
                    if inactive_attached:
                        raise exceptions.TestFail("Inactive domain XML updated "
                                                  "when --current options used "
                                                  "for live attachment")
                if vm_state == "shutoff" and not inactive_attached:
                    raise exceptions.TestFail("Inactive domain XML not updated "
                                              "when --current options used for "
                                              "attachment")
            else:
                if vm_state in ["paused", "running", "transient"]:
                    if active_attached:
                        raise exceptions.TestFail("Active domain XML not updated"
                                                  " when --current options used "
                                                  "for detachment")
                if vm_state in ["paused", "running"]:
                    if not inactive_attached:
                        raise exceptions.TestFail("Inactive domain XML updated "
                                                  "when --current options used "
                                                  "for live detachment")
                if vm_state == "shutoff" and inactive_attached:
                    raise exceptions.TestFail("Inactive domain XML not updated"
                                              " when --current options used "
                                              "for detachment")

    def check_rhel_version(release_ver, session=None):
        """
        Login to guest and check its release version
        """
        rhel_release = {"rhel6": "Red Hat Enterprise Linux Server release 6",
                        "rhel7": "Red Hat Enterprise Linux Server release 7",
                        "fedora": "Fedora release"}
        version_file = "/etc/redhat-release"
        if not rhel_release.has_key(release_ver):
            logging.error("Can't support this version of guest: %s",
                          release_ver)
            return False

        cmd = "grep '%s' %s" % (rhel_release[release_ver], version_file)
        if session:
            s = session.cmd_status(cmd)
        else:
            s = process.run(cmd, ignore_status=True, shell=True).exit_status

        logging.debug("Check version cmd return:%s", s)
        if s == 0:
            return True
        else:
            return False

    vmxml_backup = VMXML.new_from_dumpxml(vm_name, options="--inactive")
    # Before doing anything - let's be sure we can support this test
    # Parse flag list, skip testing early if flag is not supported
    # NOTE: "".split("--") returns [''] which messes up later empty test
    at_flag = params.get("at_dt_device_at_options", "")
    dt_flag = params.get("at_dt_device_dt_options", "")
    flag_list = []
    if at_flag.count("--"):
        flag_list.extend(at_flag.split("--"))
    if dt_flag.count("--"):
        flag_list.extend(dt_flag.split("--"))
    for item in flag_list:
        option = item.strip()
        if option == "":
            continue
        if not bool(virsh.has_command_help_match("update-device", option)):
            raise exceptions.TestSkipError("virsh update-device doesn't support "
                                           "--%s" % option)

    # As per RH BZ 961443 avoid testing before behavior changes
    if 'config' in flag_list:
        # SKIP tests using --config if libvirt is 0.9.10 or earlier
        if not libvirt_version.version_compare(0, 9, 10):
            raise exceptions.TestSkipError("BZ 961443: --config behavior change "
                                           "in version 0.9.10")
    if 'persistent' in flag_list or 'live' in flag_list:
        # SKIP tests using --persistent if libvirt 1.0.5 or earlier
        if not libvirt_version.version_compare(1, 0, 5):
            raise exceptions.TestSkipError("BZ 961443: --persistent behavior "
                                           "change in version 1.0.5")

    # Get the target bus/dev
    disk_type = params.get("disk_type", "cdrom")
    target_bus = params.get("updatedevice_target_bus", "ide")
    target_dev = params.get("updatedevice_target_dev", "hdc")
    disk_mode = params.get("disk_mode", "")
    support_mode = ['readonly', 'shareable']
    if not disk_mode and disk_mode not in support_mode:
        raise exceptions.TestError("%s not in support mode %s"
                                   % (disk_mode, support_mode))

    # Prepare tmp directory and files.
    orig_iso = os.path.join(data_dir.get_tmp_dir(), "orig.iso")
    test_iso = os.path.join(data_dir.get_tmp_dir(), "test.iso")

    # Check the version first.
    host_rhel6 = check_rhel_version('rhel6')
    guest_rhel6 = False
    if not vm.is_alive():
        vm.start()
    session = vm.wait_for_login()
    if check_rhel_version('rhel6', session):
        guest_rhel6 = True
    session.close()
    vm.destroy(gracefully=False)

    try:
        # Prepare the disk first.
        create_disk(vm_name, orig_iso, disk_type, target_dev, disk_mode)
        vmxml_for_test = VMXML.new_from_dumpxml(vm_name,
                                                options="--inactive")

        # Turn VM into certain state.
        if pre_vm_state == "running":
            if at_flag == "--config" or dt_flag == "--config":
                if host_rhel6:
                    raise exceptions.TestSkipError("Config option not supported"
                                                   " on this host")
            logging.info("Starting %s..." % vm_name)
            if vm.is_dead():
                vm.start()
                vm.wait_for_login().close()
        elif pre_vm_state == "shutoff":
            if not at_flag or not dt_flag:
                if host_rhel6:
                    raise exceptions.TestSkipError("Default option not supported"
                                                   " on this host")
            logging.info("Shuting down %s..." % vm_name)
            if vm.is_alive():
                vm.destroy(gracefully=False)
        elif pre_vm_state == "paused":
            if at_flag == "--config" or dt_flag == "--config":
                if host_rhel6:
                    raise exceptions.TestSkipError("Config option not supported"
                                                   " on this host")
            logging.info("Pausing %s..." % vm_name)
            if vm.is_dead():
                vm.start()
                vm.wait_for_login().close()
            if not vm.pause():
                raise exceptions.TestSkipError("Cann't pause the domain")
        elif pre_vm_state == "transient":
            logging.info("Creating %s..." % vm_name)
            vm.undefine()
            if virsh.create(vmxml_for_test.xml, **virsh_dargs).exit_status:
                vmxml_backup.define()
                raise exceptions.TestSkipError("Cann't create the domain")
            vm.wait_for_login().close()
    except Exception, e:
        logging.error(str(e))
        if os.path.exists(orig_iso):
            os.remove(orig_iso)
        vmxml_backup.sync()
        raise exceptions.TestSkipError(str(e))

    # Get remaining parameters for configuration.
    vm_ref = params.get("updatedevice_vm_ref", "domname")
    at_status_error = "yes" == params.get("at_status_error", "no")
    dt_status_error = "yes" == params.get("dt_status_error", "no")

    dom_uuid = vm.get_uuid()
    dom_id = vm.get_id()
    # Set domain reference.
    if vm_ref == "domname":
        vm_ref = vm_name
    elif vm_ref == "domid":
        vm_ref = dom_id
    elif vm_ref == "domuuid":
        vm_ref = dom_uuid
    elif vm_ref == "hexdomid" and dom_id is not None:
        vm_ref = hex(int(dom_id))

    try:

        # Firstly detach the disk.
        update_xmlfile = os.path.join(data_dir.get_tmp_dir(),
                                      "update.xml")
        create_attach_xml(update_xmlfile, disk_type, target_bus,
                          target_dev, "", disk_mode)
        ret = virsh.update_device(vm_ref, filearg=update_xmlfile,
                                  flagstr=dt_flag, ignore_status=True,
                                  debug=True)
        if vm.is_paused():
            vm.resume()
            vm.wait_for_login().close()
        if vm.is_alive() and not guest_rhel6:
            time.sleep(5)
            # For rhel7 guest, need to update twice for it to take effect.
            ret = virsh.update_device(vm_ref, filearg=update_xmlfile,
                                      flagstr=dt_flag, ignore_status=True,
                                      debug=True)
        os.remove(update_xmlfile)
        libvirt.check_exit_status(ret, dt_status_error)
        if not ret.exit_status:
            check_result(orig_iso, disk_type, target_dev, dt_flag, False)

        # Then attach the disk.
        if pre_vm_state == "paused":
            if not vm.pause():
                raise exceptions.TestFail("Cann't pause the domain")
        create_attach_xml(update_xmlfile, disk_type, target_bus,
                          target_dev, test_iso, disk_mode)
        ret = virsh.update_device(vm_ref, filearg=update_xmlfile,
                                  flagstr=at_flag, ignore_status=True,
                                  debug=True)
        if vm.is_paused():
            vm.resume()
            vm.wait_for_login().close()
        update_twice = False
        if vm.is_alive() and not guest_rhel6:
            # For rhel7 guest, need to update twice for it to take effect.
            if (pre_vm_state in ["running", "paused"] and
                    dt_flag == "--config" and at_flag != "--config"):
                update_twice = True
            elif (pre_vm_state == "transient" and
                    dt_flag.count("config") and not at_flag.count("config")):
                update_twice = True
        if update_twice:
            time.sleep(5)
            ret = virsh.update_device(vm_ref, filearg=update_xmlfile,
                                      flagstr=at_flag, ignore_status=True,
                                      debug=True)
        libvirt.check_exit_status(ret, at_status_error)
        os.remove(update_xmlfile)
        if not ret.exit_status:
            check_result(test_iso, disk_type, target_dev, at_flag)
        # Try to start vm at last.
        if vm.is_dead():
            vm.start()
            vm.wait_for_login().close()

    finally:
        vm.destroy(gracefully=False, free_mac_addresses=False)
        vmxml_backup.sync()
        if os.path.exists(orig_iso):
            os.remove(orig_iso)
        if os.path.exists(test_iso):
            os.remove(test_iso)

Example 5

Project: tp-libvirt
Source File: virsh_snapshot_disk.py
View license
def run(test, params, env):
    """
    Test virsh snapshot command when disk in all kinds of type.

    (1). Init the variables from params.
    (2). Create a image by specifice format.
    (3). Attach disk to vm.
    (4). Snapshot create.
    (5). Snapshot revert.
    (6). cleanup.
    """
    # Init variables.
    vm_name = params.get("main_vm", "avocado-vt-vm1")
    vm = env.get_vm(vm_name)
    vm_state = params.get("vm_state", "running")
    image_format = params.get("snapshot_image_format", "qcow2")
    snapshot_del_test = "yes" == params.get("snapshot_del_test", "no")
    status_error = ("yes" == params.get("status_error", "no"))
    snapshot_from_xml = ("yes" == params.get("snapshot_from_xml", "no"))
    snapshot_current = ("yes" == params.get("snapshot_current", "no"))
    snapshot_revert_paused = ("yes" == params.get("snapshot_revert_paused",
                                                  "no"))
    replace_vm_disk = "yes" == params.get("replace_vm_disk", "no")
    disk_source_protocol = params.get("disk_source_protocol")
    vol_name = params.get("vol_name")
    tmp_dir = data_dir.get_tmp_dir()
    pool_name = params.get("pool_name", "gluster-pool")
    brick_path = os.path.join(tmp_dir, pool_name)
    multi_gluster_disks = "yes" == params.get("multi_gluster_disks", "no")

    # Pool variables.
    snapshot_with_pool = "yes" == params.get("snapshot_with_pool", "no")
    pool_name = params.get("pool_name")
    pool_type = params.get("pool_type")
    pool_target = params.get("pool_target")
    emulated_image = params.get("emulated_image", "emulated-image")
    vol_format = params.get("vol_format")
    lazy_refcounts = "yes" == params.get("lazy_refcounts")
    options = params.get("snapshot_options", "")
    export_options = params.get("export_options", "rw,no_root_squash,fsid=0")

    # Set volume xml attribute dictionary, extract all params start with 'vol_'
    # which are for setting volume xml, except 'lazy_refcounts'.
    vol_arg = {}
    for key in params.keys():
        if key.startswith('vol_'):
            if key[4:] in ['capacity', 'allocation', 'owner', 'group']:
                vol_arg[key[4:]] = int(params[key])
            else:
                vol_arg[key[4:]] = params[key]
    vol_arg['lazy_refcounts'] = lazy_refcounts

    supported_pool_list = ["dir", "fs", "netfs", "logical", "iscsi",
                           "disk", "gluster"]
    if snapshot_with_pool:
        if pool_type not in supported_pool_list:
            raise error.TestNAError("%s not in support list %s" %
                                    (pool_target, supported_pool_list))

    # Do xml backup for final recovery
    vmxml_backup = libvirt_xml.VMXML.new_from_inactive_dumpxml(vm_name)
    # Some variable for xmlfile of snapshot.
    snapshot_memory = params.get("snapshot_memory", "internal")
    snapshot_disk = params.get("snapshot_disk", "internal")
    no_memory_snap = "yes" == params.get("no_memory_snap", "no")

    # Skip 'qed' cases for libvirt version greater than 1.1.0
    if libvirt_version.version_compare(1, 1, 0):
        if vol_format == "qed" or image_format == "qed":
            raise error.TestNAError("QED support changed, check bug: "
                                    "https://bugzilla.redhat.com/show_bug.cgi"
                                    "?id=731570")

    if not libvirt_version.version_compare(1, 2, 7):
        # As bug 1017289 closed as WONTFIX, the support only
        # exist on 1.2.7 and higher
        if disk_source_protocol == 'gluster':
            raise error.TestNAError("Snapshot on glusterfs not support in "
                                    "current version. Check more info with "
                                    "https://bugzilla.redhat.com/buglist.cgi?"
                                    "bug_id=1017289,1032370")

    # Init snapshot_name
    snapshot_name = None
    snapshot_external_disk = []
    snapshot_xml_path = None
    del_status = None
    image = None
    pvt = None
    # Get a tmp dir
    snap_cfg_path = "/var/lib/libvirt/qemu/snapshot/%s/" % vm_name
    try:
        if replace_vm_disk:
            utlv.set_vm_disk(vm, params, tmp_dir)
            if multi_gluster_disks:
                new_params = params.copy()
                new_params["pool_name"] = "gluster-pool2"
                new_params["vol_name"] = "gluster-vol2"
                new_params["disk_target"] = "vdf"
                new_params["image_convert"] = 'no'
                utlv.set_vm_disk(vm, new_params, tmp_dir)

        if snapshot_with_pool:
            # Create dst pool for create attach vol img
            pvt = utlv.PoolVolumeTest(test, params)
            pvt.pre_pool(pool_name, pool_type, pool_target,
                         emulated_image, image_size="1G",
                         pre_disk_vol=["20M"],
                         source_name=vol_name,
                         export_options=export_options)

            if pool_type in ["iscsi", "disk"]:
                # iscsi and disk pool did not support create volume in libvirt,
                # logical pool could use libvirt to create volume but volume
                # format is not supported and will be 'raw' as default.
                pv = libvirt_storage.PoolVolume(pool_name)
                vols = pv.list_volumes().keys()
                if vols:
                    vol_name = vols[0]
                else:
                    raise error.TestNAError("No volume in pool: %s" % pool_name)
            else:
                # Set volume xml file
                volxml = libvirt_xml.VolXML()
                newvol = volxml.new_vol(**vol_arg)
                vol_xml = newvol['xml']

                # Run virsh_vol_create to create vol
                logging.debug("create volume from xml: %s" % newvol.xmltreefile)
                cmd_result = virsh.vol_create(pool_name, vol_xml,
                                              ignore_status=True,
                                              debug=True)
                if cmd_result.exit_status:
                    raise error.TestNAError("Failed to create attach volume.")

            cmd_result = virsh.vol_path(vol_name, pool_name, debug=True)
            if cmd_result.exit_status:
                raise error.TestNAError("Failed to get volume path from pool.")
            img_path = cmd_result.stdout.strip()

            if pool_type in ["logical", "iscsi", "disk"]:
                # Use qemu-img to format logical, iscsi and disk block device
                if vol_format != "raw":
                    cmd = "qemu-img create -f %s %s 10M" % (vol_format,
                                                            img_path)
                    cmd_result = utils.run(cmd, ignore_status=True)
                    if cmd_result.exit_status:
                        raise error.TestNAError("Failed to format volume, %s" %
                                                cmd_result.stdout.strip())
            extra = "--persistent --subdriver %s" % vol_format
        else:
            # Create a image.
            params['image_name'] = "snapshot_test"
            params['image_format'] = image_format
            params['image_size'] = "1M"
            image = qemu_storage.QemuImg(params, tmp_dir, "snapshot_test")
            img_path, _ = image.create(params)
            extra = "--persistent --subdriver %s" % image_format

        if not multi_gluster_disks:
            # Do the attach action.
            out = utils.run("qemu-img info %s" % img_path)
            logging.debug("The img info is:\n%s" % out.stdout.strip())
            result = virsh.attach_disk(vm_name, source=img_path, target="vdf",
                                       extra=extra, debug=True)
            if result.exit_status:
                raise error.TestNAError("Failed to attach disk %s to VM."
                                        "Detail: %s." % (img_path, result.stderr))

        # Create snapshot.
        if snapshot_from_xml:
            snap_xml = libvirt_xml.SnapshotXML()
            snapshot_name = "snapshot_test"
            snap_xml.snap_name = snapshot_name
            snap_xml.description = "Snapshot Test"
            if not no_memory_snap:
                if "--disk-only" not in options:
                    if snapshot_memory == "external":
                        memory_external = os.path.join(tmp_dir,
                                                       "snapshot_memory")
                        snap_xml.mem_snap_type = snapshot_memory
                        snap_xml.mem_file = memory_external
                        snapshot_external_disk.append(memory_external)
                    else:
                        snap_xml.mem_snap_type = snapshot_memory

            # Add all disks into xml file.
            vmxml = libvirt_xml.VMXML.new_from_inactive_dumpxml(vm_name)
            disks = vmxml.devices.by_device_tag('disk')
            new_disks = []
            for src_disk_xml in disks:
                disk_xml = snap_xml.SnapDiskXML()
                disk_xml.xmltreefile = src_disk_xml.xmltreefile
                del disk_xml.device
                del disk_xml.address
                disk_xml.snapshot = snapshot_disk
                disk_xml.disk_name = disk_xml.target['dev']

                # Only qcow2 works as external snapshot file format, update it
                # here
                driver_attr = disk_xml.driver
                driver_attr.update({'type': 'qcow2'})
                disk_xml.driver = driver_attr

                if snapshot_disk == 'external':
                    new_attrs = disk_xml.source.attrs
                    if disk_xml.source.attrs.has_key('file'):
                        new_file = "%s.snap" % disk_xml.source.attrs['file']
                        snapshot_external_disk.append(new_file)
                        new_attrs.update({'file': new_file})
                        hosts = None
                    elif disk_xml.source.attrs.has_key('name'):
                        new_name = "%s.snap" % disk_xml.source.attrs['name']
                        new_attrs.update({'name': new_name})
                        hosts = disk_xml.source.hosts
                    elif (disk_xml.source.attrs.has_key('dev') and
                          disk_xml.type_name == 'block'):
                        # Use local file as external snapshot target for block type.
                        # As block device will be treat as raw format by default,
                        # it's not fit for external disk snapshot target. A work
                        # around solution is use qemu-img again with the target.
                        disk_xml.type_name = 'file'
                        del new_attrs['dev']
                        new_file = "%s/blk_src_file.snap" % tmp_dir
                        snapshot_external_disk.append(new_file)
                        new_attrs.update({'file': new_file})
                        hosts = None

                    new_src_dict = {"attrs": new_attrs}
                    if hosts:
                        new_src_dict.update({"hosts": hosts})
                    disk_xml.source = disk_xml.new_disk_source(**new_src_dict)
                else:
                    del disk_xml.source

                new_disks.append(disk_xml)

            snap_xml.set_disks(new_disks)
            snapshot_xml_path = snap_xml.xml
            logging.debug("The snapshot xml is: %s" % snap_xml.xmltreefile)

            options += " --xmlfile %s " % snapshot_xml_path

            if vm_state == "shut off":
                vm.destroy(gracefully=False)

            snapshot_result = virsh.snapshot_create(
                vm_name, options, debug=True)
            out_err = snapshot_result.stderr.strip()
            if snapshot_result.exit_status:
                if status_error:
                    return
                else:
                    if re.search("live disk snapshot not supported with this "
                                 "QEMU binary", out_err):
                        raise error.TestNAError(out_err)

                    if libvirt_version.version_compare(1, 2, 5):
                        # As commit d2e668e in 1.2.5, internal active snapshot
                        # without memory state is rejected. Handle it as SKIP
                        # for now. This could be supportted in future by bug:
                        # https://bugzilla.redhat.com/show_bug.cgi?id=1103063
                        if re.search("internal snapshot of a running VM" +
                                     " must include the memory state",
                                     out_err):
                            raise error.TestNAError("Check Bug #1083345, %s" %
                                                    out_err)

                    raise error.TestFail("Failed to create snapshot. Error:%s."
                                         % out_err)
        else:
            snapshot_result = virsh.snapshot_create(vm_name, options,
                                                    debug=True)
            if snapshot_result.exit_status:
                if status_error:
                    return
                else:
                    raise error.TestFail("Failed to create snapshot. Error:%s."
                                         % snapshot_result.stderr.strip())
            snapshot_name = re.search(
                "\d+", snapshot_result.stdout.strip()).group(0)

            if snapshot_current:
                snap_xml = libvirt_xml.SnapshotXML()
                new_snap = snap_xml.new_from_snapshot_dumpxml(vm_name,
                                                              snapshot_name)
                # update an element
                new_snap.creation_time = snapshot_name
                snapshot_xml_path = new_snap.xml
                options += "--redefine %s --current" % snapshot_xml_path
                snapshot_result = virsh.snapshot_create(vm_name,
                                                        options, debug=True)
                if snapshot_result.exit_status:
                    raise error.TestFail("Failed to create snapshot --current."
                                         "Error:%s." %
                                         snapshot_result.stderr.strip())

        if status_error:
            if not snapshot_del_test:
                raise error.TestFail("Success to create snapshot in negative"
                                     " case\nDetail: %s" % snapshot_result)

        # Touch a file in VM.
        if vm.is_dead():
            vm.start()
        session = vm.wait_for_login()

        # Init a unique name for tmp_file.
        tmp_file = tempfile.NamedTemporaryFile(prefix=("snapshot_test_"),
                                               dir="/tmp")
        tmp_file_path = tmp_file.name
        tmp_file.close()

        echo_cmd = "echo SNAPSHOT_DISK_TEST >> %s" % tmp_file_path
        status, output = session.cmd_status_output(echo_cmd)
        logging.debug("The echo output in domain is: '%s'", output)
        if status:
            raise error.TestFail("'%s' run failed with '%s'" %
                                 (tmp_file_path, output))
        status, output = session.cmd_status_output("cat %s" % tmp_file_path)
        logging.debug("File created with content: '%s'", output)

        session.close()

        # As only internal snapshot revert works now, let's only do revert
        # with internal, and move the all skip external cases back to pass.
        # After external also supported, just move the following code back.
        if snapshot_disk == 'internal':
            # Destroy vm for snapshot revert.
            if not libvirt_version.version_compare(1, 2, 3):
                virsh.destroy(vm_name)
            # Revert snapshot.
            revert_options = ""
            if snapshot_revert_paused:
                revert_options += " --paused"
            revert_result = virsh.snapshot_revert(vm_name, snapshot_name,
                                                  revert_options,
                                                  debug=True)
            if revert_result.exit_status:
                # Attempts to revert external snapshots will FAIL with an error
                # "revert to external disk snapshot not supported yet" or "revert
                # to external snapshot not supported yet" since d410e6f. Thus,
                # let's check for that and handle as a SKIP for now. Check bug:
                # https://bugzilla.redhat.com/show_bug.cgi?id=1071264
                if re.search("revert to external \w* ?snapshot not supported yet",
                             revert_result.stderr):
                    raise error.TestNAError(revert_result.stderr.strip())
                else:
                    raise error.TestFail("Revert snapshot failed. %s" %
                                         revert_result.stderr.strip())

            if vm.is_dead():
                raise error.TestFail("Revert snapshot failed.")

            if snapshot_revert_paused:
                if vm.is_paused():
                    vm.resume()
                else:
                    raise error.TestFail("Revert command successed, but VM is not "
                                         "paused after reverting with --paused"
                                         "  option.")
            # login vm.
            session = vm.wait_for_login()
            # Check the result of revert.
            status, output = session.cmd_status_output("cat %s" % tmp_file_path)
            logging.debug("After revert cat file output='%s'", output)
            if not status:
                raise error.TestFail("Tmp file exists, revert failed.")

            # Close the session.
            session.close()

        # Test delete snapshot without "--metadata", delete external disk
        # snapshot will fail for now.
        # Only do this when snapshot creat succeed which filtered in cfg file.
        if snapshot_del_test:
            if snapshot_name:
                del_result = virsh.snapshot_delete(vm_name, snapshot_name,
                                                   debug=True,
                                                   ignore_status=True)
                del_status = del_result.exit_status
                snap_xml_path = snap_cfg_path + "%s.xml" % snapshot_name
                if del_status:
                    if not status_error:
                        raise error.TestFail("Failed to delete snapshot.")
                    else:
                        if not os.path.exists(snap_xml_path):
                            raise error.TestFail("Snapshot xml file %s missing"
                                                 % snap_xml_path)
                else:
                    if status_error:
                        err_msg = "Snapshot delete succeed but expect fail."
                        raise error.TestFail(err_msg)
                    else:
                        if os.path.exists(snap_xml_path):
                            raise error.TestFail("Snapshot xml file %s still"
                                                 % snap_xml_path + " exist")

    finally:
        if vm.is_alive():
            vm.destroy(gracefully=False)
        virsh.detach_disk(vm_name, target="vdf", extra="--persistent")
        if image:
            image.remove()
        if del_status and snapshot_name:
            virsh.snapshot_delete(vm_name, snapshot_name, "--metadata")
        for disk in snapshot_external_disk:
            if os.path.exists(disk):
                os.remove(disk)
        vmxml_backup.sync("--snapshots-metadata")

        libvirtd = utils_libvirtd.Libvirtd()
        if disk_source_protocol == 'gluster':
            utlv.setup_or_cleanup_gluster(False, vol_name, brick_path)
            if multi_gluster_disks:
                brick_path = os.path.join(tmp_dir, "gluster-pool2")
                utlv.setup_or_cleanup_gluster(False, "gluster-vol2", brick_path)
            libvirtd.restart()

        if snapshot_xml_path:
            if os.path.exists(snapshot_xml_path):
                os.unlink(snapshot_xml_path)
        if pvt:
            try:
                pvt.cleanup_pool(pool_name, pool_type, pool_target,
                                 emulated_image, source_name=vol_name)
            except error.TestFail, detail:
                libvirtd.restart()
                logging.error(str(detail))

Example 6

Project: tp-libvirt
Source File: virtual_disks_ceph.py
View license
def run(test, params, env):
    """
    Test rbd disk device.

    1.Prepare test environment,destroy or suspend a VM.
    2.Prepare disk image.
    3.Edit disks xml and start the domain.
    4.Perform test operation.
    5.Recover test environment.
    """
    vm_name = params.get("main_vm")
    vm = env.get_vm(vm_name)
    virsh_dargs = {'debug': True, 'ignore_status': True}

    def config_ceph():
        """
        Write the configs to the file.
        """
        src_host = disk_src_host.split()
        src_port = disk_src_port.split()
        conf_str = "mon_host = "
        hosts = []
        for host, port in zip(src_host, src_port):
            hosts.append("%s:%s" % (host, port))
        with open(disk_src_config, 'w') as f:
            f.write(conf_str + ','.join(hosts) + '\n')

    def create_pool():
        """
        Define and start a pool.
        """
        sp = libvirt_storage.StoragePool()
        if create_by_xml:
            p_xml = pool_xml.PoolXML(pool_type=pool_type)
            p_xml.name = pool_name
            s_xml = pool_xml.SourceXML()
            s_xml.vg_name = disk_src_pool
            source_host = []
            for (host_name, host_port) in zip(
                    disk_src_host.split(), disk_src_port.split()):
                source_host.append({'name': host_name,
                                    'port': host_port})

            s_xml.hosts = source_host
            if auth_type:
                s_xml.auth_type = auth_type
            if auth_user:
                s_xml.auth_username = auth_user
            if auth_usage:
                s_xml.secret_usage = auth_usage
            p_xml.source = s_xml
            logging.debug("Pool xml: %s", p_xml)
            p_xml.xmltreefile.write()
            ret = virsh.pool_define(p_xml.xml, **virsh_dargs)
            libvirt.check_exit_status(ret)
            ret = virsh.pool_build(pool_name, **virsh_dargs)
            libvirt.check_exit_status(ret)
            ret = virsh.pool_start(pool_name, **virsh_dargs)
            libvirt.check_exit_status(ret)
        else:
            auth_opt = ""
            if client_name and client_key:
                auth_opt = ("--auth-type %s --auth-username %s --secret-usage '%s'"
                            % (auth_type, auth_user, auth_usage))
            if not sp.define_rbd_pool(pool_name, mon_host,
                                      disk_src_pool, extra=auth_opt):
                raise error.TestFail("Failed to define storage pool")
            if not sp.build_pool(pool_name):
                raise error.TestFail("Failed to build storage pool")
            if not sp.start_pool(pool_name):
                raise error.TestFail("Failed to start storage pool")

        # Check pool operation
        ret = virsh.pool_refresh(pool_name, **virsh_dargs)
        libvirt.check_exit_status(ret)
        ret = virsh.pool_uuid(pool_name, **virsh_dargs)
        libvirt.check_exit_status(ret)
        # pool-info
        pool_info = sp.pool_info(pool_name)
        if pool_info["Autostart"] != 'no':
            raise error.TestFail("Failed to check pool information")
        # pool-autostart
        if not sp.set_pool_autostart(pool_name):
            raise error.TestFail("Failed to set pool autostart")
        pool_info = sp.pool_info(pool_name)
        if pool_info["Autostart"] != 'yes':
            raise error.TestFail("Failed to check pool information")
        # pool-autostart --disable
        if not sp.set_pool_autostart(pool_name, "--disable"):
            raise error.TestFail("Failed to set pool autostart")
        # find-storage-pool-sources-as
        ret = virsh.find_storage_pool_sources_as("rbd", mon_host)
        libvirt.check_result(ret, unsupported_msg)

    def create_vol(vol_params):
        """
        Create volume.

        :param p_name. Pool name.
        :param vol_params. Volume parameters dict.
        :return: True if create successfully.
        """
        pvt = libvirt.PoolVolumeTest(test, params)
        if create_by_xml:
            pvt.pre_vol_by_xml(pool_name, **vol_params)
        else:
            pvt.pre_vol(vol_name, None, '2G', None, pool_name)

    def check_vol(vol_params):
        """
        Check volume infomation.
        """
        pv = libvirt_storage.PoolVolume(pool_name)
        # Supported operation
        if vol_name not in pv.list_volumes():
            raise error.TestFail("Volume %s doesn't exist" % vol_name)
        ret = virsh.vol_dumpxml(vol_name, pool_name)
        libvirt.check_exit_status(ret)
        # vol-info
        if not pv.volume_info(vol_name):
            raise error.TestFail("Can't see volmue info")
        # vol-key
        ret = virsh.vol_key(vol_name, pool_name)
        libvirt.check_exit_status(ret)
        if "%s/%s" % (disk_src_pool, vol_name) not in ret.stdout:
            raise error.TestFail("Volume key isn't correct")
        # vol-path
        ret = virsh.vol_path(vol_name, pool_name)
        libvirt.check_exit_status(ret)
        if "%s/%s" % (disk_src_pool, vol_name) not in ret.stdout:
            raise error.TestFail("Volume path isn't correct")
        # vol-pool
        ret = virsh.vol_pool("%s/%s" % (disk_src_pool, vol_name))
        libvirt.check_exit_status(ret)
        if pool_name not in ret.stdout:
            raise error.TestFail("Volume pool isn't correct")
        # vol-name
        ret = virsh.vol_name("%s/%s" % (disk_src_pool, vol_name))
        libvirt.check_exit_status(ret)
        if vol_name not in ret.stdout:
            raise error.TestFail("Volume name isn't correct")
        # vol-resize
        ret = virsh.vol_resize(vol_name, "2G", pool_name)
        libvirt.check_exit_status(ret)

        # Not supported operation
        # vol-clone
        ret = virsh.vol_clone(vol_name, "atest.vol", pool_name)
        libvirt.check_result(ret, unsupported_msg)
        # vol-create-from
        volxml = vol_xml.VolXML()
        vol_params.update({"name": "atest.vol"})
        v_xml = volxml.new_vol(**vol_params)
        v_xml.xmltreefile.write()
        ret = virsh.vol_create_from(pool_name, v_xml.xml, vol_name, pool_name)
        libvirt.check_result(ret, unsupported_msg)

        # vol-wipe
        ret = virsh.vol_wipe(vol_name, pool_name)
        libvirt.check_result(ret, unsupported_msg)
        # vol-upload
        ret = virsh.vol_upload(vol_name, vm.get_first_disk_devices()['source'],
                               "--pool %s" % pool_name)
        libvirt.check_result(ret, unsupported_msg)
        # vol-download
        ret = virsh.vol_download(vol_name, "atest.vol", "--pool %s" % pool_name)
        libvirt.check_result(ret, unsupported_msg)

    def check_qemu_cmd():
        """
        Check qemu command line options.
        """
        cmd = ("ps -ef | grep %s | grep -v grep " % vm_name)
        if disk_src_name:
            cmd += " | grep file=rbd:%s:" % disk_src_name
            if auth_user and auth_key:
                cmd += ('id=%s:auth_supported=cephx' % auth_user)
        if disk_src_config:
            cmd += " | grep 'conf=%s'" % disk_src_config
        elif mon_host:
            hosts = '\:6789\;'.join(mon_host.split())
            cmd += " | grep 'mon_host=%s'" % hosts
        if driver_iothread:
            cmd += " | grep iothread=iothread%s" % driver_iothread
        # Run the command
        process.run(cmd, shell=True)

    def check_save_restore():
        """
        Test save and restore operation
        """
        save_file = os.path.join(test.tmpdir,
                                 "%s.save" % vm_name)
        ret = virsh.save(vm_name, save_file, **virsh_dargs)
        libvirt.check_exit_status(ret)
        ret = virsh.restore(save_file, **virsh_dargs)
        libvirt.check_exit_status(ret)
        if os.path.exists(save_file):
            os.remove(save_file)
        # Login to check vm status
        vm.wait_for_login().close()

    def check_snapshot(snap_option):
        """
        Test snapshot operation.
        """
        snap_name = "s1"
        snap_mem = os.path.join(test.tmpdir, "rbd.mem")
        snap_disk = os.path.join(test.tmpdir, "rbd.disk")
        expected_fails = []
        xml_snap_exp = ["disk name='vda' snapshot='external' type='file'"]
        xml_dom_exp = ["source file='%s'" % snap_disk,
                       "backingStore type='network' index='1'",
                       "source protocol='rbd' name='%s'" % disk_src_name]
        if snap_option.count("disk-only"):
            options = ("%s --diskspec vda,file=%s --disk-only" %
                       (snap_name, snap_disk))
        elif snap_option.count("disk-mem"):
            options = ("%s --memspec file=%s --diskspec vda,file="
                       "%s" % (snap_name, snap_mem, snap_disk))
            xml_snap_exp.append("memory snapshot='external' file='%s'"
                                % snap_mem)
        else:
            options = snap_name

        error_msg = params.get("error_msg")
        if error_msg:
            expected_fails.append(error_msg)
        ret = virsh.snapshot_create_as(vm_name, options)
        if ret.exit_status:
            libvirt.check_result(ret, expected_fails)

        # check xml file.
        if not ret.exit_status:
            snap_xml = virsh.snapshot_dumpxml(vm_name, snap_name,
                                              debug=True).stdout.strip()
            dom_xml = virsh.dumpxml(vm_name, debug=True).stdout.strip()
            # Delete snapshots.
            libvirt.clean_up_snapshots(vm_name)
            if os.path.exists(snap_mem):
                os.remove(snap_mem)
            if os.path.exists(snap_disk):
                os.remove(snap_disk)

            if not all([x in snap_xml for x in xml_snap_exp]):
                raise error.TestFail("Failed to check snapshot xml")
            if not all([x in dom_xml for x in xml_dom_exp]):
                raise error.TestFail("Failed to check domain xml")

    def check_blockcopy(target):
        """
        Block copy operation test.
        """
        blk_file = os.path.join(test.tmpdir, "blk.rbd")
        if os.path.exists(blk_file):
            os.remove(blk_file)
        blk_mirror = ("mirror type='file' file='%s' "
                      "format='raw' job='copy'" % blk_file)

        # Do blockcopy
        ret = virsh.blockcopy(vm_name, target, blk_file)
        if ret.exit_status:
            error_msg = params.get("error_msg")
            if not error_msg:
                libvirt.check_exit_status(ret)
            else:
                libvirt.check_result(ret, [error_msg])
            # Passed error check, return
            return

        dom_xml = virsh.dumpxml(vm_name, debug=True).stdout.strip()
        if not dom_xml.count(blk_mirror):
            raise error.TestFail("Can't see block job in domain xml")

        # Abort
        ret = virsh.blockjob(vm_name, target, "--abort")
        libvirt.check_exit_status(ret)
        dom_xml = virsh.dumpxml(vm_name, debug=True).stdout.strip()
        if dom_xml.count(blk_mirror):
            raise error.TestFail("Failed to abort block job")
        if os.path.exists(blk_file):
            os.remove(blk_file)

        # Sleep for a while after abort operation.
        time.sleep(5)
        # Do blockcopy again
        ret = virsh.blockcopy(vm_name, target, blk_file)
        libvirt.check_exit_status(ret)

        # Wait for complete
        def wait_func():
            ret = virsh.blockjob(vm_name, target, "--info")
            return ret.stderr.count("Block Copy: [100 %]")
        timeout = params.get("blockjob_timeout", 600)
        utils_misc.wait_for(wait_func, int(timeout))

        # Pivot
        ret = virsh.blockjob(vm_name, target, "--pivot")
        libvirt.check_exit_status(ret)
        dom_xml = virsh.dumpxml(vm_name, debug=True).stdout.strip()
        if not dom_xml.count("source file='%s'" % blk_file):
            raise error.TestFail("Failed to pivot block job")
        # Remove the disk file.
        if os.path.exists(blk_file):
            os.remove(blk_file)

    def check_in_vm(vm_obj, target, old_parts, read_only=False):
        """
        Check mount/read/write disk in VM.
        :param vm. VM guest.
        :param target. Disk dev in VM.
        :return: True if check successfully.
        """
        try:
            session = vm_obj.wait_for_login()
            new_parts = libvirt.get_parts_list(session)
            added_parts = list(set(new_parts).difference(set(old_parts)))
            logging.info("Added parts:%s", added_parts)
            if len(added_parts) != 1:
                logging.error("The number of new partitions is invalid in VM")
                return False

            added_part = None
            if target.startswith("vd"):
                if added_parts[0].startswith("vd"):
                    added_part = added_parts[0]
            elif target.startswith("hd"):
                if added_parts[0].startswith("sd"):
                    added_part = added_parts[0]

            if not added_part:
                logging.error("Cann't see added partition in VM")
                return False

            cmd = ("mount /dev/{0} /mnt && ls /mnt && (sleep 15;"
                   " touch /mnt/testfile; umount /mnt)"
                   .format(added_part))
            s, o = session.cmd_status_output(cmd, timeout=60)
            session.close()
            logging.info("Check disk operation in VM:\n, %s, %s", s, o)
            # Readonly fs, check the error messages.
            # The command may return True, read-only
            # messges can be found from the command output
            if read_only:
                if "Read-only file system" not in o:
                    return False
                else:
                    return True

            # Other errors
            if s != 0:
                return False
            return True

        except (remote.LoginError, virt_vm.VMError, aexpect.ShellError), e:
            logging.error(str(e))
            return False

    mon_host = params.get("mon_host")
    disk_src_name = params.get("disk_source_name")
    disk_src_config = params.get("disk_source_config")
    disk_src_host = params.get("disk_source_host")
    disk_src_port = params.get("disk_source_port")
    disk_src_pool = params.get("disk_source_pool")
    disk_format = params.get("disk_format", "raw")
    driver_iothread = params.get("driver_iothread")
    pre_vm_state = params.get("pre_vm_state", "running")
    snap_name = params.get("disk_snap_name")
    attach_device = "yes" == params.get("attach_device", "no")
    attach_disk = "yes" == params.get("attach_disk", "no")
    test_save_restore = "yes" == params.get("test_save_restore", "no")
    test_snapshot = "yes" == params.get("test_snapshot", "no")
    test_blockcopy = "yes" == params.get("test_blockcopy", "no")
    test_qemu_cmd = "yes" == params.get("test_qemu_cmd", "no")
    test_vm_parts = "yes" == params.get("test_vm_parts", "no")
    additional_guest = "yes" == params.get("additional_guest", "no")
    create_snapshot = "yes" == params.get("create_snapshot", "no")
    convert_image = "yes" == params.get("convert_image", "no")
    create_volume = "yes" == params.get("create_volume", "no")
    create_by_xml = "yes" == params.get("create_by_xml", "no")
    client_key = params.get("client_key")
    client_name = params.get("client_name")
    auth_key = params.get("auth_key")
    auth_user = params.get("auth_user")
    auth_type = params.get("auth_type")
    auth_usage = params.get("secret_usage")
    pool_name = params.get("pool_name")
    pool_type = params.get("pool_type")
    vol_name = params.get("vol_name")
    vol_cap = params.get("vol_cap")
    vol_cap_unit = params.get("vol_cap_unit")
    start_error_msg = params.get("start_error_msg")
    attach_error_msg = params.get("attach_error_msg")
    unsupported_msg = params.get("unsupported_msg")

    # Start vm and get all partions in vm.
    if vm.is_dead():
        vm.start()
    session = vm.wait_for_login()
    old_parts = libvirt.get_parts_list(session)
    session.close()
    vm.destroy(gracefully=False)
    if additional_guest:
        guest_name = "%s_%s" % (vm_name, '1')
        timeout = params.get("clone_timeout", 360)
        utils_libguestfs.virt_clone_cmd(vm_name, guest_name,
                                        True, timeout=timeout,
                                        ignore_status=False)
        additional_vm = vm.clone(guest_name)
        if pre_vm_state == "running":
            virsh.start(guest_name)

    # Back up xml file.
    vmxml_backup = vm_xml.VMXML.new_from_inactive_dumpxml(vm_name)
    key_opt = ""
    secret_uuid = None
    key_file = os.path.join(test.tmpdir, "ceph.key")
    img_file = os.path.join(test.tmpdir,
                            "%s_test.img" % vm_name)

    try:
        # Set domain state
        libvirt.set_domain_state(vm, pre_vm_state)

        # Install ceph-common package which include rbd command
        if utils_misc.yum_install(["ceph-common"]):
            if client_name and client_key:
                with open(key_file, 'w') as f:
                    f.write("[%s]\n\tkey = %s\n" %
                            (client_name, client_key))
                key_opt = "--keyring %s" % key_file

                # Create secret xml
                sec_xml = secret_xml.SecretXML("no", "no")
                sec_xml.usage = auth_type
                sec_xml.usage_name = auth_usage
                sec_xml.xmltreefile.write()

                logging.debug("Secret xml: %s", sec_xml)
                ret = virsh.secret_define(sec_xml.xml)
                libvirt.check_exit_status(ret)

                secret_uuid = re.findall(r".+\S+(\ +\S+)\ +.+\S+",
                                         ret.stdout)[0].lstrip()
                logging.debug("Secret uuid %s", secret_uuid)
                if secret_uuid is None:
                    raise error.TestNAError("Failed to get secret uuid")

                # Set secret value
                auth_key = params.get("auth_key")
                ret = virsh.secret_set_value(secret_uuid, auth_key,
                                             **virsh_dargs)
                libvirt.check_exit_status(ret)

            # TODO - Delete the disk if it exists
            #cmd = ("rbd -m {0} {1} info {2} && rbd -m {0} {1} rm "
            #       "{2}".format(mon_host, key_opt, disk_src_name))
            #process.run(cmd, ignore_status=True, shell=True)
        else:
            raise error.TestNAError("Failed to install ceph-common")

        if disk_src_config:
            config_ceph()
        disk_path = ("rbd:%s:mon_host=%s" %
                     (disk_src_name, mon_host))
        if auth_user and auth_key:
            disk_path += (":id=%s:key=%s" %
                          (auth_user, auth_key))
        targetdev = params.get("disk_target", "vdb")
        # To be compatible with create_disk_xml function,
        # some parameters need to be updated.
        params.update({
            "type_name": params.get("disk_type", "network"),
            "target_bus": params.get("disk_target_bus"),
            "target_dev": targetdev,
            "secret_uuid": secret_uuid,
            "source_protocol": params.get("disk_source_protocol"),
            "source_name": disk_src_name,
            "source_host_name": disk_src_host,
            "source_host_port": disk_src_port})
        # Prepare disk image
        if convert_image:
            first_disk = vm.get_first_disk_devices()
            blk_source = first_disk['source']
            # Convert the image to remote storage
            disk_cmd = ("rbd -m %s %s info %s || qemu-img convert"
                        " -O %s %s %s" % (mon_host, key_opt,
                                          disk_src_name, disk_format,
                                          blk_source, disk_path))
            process.run(disk_cmd, ignore_status=False, shell=True)

        elif create_volume:
            vol_params = {"name": vol_name, "capacity": int(vol_cap),
                          "capacity_unit": vol_cap_unit, "format": "unknow"}

            create_pool()
            create_vol(vol_params)
            check_vol(vol_params)
        else:
            # Create an local image and make FS on it.
            disk_cmd = ("qemu-img create -f %s %s 10M && mkfs.ext4 -F %s" %
                        (disk_format, img_file, img_file))
            process.run(disk_cmd, ignore_status=False, shell=True)
            # Convert the image to remote storage
            disk_cmd = ("rbd -m %s %s info %s || qemu-img convert -O"
                        " %s %s %s" % (mon_host, key_opt, disk_src_name,
                                       disk_format, img_file, disk_path))
            process.run(disk_cmd, ignore_status=False, shell=True)
            # Create disk snapshot if needed.
            if create_snapshot:
                snap_cmd = ("rbd -m %s %s snap create %[email protected]%s" %
                            (mon_host, key_opt, disk_src_name, snap_name))
                process.run(snap_cmd, ignore_status=False, shell=True)
        if attach_device:
            if create_volume:
                params.update({"type_name": "volume"})
                # No need auth options for volume
                if "auth_user" in params:
                    params.pop("auth_user")
                if "auth_type" in params:
                    params.pop("auth_type")
                if "secret_type" in params:
                    params.pop("secret_type")
                if "secret_uuid" in params:
                    params.pop("secret_uuid")
                if "secret_usage" in params:
                    params.pop("secret_usage")
            xml_file = libvirt.create_disk_xml(params)
            opts = params.get("attach_option", "")
            ret = virsh.attach_device(vm_name, xml_file,
                                      flagstr=opts, debug=True)
            if attach_error_msg:
                libvirt.check_result(ret, attach_error_msg)
            else:
                libvirt.check_exit_status(ret)
            if additional_guest:
                ret = virsh.attach_device(guest_name, xml_file,
                                          "", debug=True)
                libvirt.check_exit_status(ret)
        elif attach_disk:
            ret = virsh.attach_disk(vm_name, disk_path,
                                    targetdev, **virsh_dargs)
            libvirt.check_exit_status(ret)
        elif not create_volume:
            libvirt.set_vm_disk(vm, params)

        if pre_vm_state == "transient":
            logging.info("Creating %s...", vm_name)
            vmxml_for_test = vm_xml.VMXML.new_from_inactive_dumpxml(vm_name)
            if vm.is_alive():
                vm.destroy(gracefully=False)
            vm.undefine()
            if virsh.create(vmxml_for_test.xml, **virsh_dargs).exit_status:
                vmxml_backup.define()
                raise error.TestFail("Cann't create the domain")
        elif vm.is_dead():
            vm.start()
        # Wait for vm is running
        vm.wait_for_login(timeout=600).close()
        if additional_guest:
            if additional_vm.is_dead():
                additional_vm.start()
        # Check qemu command line
        if test_qemu_cmd:
            check_qemu_cmd()
        # Check partitions in vm
        if test_vm_parts:
            if not check_in_vm(vm, targetdev, old_parts,
                               read_only=create_snapshot):
                raise error.TestFail("Failed to check vm partitions")
            if additional_guest:
                if not check_in_vm(additional_vm, targetdev, old_parts):
                    raise error.TestFail("Failed to check vm partitions")
        # Save and restore operation
        if test_save_restore:
            check_save_restore()
        if test_snapshot:
            snap_option = params.get("snapshot_option", "")
            check_snapshot(snap_option)
        if test_blockcopy:
            check_blockcopy(targetdev)

        # Detach the device.
        if attach_device and not attach_error_msg:
            xml_file = libvirt.create_disk_xml(params)
            ret = virsh.detach_device(vm_name, xml_file)
            libvirt.check_exit_status(ret)
            if additional_guest:
                ret = virsh.detach_device(guest_name, xml_file)
                libvirt.check_exit_status(ret)
        elif attach_disk:
            ret = virsh.detach_disk(vm_name, targetdev)
            libvirt.check_exit_status(ret)

        # Check disk in vm after detachment.
        if (attach_device or attach_disk) and not attach_error_msg:
            session = vm.wait_for_login()
            new_parts = libvirt.get_parts_list(session)
            if len(new_parts) != len(old_parts):
                raise error.TestFail("Disk still exists in vm"
                                     " after detachment")
            session.close()

    except virt_vm.VMStartError, details:
        if start_error_msg in str(details):
            pass
        else:
            raise error.TestFail("VM failed to start."
                                 "Error: %s" % str(details))
    finally:
        # Delete snapshots.
        snapshot_lists = virsh.snapshot_list(vm_name)
        if len(snapshot_lists) > 0:
            libvirt.clean_up_snapshots(vm_name, snapshot_lists)
            for snap in snapshot_lists:
                virsh.snapshot_delete(vm_name, snap, "--metadata")

        # Recover VM.
        if vm.is_alive():
            vm.destroy(gracefully=False)
        if additional_guest:
            virsh.remove_domain(guest_name,
                                "--remove-all-storage",
                                ignore_stauts=True)
        # Remove the snapshot.
        if create_snapshot:
            cmd = ("rbd -m {0} {1} info {2} && rbd -m {0} {1} snap"
                   " purge {2} && rbd -m {0} {1} rm {2}"
                   "".format(mon_host, key_opt, disk_src_name))
            process.run(cmd, ignore_status=True, shell=True)
        elif attach_device or attach_disk:
            cmd = ("rbd -m {0} {1} info {2} && rbd -m {0} {1} rm {2}"
                   "".format(mon_host, key_opt, disk_src_name))
            process.run(cmd, ignore_status=True, shell=True)

        # Delete tmp files.
        if os.path.exists(key_file):
            os.remove(key_file)
        if os.path.exists(img_file):
            os.remove(img_file)
        # Clean up volume, pool
        if vol_name and vol_name in str(virsh.vol_list(pool_name).stdout):
            virsh.vol_delete(vol_name, pool_name)
        if pool_name and virsh.pool_state_dict().has_key(pool_name):
            virsh.pool_destroy(pool_name, **virsh_dargs)
            virsh.pool_undefine(pool_name, **virsh_dargs)

        # Clean up secret
        if secret_uuid:
            virsh.secret_undefine(secret_uuid)

        logging.info("Restoring vm...")
        vmxml_backup.sync()

Example 7

View license
def run(test, params, env):
    """
    Test multiple disks attachment.

    1.Prepare test environment,destroy or suspend a VM.
    2.Perform 'qemu-img create' operation.
    3.Edit disks xml and start the domain.
    4.Perform test operation.
    5.Recover test environment.
    6.Confirm the test result.
    """
    vm_name = params.get("main_vm")
    vm = env.get_vm(vm_name)
    virsh_dargs = {'debug': True, 'ignore_status': True}

    def check_disk_order(targets_name):
        """
        Check VM disk's order on pci bus.

        :param targets_name. Disks target list.
        :return: True if check successfully.
        """
        logging.info("Checking VM disks order...")
        xml = vm_xml.VMXML.new_from_dumpxml(vm_name)
        disk_list = xml.devices.by_device_tag("disk")
        slot_dict = {}
        # Get the disks pci slot.
        for disk in disk_list:
            if 'virtio' == disk.target['bus']:
                slot_dict[disk.target['dev']] = int(
                    disk.address.attrs['slot'], base=16)
        # Disk's order on pci bus should keep the same with disk target name.
        s_dev = sorted(slot_dict.keys())
        s_slot = sorted(slot_dict.values())
        for i in range(len(s_dev)):
            if s_dev[i] in targets_name and slot_dict[s_dev[i]] != s_slot[i]:
                return False
        return True

    def setup_nfs_disk(disk_name, disk_type, disk_format="raw"):
        """
        Setup nfs disk.
        """
        mount_src = os.path.join(test.tmpdir, "nfs-export")
        if not os.path.exists(mount_src):
            os.mkdir(mount_src)
        mount_dir = os.path.join(test.tmpdir, "nfs-mount")

        if disk_type in ["file", "floppy", "iso"]:
            disk_path = "%s/%s" % (mount_src, disk_name)
            device_source = libvirt.create_local_disk(disk_type, disk_path, "2",
                                                      disk_format=disk_format)
            #Format the disk.
            if disk_type in ["file", "floppy"]:
                cmd = ("mkfs.ext3 -F %s && setsebool virt_use_nfs true"
                       % device_source)
                if utils.run(cmd, ignore_status=True).exit_status:
                    raise error.TestNAError("Format disk failed")

        nfs_params = {"nfs_mount_dir": mount_dir, "nfs_mount_options": "ro",
                      "nfs_mount_src": mount_src, "setup_local_nfs": "yes",
                      "export_options": "rw,no_root_squash"}

        nfs_obj = nfs.Nfs(nfs_params)
        nfs_obj.setup()
        if not nfs_obj.mount():
            return None

        disk = {"disk_dev": nfs_obj, "format": "nfs", "source":
                "%s/%s" % (mount_dir, os.path.split(device_source)[-1])}

        return disk

    def prepare_disk(path, disk_format):
        """
        Prepare the disk for a given disk format.
        """
        disk = {}
        # Check if we test with a non-existed disk.
        if os.path.split(path)[-1].startswith("notexist."):
            disk.update({"format": disk_format,
                         "source": path})

        elif disk_format == "scsi":
            scsi_option = params.get("virt_disk_device_scsi_option", "")
            disk_source = libvirt.create_scsi_disk(scsi_option)
            if disk_source:
                disk.update({"format": "scsi",
                             "source": disk_source})
            else:
                raise error.TestNAError("Get scsi disk failed")

        elif disk_format in ["iso", "floppy"]:
            disk_path = libvirt.create_local_disk(disk_format, path)
            disk.update({"format": disk_format,
                         "source": disk_path})
        elif disk_format == "nfs":
            nfs_disk_type = params.get("nfs_disk_type", None)
            disk.update(setup_nfs_disk(os.path.split(path)[-1], nfs_disk_type))

        elif disk_format == "iscsi":
            # Create iscsi device if needed.
            image_size = params.get("image_size", "2G")
            device_source = libvirt.setup_or_cleanup_iscsi(
                is_setup=True, is_login=True, image_size=image_size)
            logging.debug("iscsi dev name: %s", device_source)

            # Format the disk and make file system.
            libvirt.mk_part(device_source)
            # Run partprobe to make the change take effect.
            utils.run("partprobe", ignore_status=True)
            libvirt.mkfs("%s1" % device_source, "ext3")
            device_source += "1"
            disk.update({"format": disk_format,
                         "source": device_source})
        elif disk_format in ["raw", "qcow2"]:
            disk_size = params.get("virt_disk_device_size", "1")
            device_source = libvirt.create_local_disk(
                "file", path, disk_size, disk_format=disk_format)
            disk.update({"format": disk_format,
                         "source": device_source})

        return disk

    def check_disk_format(targets_name, targets_format):
        """
        Check VM disk's type.

        :param targets_name. Device name list.
        :param targets_format. Device format list.
        :return: True if check successfully.
        """
        logging.info("Checking VM disks type... ")
        for tn, tf in zip(targets_name, targets_format):
            disk_format = vm_xml.VMXML.get_disk_attr(vm_name, tn,
                                                     "driver", "type")
            if disk_format not in [None, tf]:
                return False
        return True

    def check_vm_partitions(devices, targets_name, exists=True):
        """
        Check VM disk's partition.

        :return: True if check successfully.
        """
        logging.info("Checking VM partittion...")
        try:
            session = vm.wait_for_login()
            for i in range(len(devices)):
                if devices[i] == "cdrom":
                    s, o = session.cmd_status_output(
                        "ls /dev/sr0 && mount /dev/sr0 /mnt &&"
                        " ls /mnt && umount /mnt")
                    logging.info("cdrom devices in VM:\n%s", o)
                elif devices[i] == "floppy":
                    s, o = session.cmd_status_output(
                        "modprobe floppy && ls /dev/fd0")
                    logging.info("floppy devices in VM:\n%s", o)
                else:
                    if targets_name[i] == "hda":
                        target = "sda"
                    else:
                        target = targets_name[i]
                    s, o = session.cmd_status_output(
                        "grep %s /proc/partitions" % target)
                    logging.info("Disk devices in VM:\n%s", o)
                if s != 0:
                    if exists:
                        session.close()
                        return False
                else:
                    if not exists:
                        session.close()
                        return False
            session.close()
            return True
        except (remote.LoginError, virt_vm.VMError, aexpect.ShellError), e:
            logging.error(str(e))
            return False

    def check_vm_block_size(targets_name, log_size, phy_size):
        """
        Check VM disk's blocksize.

        :param logical_size. Device logical block size.
        :param physical_size. Device physical block size.
        :return: True if check successfully.
        """
        logging.info("Checking VM block size...")
        try:
            session = vm.wait_for_login()
            for target in targets_name:
                cmd = "cat /sys/block/%s/queue/" % target
                s, o = session.cmd_status_output("%slogical_block_size"
                                                 % cmd)
                logging.debug("logical block size in VM:\n%s", o)
                if s != 0 or o.strip() != log_size:
                    session.close()
                    return False
                s, o = session.cmd_status_output("%sphysical_block_size"
                                                 % cmd)
                logging.debug("physical block size in VM:\n%s", o)
                if s != 0 or o.strip() != phy_size:
                    session.close()
                    return False
            session.close()
            return True
        except (remote.LoginError, virt_vm.VMError, aexpect.ShellError), e:
            logging.error(str(e))
            return False

    def check_readonly(targets_name):
        """
        Check disk readonly option.
        """
        logging.info("Checking disk readonly option...")
        try:
            session = vm.wait_for_login()
            for target in targets_name:
                if target == "hdc":
                    mount_cmd = "mount /dev/cdrom /mnt"
                elif target == "fda":
                    mount_cmd = "modprobe floppy && mount /dev/fd0 /mnt"
                else:
                    mount_cmd = "mount /dev/%s /mnt" % target
                cmd = ("(%s && ls /mnt || exit 1) && (echo "
                       "'test' > /mnt/test || umount /mnt)" % mount_cmd)
                s, o = session.cmd_status_output(cmd)
                logging.debug("cmd exit: %s, output: %s", s, o)
                if s:
                    session.close()
                    return False
            session.close()
            return True
        except (remote.LoginError, virt_vm.VMError, aexpect.ShellError), e:
            logging.error(str(e))
            return False

    def check_bootorder_snapshot(disk_name):
        """
        Check VM disk's bootorder option with snapshot.

        :param disk_name. The target disk to be checked.
        """
        logging.info("Checking diskorder option with snapshot...")
        snapshot1 = "s1"
        snapshot2 = "s2"
        snapshot2_file = os.path.join(test.tmpdir, "s2")
        ret = virsh.snapshot_create(vm_name, "", **virsh_dargs)
        libvirt.check_exit_status(ret)

        ret = virsh.snapshot_create_as(vm_name, "%s --disk-only" % snapshot1,
                                       **virsh_dargs)
        libvirt.check_exit_status(ret)

        ret = virsh.snapshot_dumpxml(vm_name, snapshot1)
        libvirt.check_exit_status(ret)

        cmd = "echo \"%s\" | grep %s.%s" % (ret.stdout, disk_name, snapshot1)
        if utils.run(cmd, ignore_status=True).exit_status:
            raise error.TestError("Check snapshot disk failed")

        ret = virsh.snapshot_create_as(vm_name,
                                       "%s --memspec file=%s,snapshot=external"
                                       % (snapshot2, snapshot2_file),
                                       **virsh_dargs)
        libvirt.check_exit_status(ret)

        ret = virsh.dumpxml(vm_name)
        libvirt.check_exit_status(ret)

        cmd = ("echo \"%s\" | grep -A 16 %s.%s | grep \"boot order='%s'\""
               % (ret.stdout, disk_name, snapshot2, bootorder))
        if utils.run(cmd, ignore_status=True).exit_status:
            raise error.TestError("Check snapshot disk with bootorder failed")

        snap_lists = virsh.snapshot_list(vm_name)
        if snapshot1 not in snap_lists or snapshot2 not in snap_lists:
            raise error.TestError("Check snapshot list failed")

        # Check virsh save command after snapshot.
        save_file = "/tmp/%s.save" % vm_name
        ret = virsh.save(vm_name, save_file, **virsh_dargs)
        libvirt.check_exit_status(ret)

        # Check virsh restore command after snapshot.
        ret = virsh.restore(save_file, **virsh_dargs)
        libvirt.check_exit_status(ret)

        #Passed all test.
        os.remove(save_file)

    def check_boot_console(bootorders):
        """
        Get console output and check bootorder.
        """
        # Get console output.
        vm.serial_console.read_until_output_matches(
            ["Hard Disk"], utils_misc.strip_console_codes)
        output = vm.serial_console.get_stripped_output()
        logging.debug("serial output: %s", output)
        lines = re.findall(r"^Booting from (.+)...", output, re.M)
        logging.debug("lines: %s", lines)
        if len(lines) != len(bootorders):
            return False
        for i in range(len(bootorders)):
            if lines[i] != bootorders[i]:
                return False

        return True

    def check_disk_save_restore(save_file, device_targets,
                                startup_policy):
        """
        Check domain save and restore operation.
        """
        # Save the domain.
        ret = virsh.save(vm_name, save_file,
                         **virsh_dargs)
        libvirt.check_exit_status(ret)

        # Restore the domain.
        restore_error = False
        # Check disk startup policy option
        if "optional" in startup_policy:
            os.remove(disks[0]["source"])
            restore_error = True
        ret = virsh.restore(save_file, **virsh_dargs)
        libvirt.check_exit_status(ret, restore_error)
        if restore_error:
            return

        # Connect to the domain and check disk.
        try:
            session = vm.wait_for_login()
            cmd = ("ls /dev/%s && mkfs.ext3 -F /dev/%s && mount /dev/%s"
                   " /mnt && ls /mnt && touch /mnt/test && umount /mnt"
                   % (device_targets[0], device_targets[0], device_targets[0]))
            s, o = session.cmd_status_output(cmd)
            if s:
                session.close()
                raise error.TestError("Failed to read/write disk in VM:"
                                      " %s" % o)
            session.close()
        except (remote.LoginError, virt_vm.VMError, aexpect.ShellError), e:
            raise error.TestError(str(e))

    def check_dom_iothread():
        """
        Check iothread by qemu-monitor-command.
        """
        ret = virsh.qemu_monitor_command(vm_name,
                                         '{"execute": "query-iothreads"}',
                                         "--pretty")
        libvirt.check_exit_status(ret)
        logging.debug("Domain iothreads: %s", ret.stdout)
        iothreads_ret = json.loads(ret.stdout)
        if len(iothreads_ret['return']) != int(dom_iothreads):
            raise error.TestFail("Failed to check domain iothreads")

    status_error = "yes" == params.get("status_error", "no")
    define_error = "yes" == params.get("define_error", "no")
    dom_iothreads = params.get("dom_iothreads")

    # Disk specific attributes.
    devices = params.get("virt_disk_device", "disk").split()
    device_source_names = params.get("virt_disk_device_source").split()
    device_targets = params.get("virt_disk_device_target", "vda").split()
    device_formats = params.get("virt_disk_device_format", "raw").split()
    device_types = params.get("virt_disk_device_type", "file").split()
    device_bus = params.get("virt_disk_device_bus", "virtio").split()
    driver_options = params.get("driver_option", "").split()
    device_bootorder = params.get("virt_disk_boot_order", "").split()
    device_readonly = params.get("virt_disk_option_readonly",
                                 "no").split()
    device_attach_error = params.get("disks_attach_error", "").split()
    device_attach_option = params.get("disks_attach_option", "").split(';')
    device_address = params.get("virt_disk_addr_options", "").split()
    startup_policy = params.get("virt_disk_device_startuppolicy", "").split()
    bootorder = params.get("disk_bootorder", "")
    bootdisk_target = params.get("virt_disk_bootdisk_target", "vda")
    bootdisk_bus = params.get("virt_disk_bootdisk_bus", "virtio")
    bootdisk_driver = params.get("virt_disk_bootdisk_driver", "")
    serial = params.get("virt_disk_serial", "")
    wwn = params.get("virt_disk_wwn", "")
    vendor = params.get("virt_disk_vendor", "")
    product = params.get("virt_disk_product", "")
    add_disk_driver = params.get("add_disk_driver")
    iface_driver = params.get("iface_driver_option", "")
    bootdisk_snapshot = params.get("bootdisk_snapshot", "")
    snapshot_option = params.get("snapshot_option", "")
    snapshot_error = "yes" == params.get("snapshot_error", "no")
    add_usb_device = "yes" == params.get("add_usb_device", "no")
    input_usb_address = params.get("input_usb_address", "")
    hub_usb_address = params.get("hub_usb_address", "")
    hotplug = "yes" == params.get(
        "virt_disk_device_hotplug", "no")
    device_at_dt_disk = "yes" == params.get("virt_disk_at_dt_disk", "no")
    device_with_source = "yes" == params.get(
        "virt_disk_with_source", "yes")
    virtio_scsi_controller = "yes" == params.get(
        "virtio_scsi_controller", "no")
    virtio_scsi_controller_driver = params.get(
        "virtio_scsi_controller_driver", "")
    source_path = "yes" == params.get(
        "virt_disk_device_source_path", "yes")
    check_patitions = "yes" == params.get(
        "virt_disk_check_partitions", "yes")
    check_patitions_hotunplug = "yes" == params.get(
        "virt_disk_check_partitions_hotunplug", "yes")
    test_slots_order = "yes" == params.get(
        "virt_disk_device_test_order", "no")
    test_disks_format = "yes" == params.get(
        "virt_disk_device_test_format", "no")
    test_block_size = "yes" == params.get(
        "virt_disk_device_test_block_size", "no")
    test_file_img_on_disk = "yes" == params.get(
        "test_file_image_on_disk", "no")
    test_with_boot_disk = "yes" == params.get(
        "virt_disk_with_boot_disk", "no")
    test_disk_option_cmd = "yes" == params.get(
        "test_disk_option_cmd", "no")
    test_disk_type_dir = "yes" == params.get(
        "virt_disk_test_type_dir", "no")
    test_disk_bootorder = "yes" == params.get(
        "virt_disk_test_bootorder", "no")
    test_disk_bootorder_snapshot = "yes" == params.get(
        "virt_disk_test_bootorder_snapshot", "no")
    test_boot_console = "yes" == params.get(
        "virt_disk_device_boot_console", "no")
    test_disk_readonly = "yes" == params.get(
        "virt_disk_device_test_readonly", "no")
    test_disk_snapshot = "yes" == params.get(
        "virt_disk_test_snapshot", "no")
    test_disk_save_restore = "yes" == params.get(
        "virt_disk_test_save_restore", "no")
    test_bus_device_option = "yes" == params.get(
        "test_bus_option_cmd", "no")
    snapshot_before_start = "yes" == params.get(
        "snapshot_before_start", "no")

    if dom_iothreads:
        if not libvirt_version.version_compare(1, 2, 8):
            raise error.TestNAError("iothreads not supported for"
                                    " this libvirt version")

    if test_block_size:
        logical_block_size = params.get("logical_block_size")
        physical_block_size = params.get("physical_block_size")

    if any([test_boot_console, add_disk_driver]):
        if vm.is_dead():
            vm.start()
        session = vm.wait_for_login()
        if test_boot_console:
            # Setting console to kernel parameters
            vm.set_kernel_console("ttyS0", "115200")
        if add_disk_driver:
            # Ignore errors here
            session.cmd("dracut --force --add-drivers '%s'"
                        % add_disk_driver)
        session.close()
        vm.shutdown()

    # Destroy VM.
    if vm.is_alive():
        vm.destroy(gracefully=False)

    # Back up xml file.
    vmxml_backup = vm_xml.VMXML.new_from_inactive_dumpxml(vm_name)

    # Get device path.
    device_source_path = ""
    if source_path:
        device_source_path = test.virtdir

    # Prepare test environment.
    qemu_config = LibvirtQemuConfig()
    if test_disks_format:
        qemu_config.allow_disk_format_probing = True
        utils_libvirtd.libvirtd_restart()

    # Create virtual device file.
    disks = []
    try:
        for i in range(len(device_source_names)):
            if test_disk_type_dir:
                # If we testing disk type dir option,
                # it needn't to create disk image
                disks.append({"format": "dir",
                              "source": device_source_names[i]})
            else:
                path = "%s/%s.%s" % (device_source_path,
                                     device_source_names[i], device_formats[i])
                disk = prepare_disk(path, device_formats[i])
                if disk:
                    disks.append(disk)

    except Exception, e:
        logging.error(repr(e))
        for img in disks:
            if img.has_key("disk_dev"):
                if img["format"] == "nfs":
                    img["disk_dev"].cleanup()
            else:
                if img["format"] == "iscsi":
                    libvirt.setup_or_cleanup_iscsi(is_setup=False)
                if img["format"] not in ["dir", "scsi"]:
                    os.remove(img["source"])
        raise error.TestNAError("Creating disk failed")

    # Build disks xml.
    disks_xml = []
    # Additional disk images.
    disks_img = []
    vmxml = vm_xml.VMXML.new_from_dumpxml(vm_name)
    try:
        for i in range(len(disks)):
            disk_xml = Disk(type_name=device_types[i])
            # If we are testing image file on iscsi disk,
            # mount the disk and then create the image.
            if test_file_img_on_disk:
                mount_path = "/tmp/diskimg"
                if utils.run("mkdir -p %s && mount %s %s"
                             % (mount_path, disks[i]["source"],
                                mount_path), ignore_status=True).exit_status:
                    raise error.TestNAError("Prepare disk failed")
                disk_path = "%s/%s.qcow2" % (mount_path, device_source_names[i])
                disk_source = libvirt.create_local_disk("file", disk_path, "1",
                                                        disk_format="qcow2")
                disks_img.append({"format": "qcow2",
                                  "source": disk_source, "path": mount_path})
            else:
                disk_source = disks[i]["source"]

            disk_xml.device = devices[i]

            if device_with_source:
                if device_types[i] == "file":
                    dev_attrs = "file"
                elif device_types[i] == "dir":
                    dev_attrs = "dir"
                else:
                    dev_attrs = "dev"
                source_dict = {dev_attrs: disk_source}
                if len(startup_policy) > i:
                    source_dict.update({"startupPolicy": startup_policy[i]})
                disk_xml.source = disk_xml.new_disk_source(
                    **{"attrs": source_dict})

            if len(device_bootorder) > i:
                disk_xml.boot = device_bootorder[i]

            if test_block_size:
                disk_xml.blockio = {"logical_block_size": logical_block_size,
                                    "physical_block_size": physical_block_size}

            if wwn != "":
                disk_xml.wwn = wwn
            if serial != "":
                disk_xml.serial = serial
            if vendor != "":
                disk_xml.vendor = vendor
            if product != "":
                disk_xml.product = product

            disk_xml.target = {"dev": device_targets[i], "bus": device_bus[i]}
            if len(device_readonly) > i:
                disk_xml.readonly = "yes" == device_readonly[i]

            # Add driver options from parameters
            driver_dict = {"name": "qemu"}
            if len(driver_options) > i:
                for driver_option in driver_options[i].split(','):
                    if driver_option != "":
                        d = driver_option.split('=')
                        driver_dict.update({d[0].strip(): d[1].strip()})
            disk_xml.driver = driver_dict

            # Add disk address from parameters.
            if len(device_address) > i:
                addr_dict = {}
                for addr_option in device_address[i].split(','):
                    if addr_option != "":
                        d = addr_option.split('=')
                        addr_dict.update({d[0].strip(): d[1].strip()})
                disk_xml.address = disk_xml.new_disk_address(
                    **{"attrs": addr_dict})

            logging.debug("disk xml: %s", disk_xml)
            if hotplug:
                disks_xml.append(disk_xml)
            else:
                vmxml.add_device(disk_xml)

        # If we want to test with bootdisk.
        # just edit the bootdisk xml.
        if test_with_boot_disk:
            xml_devices = vmxml.devices
            disk_index = xml_devices.index(xml_devices.by_device_tag("disk")[0])
            disk = xml_devices[disk_index]
            if bootorder != "":
                disk.boot = bootorder
                osxml = vm_xml.VMOSXML()
                osxml.type = vmxml.os.type
                osxml.arch = vmxml.os.arch
                osxml.machine = vmxml.os.machine
                if test_boot_console:
                    osxml.loader = "/usr/share/seabios/bios.bin"
                    osxml.bios_useserial = "yes"
                    osxml.bios_reboot_timeout = "-1"

                del vmxml.os
                vmxml.os = osxml
            driver_dict = {"name": disk.driver["name"],
                           "type": disk.driver["type"]}
            if bootdisk_driver != "":
                for driver_option in bootdisk_driver.split(','):
                    if driver_option != "":
                        d = driver_option.split('=')
                        driver_dict.update({d[0].strip(): d[1].strip()})
            disk.driver = driver_dict

            if iface_driver != "":
                driver_dict = {"name": "vhost"}
                for driver_option in iface_driver.split(','):
                    if driver_option != "":
                        d = driver_option.split('=')
                        driver_dict.update({d[0].strip(): d[1].strip()})
                iface_list = xml_devices.by_device_tag("interface")[0]
                iface_index = xml_devices.index(iface_list)
                iface = xml_devices[iface_index]
                iface.driver = iface.new_driver(**{"driver_attr": driver_dict})
                iface.model = "virtio"
                del iface.address

            if bootdisk_snapshot != "":
                disk.snapshot = bootdisk_snapshot

            disk.target = {"dev": bootdisk_target, "bus": bootdisk_bus}
            device_source = disk.source.attrs["file"]

            del disk.address
            vmxml.devices = xml_devices
            vmxml.define()

        # Add virtio_scsi controller.
        if virtio_scsi_controller:
            scsi_controller = Controller("controller")
            scsi_controller.type = "scsi"
            scsi_controller.index = "0"
            ctl_model = params.get("virtio_scsi_controller_model")
            if ctl_model:
                scsi_controller.model = ctl_model
            if virtio_scsi_controller_driver != "":
                driver_dict = {}
                for driver_option in virtio_scsi_controller_driver.split(','):
                    if driver_option != "":
                        d = driver_option.split('=')
                        driver_dict.update({d[0].strip(): d[1].strip()})
                scsi_controller.driver = driver_dict
            vmxml.del_controller("scsi")
            vmxml.add_device(scsi_controller)

        # Test usb devices.
        usb_devices = {}
        if add_usb_device:
            # Delete all usb devices first.
            controllers = vmxml.get_devices(device_type="controller")
            for ctrl in controllers:
                if ctrl.type == "usb":
                    vmxml.del_device(ctrl)

            inputs = vmxml.get_devices(device_type="input")
            for input in inputs:
                if input.type_name == "tablet":
                    vmxml.del_device(input)

            # Add new usb controllers.
            usb_controller1 = Controller("controller")
            usb_controller1.type = "usb"
            usb_controller1.index = "0"
            usb_controller1.model = "piix3-uhci"
            vmxml.add_device(usb_controller1)
            usb_controller2 = Controller("controller")
            usb_controller2.type = "usb"
            usb_controller2.index = "1"
            usb_controller2.model = "ich9-ehci1"
            vmxml.add_device(usb_controller2)

            input_obj = Input("tablet")
            input_obj.input_bus = "usb"
            addr_dict = {}
            if input_usb_address != "":
                for addr_option in input_usb_address.split(','):
                    if addr_option != "":
                        d = addr_option.split('=')
                        addr_dict.update({d[0].strip(): d[1].strip()})
            if addr_dict:
                input_obj.address = input_obj.new_input_address(
                    **{"attrs": addr_dict})
            vmxml.add_device(input_obj)
            usb_devices.update({"input": addr_dict})

            hub_obj = Hub("usb")
            addr_dict = {}
            if hub_usb_address != "":
                for addr_option in hub_usb_address.split(','):
                    if addr_option != "":
                        d = addr_option.split('=')
                        addr_dict.update({d[0].strip(): d[1].strip()})
            if addr_dict:
                hub_obj.address = hub_obj.new_hub_address(
                    **{"attrs": addr_dict})
            vmxml.add_device(hub_obj)
            usb_devices.update({"hub": addr_dict})

        if dom_iothreads:
            # Delete cputune/iothreadids section, it may have conflict
            # with domain iothreads
            del vmxml.cputune
            del vmxml.iothreadids
            vmxml.iothreads = int(dom_iothreads)

        # After compose the disk xml, redefine the VM xml.
        vmxml.sync()

        # Test snapshot before vm start.
        if test_disk_snapshot:
            if snapshot_before_start:
                ret = virsh.snapshot_create_as(vm_name, "s1 %s"
                                               % snapshot_option)
                libvirt.check_exit_status(ret, snapshot_error)

        # Start the VM.
        vm.start()
        if status_error:
            raise error.TestFail("VM started unexpectedly")

        # Hotplug the disks.
        if device_at_dt_disk:
            for i in range(len(disks)):
                attach_option = ""
                if len(device_attach_option) > i:
                    attach_option = device_attach_option[i]
                ret = virsh.attach_disk(vm_name, disks[i]["source"],
                                        device_targets[i],
                                        attach_option)
                libvirt.check_exit_status(ret)

        elif hotplug:
            for i in range(len(disks_xml)):
                disks_xml[i].xmltreefile.write()
                attach_option = ""
                if len(device_attach_option) > i:
                    attach_option = device_attach_option[i]
                ret = virsh.attach_device(vm_name, disks_xml[i].xml,
                                          flagstr=attach_option)
                attach_error = False
                if len(device_attach_error) > i:
                    attach_error = "yes" == device_attach_error[i]
                libvirt.check_exit_status(ret, attach_error)

    except virt_vm.VMStartError as details:
        if not status_error:
            raise error.TestFail('VM failed to start:\n%s' % details)
    except xcepts.LibvirtXMLError:
        if not define_error:
            raise error.TestFail("Failed to define VM")
    else:
        # VM is started, perform the tests.
        if test_slots_order:
            if not check_disk_order(device_targets):
                raise error.TestFail("Disks slots order error in domain xml")

        if test_disks_format:
            if not check_disk_format(device_targets, device_formats):
                raise error.TestFail("Disks type error in VM xml")

        if test_boot_console:
            # Check if disks bootorder is as expected.
            expected_order = params.get("expected_order").split(',')
            if not check_boot_console(expected_order):
                raise error.TestFail("Test VM bootorder failed")

        if test_block_size:
            # Check disk block size in VM.
            if not check_vm_block_size(device_targets,
                                       logical_block_size, physical_block_size):
                raise error.TestFail("Test disk block size in VM failed")

        if test_disk_option_cmd:
            # Check if disk options take affect in qemu commmand line.
            cmd = ("ps -ef | grep %s | grep -v grep " % vm_name)
            if test_with_boot_disk:
                d_target = bootdisk_target
            else:
                d_target = device_targets[0]
                if device_with_source:
                    cmd += (" | grep %s" %
                            (device_source_names[0].replace(',', ',,')))
            io = vm_xml.VMXML.get_disk_attr(vm_name, d_target, "driver", "io")
            if io:
                cmd += " | grep aio=%s" % io
            ioeventfd = vm_xml.VMXML.get_disk_attr(vm_name, d_target,
                                                   "driver", "ioeventfd")
            if ioeventfd:
                cmd += " | grep ioeventfd=%s" % ioeventfd
            event_idx = vm_xml.VMXML.get_disk_attr(vm_name, d_target,
                                                   "driver", "event_idx")
            if event_idx:
                cmd += " | grep event_idx=%s" % event_idx

            discard = vm_xml.VMXML.get_disk_attr(vm_name, d_target,
                                                 "driver", "discard")
            if discard:
                cmd += " | grep discard=%s" % discard
            copy_on_read = vm_xml.VMXML.get_disk_attr(vm_name, d_target,
                                                      "driver", "copy_on_read")
            if copy_on_read:
                cmd += " | grep copy-on-read=%s" % copy_on_read

            iothread = vm_xml.VMXML.get_disk_attr(vm_name, d_target,
                                                  "driver", "iothread")
            if iothread:
                cmd += " | grep iothread=iothread%s" % iothread

            if serial != "":
                cmd += " | grep serial=%s" % serial
            if wwn != "":
                cmd += " | grep -E \"wwn=(0x)?%s\"" % wwn
            if vendor != "":
                cmd += " | grep vendor=%s" % vendor
            if product != "":
                cmd += " | grep \"product=%s\"" % product

            num_queues = ""
            ioeventfd = ""
            if virtio_scsi_controller_driver != "":
                for driver_option in virtio_scsi_controller_driver.split(','):
                    if driver_option != "":
                        d = driver_option.split('=')
                        if d[0].strip() == "queues":
                            num_queues = d[1].strip()
                        elif d[0].strip() == "ioeventfd":
                            ioeventfd = d[1].strip()
            if num_queues != "":
                cmd += " | grep num_queues=%s" % num_queues
            if ioeventfd:
                cmd += " | grep ioeventfd=%s" % ioeventfd

            iface_event_idx = ""
            if iface_driver != "":
                for driver_option in iface_driver.split(','):
                    if driver_option != "":
                        d = driver_option.split('=')
                        if d[0].strip() == "event_idx":
                            iface_event_idx = d[1].strip()
            if iface_event_idx != "":
                cmd += " | grep virtio-net-pci,event_idx=%s" % iface_event_idx

            if utils.run(cmd, ignore_status=True).exit_status:
                raise error.TestFail("Check disk driver option failed")

        if test_disk_snapshot:
            ret = virsh.snapshot_create_as(vm_name, "s1 %s" % snapshot_option)
            libvirt.check_exit_status(ret, snapshot_error)

        # Check the disk bootorder.
        if test_disk_bootorder:
            for i in range(len(device_targets)):
                if len(device_attach_error) > i:
                    if device_attach_error[i] == "yes":
                        continue
                if device_bootorder[i] != vm_xml.VMXML.get_disk_attr(
                        vm_name, device_targets[i], "boot", "order"):
                    raise error.TestFail("Check bootorder failed")

        # Check disk bootorder with snapshot.
        if test_disk_bootorder_snapshot:
            disk_name = os.path.splitext(device_source)[0]
            check_bootorder_snapshot(disk_name)

        # Check disk readonly option.
        if test_disk_readonly:
            if not check_readonly(device_targets):
                raise error.TestFail("Checking disk readonly option failed")

        # Check disk bus device option in qemu command line.
        if test_bus_device_option:
            cmd = ("ps -ef | grep %s | grep -v grep " % vm_name)
            dev_bus = int(vm_xml.VMXML.get_disk_attr(vm_name, device_targets[0],
                                                     "address", "bus"), 16)
            if device_bus[0] == "virtio":
                pci_slot = int(vm_xml.VMXML.get_disk_attr(vm_name, device_targets[0],
                                                          "address", "slot"), 16)
                if devices[0] == "lun":
                    device_option = "scsi=on"
                else:
                    device_option = "scsi=off"
                cmd += (" | grep virtio-blk-pci,%s,bus=pci.%x,addr=0x%x"
                        % (device_option, dev_bus, pci_slot))
            if device_bus[0] in ["ide", "sata", "scsi"]:
                dev_unit = int(vm_xml.VMXML.get_disk_attr(vm_name, device_targets[0],
                                                          "address", "unit"), 16)
                dev_id = vm_xml.VMXML.get_disk_attr(vm_name, device_targets[0],
                                                    "alias", "name")
            if device_bus[0] == "ide":
                check_cmd = "/usr/libexec/qemu-kvm -device ? 2>&1 |grep -E 'ide-cd|ide-hd'"
                if utils.run(check_cmd, ignore_status=True).exit_status:
                    raise error.TestNAError("ide-cd/ide-hd not supported by this qemu-kvm")

                if devices[0] == "cdrom":
                    device_option = "ide-cd"
                else:
                    device_option = "ide-hd"
                cmd += (" | grep %s,bus=ide.%d,unit=%d,drive=drive-%s,id=%s"
                        % (device_option, dev_bus, dev_unit, dev_id, dev_id))
            if device_bus[0] == "sata":
                cmd += (" | grep 'device ahci,.*,bus=pci.%s'" % dev_bus)
            if device_bus[0] == "scsi":
                if devices[0] == "lun":
                    device_option = "scsi-block"
                elif devices[0] == "cdrom":
                    device_option = "scsi-cd"
                else:
                    device_option = "scsi-hd"
                cmd += (" | grep %s,bus=scsi%d.%d,.*drive=drive-%s,id=%s"
                        % (device_option, dev_bus, dev_unit, dev_id, dev_id))
            if device_bus[0] == "usb":
                dev_port = vm_xml.VMXML.get_disk_attr(vm_name, device_targets[0],
                                                      "address", "port")
                dev_id = vm_xml.VMXML.get_disk_attr(vm_name, device_targets[0],
                                                    "alias", "name")
                if devices[0] == "disk":
                    cmd += (" | grep usb-storage,bus=usb%s.0,port=%s,"
                            "drive=drive-%s,id=%s"
                            % (dev_bus, dev_port, dev_id, dev_id))
                if usb_devices.has_key("input"):
                    cmd += (" | grep usb-tablet,id=input[0-9],bus=usb.%s,"
                            "port=%s" % (usb_devices["input"]["bus"],
                                         usb_devices["input"]["port"]))
                if usb_devices.has_key("hub"):
                    cmd += (" | grep usb-hub,id=hub0,bus=usb.%s,"
                            "port=%s" % (usb_devices["hub"]["bus"],
                                         usb_devices["hub"]["port"]))

            if utils.run(cmd, ignore_status=True).exit_status:
                raise error.TestFail("Cann't see disk option"
                                     " in command line")

        if dom_iothreads:
            check_dom_iothread()

        # Check in VM after command.
        if check_patitions:
            if not check_vm_partitions(devices,
                                       device_targets):
                raise error.TestFail("Cann't see device in VM")

        # Check disk save and restore.
        if test_disk_save_restore:
            save_file = "/tmp/%s.save" % vm_name
            check_disk_save_restore(save_file, device_targets,
                                    startup_policy)
            if os.path.exists(save_file):
                os.remove(save_file)

        # If we testing hotplug, detach the disk at last.
        if device_at_dt_disk:
            for i in range(len(disks)):
                dt_options = ""
                if devices[i] == "cdrom":
                    dt_options = "--config"
                ret = virsh.detach_disk(vm_name, device_targets[i],
                                        dt_options, **virsh_dargs)
                libvirt.check_exit_status(ret)
            # Check disks in VM after hotunplug.
            if check_patitions_hotunplug:
                if not check_vm_partitions(devices,
                                           device_targets, False):
                    raise error.TestFail("See device in VM after hotunplug")

        elif hotplug:
            for i in range(len(disks_xml)):
                if len(device_attach_error) > i:
                    if device_attach_error[i] == "yes":
                        continue
                ret = virsh.detach_device(vm_name, disks_xml[i].xml,
                                          flagstr=attach_option, **virsh_dargs)
                os.remove(disks_xml[i].xml)
                libvirt.check_exit_status(ret)

            # Check disks in VM after hotunplug.
            if check_patitions_hotunplug:
                if not check_vm_partitions(devices,
                                           device_targets, False):
                    raise error.TestFail("See device in VM after hotunplug")

    finally:
        # Delete snapshots.
        libvirt.clean_up_snapshots(vm_name, domxml=vmxml_backup)

        # Recover VM.
        if vm.is_alive():
            vm.destroy(gracefully=False)
        vmxml_backup.sync("--snapshots-metadata")

        # Restore qemu_config file.
        qemu_config.restore()
        utils_libvirtd.libvirtd_restart()

        for img in disks_img:
            os.remove(img["source"])
            if os.path.exists(img["path"]):
                utils.run("umount %s && rmdir %s"
                          % (img["path"], img["path"]), ignore_status=True)

        for img in disks:
            if img.has_key("disk_dev"):
                if img["format"] == "nfs":
                    img["disk_dev"].cleanup()

                del img["disk_dev"]
            else:
                if img["format"] == "scsi":
                    libvirt.delete_scsi_disk()
                elif img["format"] == "iscsi":
                    libvirt.setup_or_cleanup_iscsi(is_setup=False)
                elif img["format"] not in ["dir"]:
                    if os.path.exists(img["source"]):
                        os.remove(img["source"])

Example 8

Project: tp-qemu
Source File: cpuflags.py
View license
def run(test, params, env):
    """
    Boot guest with different cpu flags and check if guest works correctly.

    :param test: kvm test object.
    :param params: Dictionary with the test parameters.
    :param env: Dictionary with test environment.
    """
    utils_misc.Flag.aliases = utils_misc.kvm_map_flags_aliases
    qemu_binary = utils_misc.get_qemu_binary(params)

    cpuflags_src = os.path.join(data_dir.get_deps_dir("cpu_flags"), "src")
    cpuflags_def = os.path.join(data_dir.get_deps_dir("cpu_flags"),
                                "cpu_map.xml")
    smp = int(params.get("smp", 1))

    all_host_supported_flags = params.get("all_host_supported_flags", "no")

    mig_timeout = float(params.get("mig_timeout", "3600"))
    mig_protocol = params.get("migration_protocol", "tcp")
    mig_speed = params.get("mig_speed", "1G")

    cpu_model_black_list = params.get("cpu_model_blacklist", "").split(" ")

    multi_host_migration = params.get("multi_host_migration", "no")

    class HgFlags(object):

        def __init__(self, cpu_model, extra_flags=set([])):
            virtual_flags = set(map(utils_misc.Flag,
                                    params.get("guest_spec_flags", "").split()))
            self.hw_flags = set(map(utils_misc.Flag,
                                    params.get("host_spec_flags", "").split()))
            self.qemu_support_flags = get_all_qemu_flags()
            self.host_support_flags = set(map(utils_misc.Flag,
                                              utils_misc.get_cpu_flags()))
            self.quest_cpu_model_flags = (get_guest_host_cpuflags(cpu_model) -
                                          virtual_flags)

            self.supported_flags = (self.qemu_support_flags &
                                    self.host_support_flags)
            self.cpumodel_unsupport_flags = (self.supported_flags -
                                             self.quest_cpu_model_flags)

            self.host_unsupported_flags = (self.quest_cpu_model_flags -
                                           self.host_support_flags)

            self.all_possible_guest_flags = (self.quest_cpu_model_flags -
                                             self.host_unsupported_flags)
            self.all_possible_guest_flags |= self.cpumodel_unsupport_flags

            self.guest_flags = (self.quest_cpu_model_flags -
                                self.host_unsupported_flags)
            self.guest_flags |= extra_flags

            self.host_all_unsupported_flags = set([])
            self.host_all_unsupported_flags |= self.qemu_support_flags
            self.host_all_unsupported_flags -= (self.host_support_flags |
                                                virtual_flags)

    def start_guest_with_cpuflags(cpuflags, smp=None, migration=False,
                                  wait=True):
        """
        Try to boot guest with special cpu flags and try login in to them.
        """
        params_b = params.copy()
        params_b["cpu_model"] = cpuflags
        if smp is not None:
            params_b["smp"] = smp

        vm_name = "vm1-cpuflags"
        vm = qemu_vm.VM(vm_name, params_b, test.bindir, env['address_cache'])
        env.register_vm(vm_name, vm)
        if (migration is True):
            vm.create(migration_mode=mig_protocol)
        else:
            vm.create()

        session = None
        try:
            vm.verify_alive()

            if wait:
                session = vm.wait_for_login()
        except qemu_vm.ImageUnbootableError:
            vm.destroy(gracefully=False)
            raise

        return (vm, session)

    def get_guest_system_cpuflags(vm_session):
        """
        Get guest system cpuflags.

        :param vm_session: session to checked vm.
        :return: [corespond flags]
        """
        flags_re = re.compile(r'^flags\s*:(.*)$', re.MULTILINE)
        out = vm_session.cmd_output("cat /proc/cpuinfo")

        flags = flags_re.search(out).groups()[0].split()
        return set(map(utils_misc.Flag, flags))

    def get_guest_host_cpuflags_legacy(cpumodel):
        """
        Get cpu flags correspond with cpumodel parameters.

        :param cpumodel: Cpumodel parameter sended to <qemu-kvm-cmd>.
        :return: [corespond flags]
        """
        cmd = qemu_binary + " -cpu ?dump"
        output = utils.run(cmd).stdout
        re.escape(cpumodel)
        pattern = (r".+%s.*\n.*\n +feature_edx .+ \((.*)\)\n +feature_"
                   "ecx .+ \((.*)\)\n +extfeature_edx .+ \((.*)\)\n +"
                   "extfeature_ecx .+ \((.*)\)\n" % (cpumodel))
        flags = []
        model = re.search(pattern, output)
        if model is None:
            raise error.TestFail("Cannot find %s cpu model." % (cpumodel))
        for flag_group in model.groups():
            flags += flag_group.split()
        return set(map(utils_misc.Flag, flags))

    class ParseCpuFlags(object):

        def __init__(self, encoding=None):
            self.cpus = {}
            self.parser = expat.ParserCreate(encoding)
            self.parser.StartElementHandler = self.start_element
            self.parser.EndElementHandler = self.end_element
            self.last_arch = None
            self.last_model = None
            self.sub_model = False
            self.all_flags = []

        def start_element(self, name, attrs):
            if name == "cpus":
                self.cpus = {}
            elif name == "arch":
                self.last_arch = self.cpus[attrs['name']] = {}
            elif name == "model":
                if self.last_model is None:
                    self.last_model = self.last_arch[attrs['name']] = []
                else:
                    self.last_model += self.last_arch[attrs['name']]
                    self.sub_model = True
            elif name == "feature":
                if self.last_model is not None:
                    self.last_model.append(attrs['name'])
                else:
                    self.all_flags.append(attrs['name'])

        def end_element(self, name):
            if name == "arch":
                self.last_arch = None
            elif name == "model":
                if self.sub_model is False:
                    self.last_model = None
                else:
                    self.sub_model = False

        def parse_file(self, file_path):
            self.parser.ParseFile(open(file_path, 'r'))
            return self.cpus

    def get_guest_host_cpuflags_1350(cpumodel):
        """
        Get cpu flags correspond with cpumodel parameters.

        :param cpumodel: Cpumodel parameter sended to <qemu-kvm-cmd>.
        :return: [corespond flags]
        """
        p = ParseCpuFlags()
        cpus = p.parse_file(cpuflags_def)
        for arch in cpus.values():
            if cpumodel in arch.keys():
                flags = arch[cpumodel]
        return set(map(utils_misc.Flag, flags))

    get_guest_host_cpuflags_BAD = get_guest_host_cpuflags_1350

    def get_all_qemu_flags_legacy():
        cmd = qemu_binary + " -cpu ?cpuid"
        output = utils.run(cmd).stdout

        flags_re = re.compile(r".*\n.*f_edx:(.*)\n.*f_ecx:(.*)\n"
                              ".*extf_edx:(.*)\n.*extf_ecx:(.*)")
        m = flags_re.search(output)
        flags = []
        for a in m.groups():
            flags += a.split()

        return set(map(utils_misc.Flag, flags))

    def get_all_qemu_flags_1350():
        cmd = qemu_binary + " -cpu ?"
        output = utils.run(cmd).stdout

        flags_re = re.compile(r".*Recognized CPUID flags:\n(.*)", re.DOTALL)
        m = flags_re.search(output)
        flags = []
        for a in m.groups():
            flags += a.split()

        return set(map(utils_misc.Flag, flags))

    def get_all_qemu_flags_BAD():
        """
        Get cpu flags correspond with cpumodel parameters.

        :param cpumodel: Cpumodel parameter sended to <qemu-kvm-cmd>.
        :return: [corespond flags]
        """
        p = ParseCpuFlags()
        p.parse_file(cpuflags_def)
        return set(map(utils_misc.Flag, p.all_flags))

    def get_cpu_models_legacy():
        """
        Get all cpu models from qemu.

        :return: cpu models.
        """
        cmd = qemu_binary + " -cpu ?"
        output = utils.run(cmd).stdout

        cpu_re = re.compile(r"\w+\s+\[?(\w+)\]?")
        return cpu_re.findall(output)

    def get_cpu_models_1350():
        """
        Get all cpu models from qemu.

        :return: cpu models.
        """
        cmd = qemu_binary + " -cpu ?"
        output = utils.run(cmd).stdout

        cpu_re = re.compile(r"x86\s+\[?(\w+)\]?")
        return cpu_re.findall(output)

    get_cpu_models_BAD = get_cpu_models_1350

    def get_qemu_cpu_cmd_version():
        cmd = qemu_binary + " -cpu ?cpuid"
        try:
            utils.run(cmd).stdout
            return "legacy"
        except:
            cmd = qemu_binary + " -cpu ?"
            output = utils.run(cmd).stdout
            if "CPUID" in output:
                return "1350"
            else:
                return "BAD"

    qcver = get_qemu_cpu_cmd_version()

    get_guest_host_cpuflags = locals()["get_guest_host_cpuflags_%s" % qcver]
    get_all_qemu_flags = locals()["get_all_qemu_flags_%s" % qcver]
    get_cpu_models = locals()["get_cpu_models_%s" % qcver]

    def get_flags_full_name(cpu_flag):
        """
        Get all name of Flag.

        :param cpu_flag: Flag
        :return: all name of Flag.
        """
        cpu_flag = utils_misc.Flag(cpu_flag)
        for f in get_all_qemu_flags():
            if f == cpu_flag:
                return utils_misc.Flag(f)
        return []

    def parse_qemu_cpucommand(cpumodel):
        """
        Parse qemu cpu params.

        :param cpumodel: Cpu model command.
        :return: All flags which guest must have.
        """
        flags = cpumodel.split(",")
        cpumodel = flags[0]

        qemu_model_flag = get_guest_host_cpuflags(cpumodel)
        host_support_flag = set(map(utils_misc.Flag,
                                    utils_misc.get_cpu_flags()))
        real_flags = qemu_model_flag & host_support_flag

        for f in flags[1:]:
            if f[0].startswith("+"):
                real_flags |= set([get_flags_full_name(f[1:])])
            if f[0].startswith("-"):
                real_flags -= set([get_flags_full_name(f[1:])])

        return real_flags

    def check_cpuflags(cpumodel, vm_session):
        """
        Check if vm flags are same like flags select by cpumodel.

        :param cpumodel: params for -cpu param in qemu-kvm
        :param vm_session: session to vm to check flags.

        :return: ([excess], [missing]) flags
        """
        gf = get_guest_system_cpuflags(vm_session)
        rf = parse_qemu_cpucommand(cpumodel)

        logging.debug("Guest flags: %s", gf)
        logging.debug("Host flags: %s", rf)
        logging.debug("Flags on guest not defined by host: %s", (gf - rf))
        return rf - gf

    def get_cpu_models_supported_by_host():
        """
        Get all cpumodels which set of flags is subset of hosts flags.

        :return: [cpumodels]
        """
        cpumodels = []
        for cpumodel in get_cpu_models():
            flags = HgFlags(cpumodel)
            if flags.host_unsupported_flags == set([]):
                cpumodels.append(cpumodel)
        return cpumodels

    def disable_cpu(vm_session, cpu, disable=True):
        """
        Disable cpu in guest system.

        :param cpu: CPU id to disable.
        :param disable: if True disable cpu else enable cpu.
        """
        system_cpu_dir = "/sys/devices/system/cpu/"
        cpu_online = system_cpu_dir + "cpu%d/online" % (cpu)
        cpu_state = vm_session.cmd_output("cat %s" % cpu_online).strip()
        if disable and cpu_state == "1":
            vm_session.cmd("echo 0 > %s" % cpu_online)
            logging.debug("Guest cpu %d is disabled.", cpu)
        elif cpu_state == "0":
            vm_session.cmd("echo 1 > %s" % cpu_online)
            logging.debug("Guest cpu %d is enabled.", cpu)

    def check_online_cpus(vm_session, smp, disabled_cpu):
        """
        Disable cpu in guest system.

        :param smp: Count of cpu core in system.
        :param disable_cpu: List of disabled cpu.

        :return: List of CPUs that are still enabled after disable procedure.
        """
        online = [0]
        for cpu in range(1, smp):
            system_cpu_dir = "/sys/devices/system/cpu/"
            cpu_online = system_cpu_dir + "cpu%d/online" % (cpu)
            cpu_state = vm_session.cmd_output("cat %s" % cpu_online).strip()
            if cpu_state == "1":
                online.append(cpu)
        cpu_proc = vm_session.cmd_output("cat /proc/cpuinfo")
        cpu_state_proc = map(lambda x: int(x),
                             re.findall(r"processor\s+:\s*(\d+)\n", cpu_proc))
        if set(online) != set(cpu_state_proc):
            raise error.TestError("Some cpus are disabled but %s are still "
                                  "visible like online in /proc/cpuinfo." %
                                  (set(cpu_state_proc) - set(online)))

        return set(online) - set(disabled_cpu)

    def install_cpuflags_test_on_vm(vm, dst_dir):
        """
        Install stress to vm.

        :param vm: virtual machine.
        :param dst_dir: Installation path.
        """
        session = vm.wait_for_login()
        vm.copy_files_to(cpuflags_src, dst_dir)
        session.cmd("sync")
        session.cmd("cd %s; make EXTRA_FLAGS='';" %
                    os.path.join(dst_dir, "cpu_flags"))
        session.cmd("sync")
        session.close()

    def check_cpuflags_work(vm, path, flags):
        """
        Check which flags work.

        :param vm: Virtual machine.
        :param path: Path of cpuflags_test
        :param flags: Flags to test.
        :return: Tuple (Working, not working, not tested) flags.
        """
        pass_Flags = []
        not_tested = []
        not_working = []
        session = vm.wait_for_login()
        for f in flags:
            try:
                for tc in utils_misc.kvm_map_flags_to_test[f]:
                    session.cmd("%s/cpuflags-test --%s" %
                                (os.path.join(path, "cpu_flags"), tc))
                pass_Flags.append(f)
            except aexpect.ShellCmdError:
                not_working.append(f)
            except KeyError:
                not_tested.append(f)
        return (set(map(utils_misc.Flag, pass_Flags)),
                set(map(utils_misc.Flag, not_working)),
                set(map(utils_misc.Flag, not_tested)))

    def run_stress(vm, timeout, guest_flags):
        """
        Run stress on vm for timeout time.
        """
        ret = False
        install_path = "/tmp"
        install_cpuflags_test_on_vm(vm, install_path)
        flags = check_cpuflags_work(vm, install_path, guest_flags)
        dd_session = vm.wait_for_login()
        stress_session = vm.wait_for_login()
        dd_session.sendline("dd if=/dev/[svh]da of=/tmp/stressblock"
                            " bs=10MB count=100 &")
        try:
            stress_session.cmd("%s/cpuflags-test --stress %s%s" %
                               (os.path.join(install_path, "cpu_flags"), smp,
                                utils_misc.kvm_flags_to_stresstests(flags[0])),
                               timeout=timeout)
        except aexpect.ShellTimeoutError:
            ret = True
        stress_session.close()
        dd_session.close()
        return ret

    def separe_cpu_model(cpu_model):
        try:
            (cpu_model, _) = cpu_model.split(":")
        except ValueError:
            cpu_model = cpu_model
        return cpu_model

    def parse_cpu_model():
        """
        Parse cpu_models from config file.

        :return: [(cpumodel, extra_flags)]
        """
        cpu_model = params.get("cpu_model", "")
        logging.debug("CPU model found: %s", str(cpu_model))

        try:
            (cpu_model, extra_flags) = cpu_model.split(":")
            extra_flags = set(map(utils_misc.Flag, extra_flags.split(",")))
        except ValueError:
            cpu_model = cpu_model
            extra_flags = set([])
        return (cpu_model, extra_flags)

    class MiniSubtest(object):

        def __new__(cls, *args, **kargs):
            self = super(MiniSubtest, cls).__new__(cls)
            ret = None
            if args is None:
                args = []
            try:
                ret = self.test(*args, **kargs)
            finally:
                if hasattr(self, "clean"):
                    self.clean()
            return ret

    def print_exception(called_object):
        exc_type, exc_value, exc_traceback = sys.exc_info()
        logging.error("In function (" + called_object.__name__ + "):")
        logging.error("Call from:\n" +
                      traceback.format_stack()[-2][:-1])
        logging.error("Exception from:\n" +
                      "".join(traceback.format_exception(
                              exc_type, exc_value,
                              exc_traceback.tb_next)))

    class Test_temp(MiniSubtest):

        def clean(self):
            logging.info("cleanup")
            if (hasattr(self, "vm")):
                vm = getattr(self, "vm")
                vm.destroy(gracefully=False)

    # 1) <qemu-kvm-cmd> -cpu ?model
    class test_qemu_cpu_model(MiniSubtest):

        def test(self):
            if qcver == "legacy":
                cpu_models = params.get("cpu_models", "core2duo").split()
                cmd = qemu_binary + " -cpu ?model"
                result = utils.run(cmd)
                missing = []
                cpu_models = map(separe_cpu_model, cpu_models)
                for cpu_model in cpu_models:
                    if cpu_model not in result.stdout:
                        missing.append(cpu_model)
                if missing:
                    raise error.TestFail("CPU models %s are not in output "
                                         "'%s' of command \n%s" %
                                         (missing, cmd, result.stdout))
            elif qcver == "1350":
                raise error.TestNAError("New qemu use new -cpu ? cmd.")

    # 2) <qemu-kvm-cmd> -cpu ?dump
    class test_qemu_dump(MiniSubtest):

        def test(self):
            if qcver == "legacy":
                cpu_models = params.get("cpu_models", "core2duo").split()
                cmd = qemu_binary + " -cpu ?dump"
                result = utils.run(cmd)
                cpu_models = map(separe_cpu_model, cpu_models)
                missing = []
                for cpu_model in cpu_models:
                    if cpu_model not in result.stdout:
                        missing.append(cpu_model)
                if missing:
                    raise error.TestFail("CPU models %s are not in output "
                                         "'%s' of command \n%s" %
                                         (missing, cmd, result.stdout))
            elif qcver == "1350":
                raise error.TestNAError(
                    "New qemu does not support -cpu ?dump.")

    # 3) <qemu-kvm-cmd> -cpu ?cpuid
    class test_qemu_cpuid(MiniSubtest):

        def test(self):
            if qcver == "legacy":
                cmd = qemu_binary + " -cpu ?cpuid"
                result = utils.run(cmd)
                if result.stdout is "":
                    raise error.TestFail("There aren't any cpu Flag in output"
                                         " '%s' of command \n%s" %
                                         (cmd, result.stdout))
            elif qcver == "1350":
                raise error.TestNAError("New qemu use new -cpu ? cmd.")

    # 1) boot with cpu_model
    class test_boot_cpu_model(Test_temp):

        def test(self):
            cpu_model, _ = parse_cpu_model()
            logging.debug("Run tests with cpu model %s", cpu_model)
            flags = HgFlags(cpu_model)
            (self.vm, session) = start_guest_with_cpuflags(cpu_model)
            not_enable_flags = (check_cpuflags(cpu_model, session) -
                                flags.hw_flags)
            if not_enable_flags != set([]):
                raise error.TestFail("Flags defined on host but not found "
                                     "on guest: %s" % (not_enable_flags))

    # 2) success boot with supported flags
    class test_boot_cpu_model_and_additional_flags(Test_temp):

        def test(self):
            cpu_model, extra_flags = parse_cpu_model()

            flags = HgFlags(cpu_model, extra_flags)

            logging.debug("Cpu mode flags %s.",
                          str(flags.quest_cpu_model_flags))
            cpuf_model = cpu_model

            if all_host_supported_flags == "yes":
                for fadd in flags.cpumodel_unsupport_flags:
                    cpuf_model += ",+" + str(fadd)
            else:
                for fadd in extra_flags:
                    cpuf_model += ",+" + str(fadd)

            for fdel in flags.host_unsupported_flags:
                cpuf_model += ",-" + str(fdel)

            if all_host_supported_flags == "yes":
                guest_flags = flags.all_possible_guest_flags
            else:
                guest_flags = flags.guest_flags

            (self.vm, session) = start_guest_with_cpuflags(cpuf_model)

            not_enable_flags = (check_cpuflags(cpuf_model, session) -
                                flags.hw_flags)
            if not_enable_flags != set([]):
                logging.info("Model unsupported flags: %s",
                             str(flags.cpumodel_unsupport_flags))
                logging.error("Flags defined on host but not on found "
                              "on guest: %s", str(not_enable_flags))
            logging.info("Check main instruction sets.")

            install_path = "/tmp"
            install_cpuflags_test_on_vm(self.vm, install_path)

            Flags = check_cpuflags_work(self.vm, install_path,
                                        flags.all_possible_guest_flags)
            logging.info("Woking CPU flags: %s", str(Flags[0]))
            logging.info("Not working CPU flags: %s", str(Flags[1]))
            logging.warning("Flags works even if not defined on guest cpu "
                            "flags: %s", str(Flags[0] - guest_flags))
            logging.warning("Not tested CPU flags: %s", str(Flags[2]))

            if Flags[1] & guest_flags:
                raise error.TestFail("Some flags do not work: %s" %
                                     (str(Flags[1])))

    # 3) fail boot unsupported flags
    class test_boot_warn_with_host_unsupported_flags(MiniSubtest):

        def test(self):
            # This is virtual cpu flags which are supported by
            # qemu but no with host cpu.
            cpu_model, extra_flags = parse_cpu_model()

            flags = HgFlags(cpu_model, extra_flags)

            logging.debug("Unsupported flags %s.",
                          str(flags.host_all_unsupported_flags))
            cpuf_model = cpu_model + ",check"

            # Add unsupported flags.
            for fadd in flags.host_all_unsupported_flags:
                cpuf_model += ",+" + str(fadd)

            vnc_port = utils_misc.find_free_port(5900, 6100) - 5900
            cmd = "%s -cpu %s -vnc :%d -enable-kvm" % (qemu_binary,
                                                       cpuf_model,
                                                       vnc_port)
            out = None

            try:
                try:
                    out = utils.run(cmd, timeout=5, ignore_status=True).stderr
                    raise error.TestFail("Guest not boot with unsupported "
                                         "flags.")
                except error.CmdError, e:
                    out = e.result_obj.stderr
            finally:
                uns_re = re.compile(r"^warning:.*flag '(.+)'", re.MULTILINE)
                nf_re = re.compile(
                    r"^CPU feature (.+) not found", re.MULTILINE)
                warn_flags = set([utils_misc.Flag(x)
                                  for x in uns_re.findall(out)])
                not_found = set([utils_misc.Flag(x)
                                 for x in nf_re.findall(out)])
                fwarn_flags = flags.host_all_unsupported_flags - warn_flags
                fwarn_flags -= not_found
                if fwarn_flags:
                    raise error.TestFail("Qemu did not warn the use of "
                                         "flags %s" % str(fwarn_flags))

    # 3) fail boot unsupported flags
    class test_fail_boot_with_host_unsupported_flags(MiniSubtest):

        def test(self):
            # This is virtual cpu flags which are supported by
            # qemu but no with host cpu.
            cpu_model, extra_flags = parse_cpu_model()

            flags = HgFlags(cpu_model, extra_flags)
            cpuf_model = cpu_model + ",enforce"

            logging.debug("Unsupported flags %s.",
                          str(flags.host_all_unsupported_flags))

            # Add unsupported flags.
            for fadd in flags.host_all_unsupported_flags:
                cpuf_model += ",+" + str(fadd)

            vnc_port = utils_misc.find_free_port(5900, 6100) - 5900
            cmd = "%s -cpu %s -vnc :%d -enable-kvm" % (qemu_binary,
                                                       cpuf_model,
                                                       vnc_port)
            out = None
            try:
                try:
                    out = utils.run(cmd, timeout=5, ignore_status=True).stderr
                except error.CmdError:
                    logging.error("Host boot with unsupported flag")
            finally:
                uns_re = re.compile(r"^warning:.*flag '(.+)'", re.MULTILINE)
                nf_re = re.compile(
                    r"^CPU feature (.+) not found", re.MULTILINE)
                warn_flags = set([utils_misc.Flag(x)
                                  for x in uns_re.findall(out)])
                not_found = set([utils_misc.Flag(x)
                                 for x in nf_re.findall(out)])
                fwarn_flags = flags.host_all_unsupported_flags - warn_flags
                fwarn_flags -= not_found
                if fwarn_flags:
                    raise error.TestFail("Qemu did not warn the use of "
                                         "flags %s" % str(fwarn_flags))

    # 4) check guest flags under load cpu, stress and system (dd)
    class test_boot_guest_and_try_flags_under_load(Test_temp):

        def test(self):
            logging.info("Check guest working cpuflags under load "
                         "cpu and stress and system (dd)")
            cpu_model, extra_flags = parse_cpu_model()

            flags = HgFlags(cpu_model, extra_flags)

            cpuf_model = cpu_model

            logging.debug("Cpu mode flags %s.",
                          str(flags.quest_cpu_model_flags))

            if all_host_supported_flags == "yes":
                logging.debug("Added flags %s.",
                              str(flags.cpumodel_unsupport_flags))

                # Add unsupported flags.
                for fadd in flags.cpumodel_unsupport_flags:
                    cpuf_model += ",+" + str(fadd)

                for fdel in flags.host_unsupported_flags:
                    cpuf_model += ",-" + str(fdel)

            (self.vm, _) = start_guest_with_cpuflags(cpuf_model, smp)

            if (not run_stress(self.vm, 60, flags.guest_flags)):
                raise error.TestFail("Stress test ended before"
                                     " end of test.")

        def clean(self):
            logging.info("cleanup")
            self.vm.destroy(gracefully=False)

    # 5) Online/offline CPU
    class test_online_offline_guest_CPUs(Test_temp):

        def test(self):
            cpu_model, extra_flags = parse_cpu_model()

            logging.debug("Run tests with cpu model %s.", (cpu_model))
            flags = HgFlags(cpu_model, extra_flags)

            (self.vm, session) = start_guest_with_cpuflags(cpu_model, smp)

            def encap(timeout):
                random.seed()
                begin = time.time()
                end = begin
                if smp > 1:
                    while end - begin < 60:
                        cpu = random.randint(1, smp - 1)
                        if random.randint(0, 1):
                            disable_cpu(session, cpu, True)
                        else:
                            disable_cpu(session, cpu, False)
                        end = time.time()
                    return True
                else:
                    logging.warning("For this test is necessary smp > 1.")
                    return False
            timeout = 60

            test_flags = flags.guest_flags
            if all_host_supported_flags == "yes":
                test_flags = flags.all_possible_guest_flags

            result = utils_misc.parallel([(encap, [timeout]),
                                          (run_stress, [self.vm, timeout,
                                                        test_flags])])
            if not (result[0] and result[1]):
                raise error.TestFail("Stress tests failed before"
                                     " end of testing.")

    # 6) migration test
    class test_migration_with_additional_flags(Test_temp):

        def test(self):
            cpu_model, extra_flags = parse_cpu_model()

            flags = HgFlags(cpu_model, extra_flags)

            logging.debug("Cpu mode flags %s.",
                          str(flags.quest_cpu_model_flags))
            logging.debug("Added flags %s.",
                          str(flags.cpumodel_unsupport_flags))
            cpuf_model = cpu_model

            # Add unsupported flags.
            for fadd in flags.cpumodel_unsupport_flags:
                cpuf_model += ",+" + str(fadd)

            for fdel in flags.host_unsupported_flags:
                cpuf_model += ",-" + str(fdel)

            (self.vm, _) = start_guest_with_cpuflags(cpuf_model, smp)

            install_path = "/tmp"
            install_cpuflags_test_on_vm(self.vm, install_path)
            flags = check_cpuflags_work(self.vm, install_path,
                                        flags.guest_flags)
            dd_session = self.vm.wait_for_login()
            stress_session = self.vm.wait_for_login()

            dd_session.sendline("nohup dd if=/dev/[svh]da of=/tmp/"
                                "stressblock bs=10MB count=100 &")
            cmd = ("nohup %s/cpuflags-test --stress  %s%s &" %
                   (os.path.join(install_path, "cpu_flags"), smp,
                    utils_misc.kvm_flags_to_stresstests(flags[0])))
            stress_session.sendline(cmd)

            time.sleep(5)

            self.vm.monitor.migrate_set_speed(mig_speed)
            self.clone = self.vm.migrate(
                mig_timeout, mig_protocol, offline=False,
                not_wait_for_migration=True)

            time.sleep(5)

            try:
                self.vm.wait_for_migration(10)
            except virt_vm.VMMigrateTimeoutError:
                self.vm.monitor.migrate_set_downtime(1)
                self.vm.wait_for_migration(mig_timeout)

            # Swap due to test cleaning.
            temp = self.vm.clone(copy_state=True)
            self.vm.__dict__ = self.clone.__dict__
            self.clone = temp

            self.vm.resume()
            self.clone.destroy(gracefully=False)

            stress_session = self.vm.wait_for_login()

            # If cpuflags-test hang up during migration test raise exception
            try:
                stress_session.cmd('killall cpuflags-test')
            except aexpect.ShellCmdError:
                raise error.TestFail("Cpuflags-test should work after"
                                     " migration.")

    def net_send_object(socket, obj):
        """
        Send python object over network.

        :param ip_addr: ipaddres of waiter for data.
        :param obj: object to send
        """
        data = pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)
        socket.sendall("%6d" % len(data))
        socket.sendall(data)

    def net_recv_object(socket, timeout=60):
        """
        Receive python object over network.

        :param ip_addr: ipaddres of waiter for data.
        :param obj: object to send
        :return: object from network
        """
        try:
            time_start = time.time()
            data = ""
            d_len = int(socket.recv(6))

            while (len(data) < d_len and (time.time() - time_start) < timeout):
                data += socket.recv(d_len - len(data))

            data = pickle.loads(data)
            return data
        except:
            error.TestFail("Failed to receive python object over the network")
            raise

    class test_multi_host_migration(Test_temp):

        def test(self):
            """
            Test migration between multiple hosts.
            """
            cpu_model, extra_flags = parse_cpu_model()

            flags = HgFlags(cpu_model, extra_flags)

            logging.debug("Cpu mode flags %s.",
                          str(flags.quest_cpu_model_flags))
            logging.debug("Added flags %s.",
                          str(flags.cpumodel_unsupport_flags))
            cpuf_model = cpu_model

            for fadd in extra_flags:
                cpuf_model += ",+" + str(fadd)

            for fdel in flags.host_unsupported_flags:
                cpuf_model += ",-" + str(fdel)

            install_path = "/tmp"

            class testMultihostMigration(migration.MultihostMigration):

                def __init__(self, test, params, env):
                    migration.MultihostMigration.__init__(self, test, params,
                                                          env)

                def migration_scenario(self):
                    srchost = self.params.get("hosts")[0]
                    dsthost = self.params.get("hosts")[1]

                    def worker(mig_data):
                        vm = env.get_vm("vm1")
                        session = vm.wait_for_login(timeout=self.login_timeout)

                        install_cpuflags_test_on_vm(vm, install_path)

                        Flags = check_cpuflags_work(vm, install_path,
                                                    flags.all_possible_guest_flags)
                        logging.info("Woking CPU flags: %s", str(Flags[0]))
                        logging.info("Not working CPU flags: %s",
                                     str(Flags[1]))
                        logging.warning("Flags works even if not defined on"
                                        " guest cpu flags: %s",
                                        str(Flags[0] - flags.guest_flags))
                        logging.warning("Not tested CPU flags: %s",
                                        str(Flags[2]))
                        session.sendline("nohup dd if=/dev/[svh]da of=/tmp/"
                                         "stressblock bs=10MB count=100 &")

                        cmd = ("nohup %s/cpuflags-test --stress  %s%s &" %
                               (os.path.join(install_path, "cpu_flags"),
                                smp,
                                utils_misc.kvm_flags_to_stresstests(Flags[0] &
                                                                    flags.guest_flags)))
                        logging.debug("Guest_flags: %s",
                                      str(flags.guest_flags))
                        logging.debug("Working_flags: %s", str(Flags[0]))
                        logging.debug("Start stress on guest: %s", cmd)
                        session.sendline(cmd)

                    def check_worker(mig_data):
                        vm = env.get_vm("vm1")

                        vm.verify_illegal_instruction()

                        session = vm.wait_for_login(timeout=self.login_timeout)

                        try:
                            session.cmd('killall cpuflags-test')
                        except aexpect.ShellCmdError:
                            raise error.TestFail("The cpuflags-test program"
                                                 " should be active after"
                                                 " migration and it's not.")

                        Flags = check_cpuflags_work(vm, install_path,
                                                    flags.all_possible_guest_flags)
                        logging.info("Woking CPU flags: %s",
                                     str(Flags[0]))
                        logging.info("Not working CPU flags: %s",
                                     str(Flags[1]))
                        logging.warning("Flags works even if not defined on"
                                        " guest cpu flags: %s",
                                        str(Flags[0] - flags.guest_flags))
                        logging.warning("Not tested CPU flags: %s",
                                        str(Flags[2]))

                    self.migrate_wait(["vm1"], srchost, dsthost,
                                      worker, check_worker)

            params_b = params.copy()
            params_b["cpu_model"] = cpu_model
            mig = testMultihostMigration(test, params_b, env)
            mig.run()

    class test_multi_host_migration_onoff_cpu(Test_temp):

        def test(self):
            """
            Test migration between multiple hosts.
            """
            cpu_model, extra_flags = parse_cpu_model()

            flags = HgFlags(cpu_model, extra_flags)

            logging.debug("Cpu mode flags %s.",
                          str(flags.quest_cpu_model_flags))
            logging.debug("Added flags %s.",
                          str(flags.cpumodel_unsupport_flags))
            cpuf_model = cpu_model

            for fadd in extra_flags:
                cpuf_model += ",+" + str(fadd)

            for fdel in flags.host_unsupported_flags:
                cpuf_model += ",-" + str(fdel)

            smp = int(params["smp"])
            disable_cpus = map(lambda cpu: int(cpu),
                               params.get("disable_cpus", "").split())

            install_path = "/tmp"

            class testMultihostMigration(migration.MultihostMigration):

                def __init__(self, test, params, env):
                    migration.MultihostMigration.__init__(self, test, params,
                                                          env)
                    self.srchost = self.params.get("hosts")[0]
                    self.dsthost = self.params.get("hosts")[1]
                    self.id = {'src': self.srchost,
                               'dst': self.dsthost,
                               "type": "disable_cpu"}
                    self.migrate_count = int(self.params.get('migrate_count',
                                                             '2'))

                def ping_pong_migrate(self, sync, worker, check_worker):
                    for _ in range(self.migrate_count):
                        logging.info("File transfer not ended, starting"
                                     " a round of migration...")
                        sync.sync(True, timeout=mig_timeout)
                        if self.hostid == self.srchost:
                            self.migrate_wait(["vm1"],
                                              self.srchost,
                                              self.dsthost,
                                              start_work=worker)
                        elif self.hostid == self.dsthost:
                            self.migrate_wait(["vm1"],
                                              self.srchost,
                                              self.dsthost,
                                              check_work=check_worker)
                        tmp = self.dsthost
                        self.dsthost = self.srchost
                        self.srchost = tmp

                def migration_scenario(self):

                    sync = SyncData(self.master_id(), self.hostid, self.hosts,
                                    self.id, self.sync_server)

                    def worker(mig_data):
                        vm = env.get_vm("vm1")
                        session = vm.wait_for_login(timeout=self.login_timeout)

                        install_cpuflags_test_on_vm(vm, install_path)

                        Flags = check_cpuflags_work(vm, install_path,
                                                    flags.all_possible_guest_flags)
                        logging.info("Woking CPU flags: %s", str(Flags[0]))
                        logging.info("Not working CPU flags: %s",
                                     str(Flags[1]))
                        logging.warning("Flags works even if not defined on"
                                        " guest cpu flags: %s",
                                        str(Flags[0] - flags.guest_flags))
                        logging.warning("Not tested CPU flags: %s",
                                        str(Flags[2]))
                        for cpu in disable_cpus:
                            if cpu < smp:
                                disable_cpu(session, cpu, True)
                            else:
                                logging.warning("There is no enouth cpu"
                                                " in Guest. It is trying to"
                                                "remove cpu:%s from guest with"
                                                " smp:%s." % (cpu, smp))
                        logging.debug("Guest_flags: %s",
                                      str(flags.guest_flags))
                        logging.debug("Working_flags: %s", str(Flags[0]))

                    def check_worker(mig_data):
                        vm = env.get_vm("vm1")

                        vm.verify_illegal_instruction()

                        session = vm.wait_for_login(timeout=self.login_timeout)

                        really_disabled = check_online_cpus(session, smp,
                                                            disable_cpus)

                        not_disabled = set(really_disabled) & set(disable_cpus)
                        if not_disabled:
                            raise error.TestFail("Some of disabled cpus are "
                                                 "online. This shouldn't "
                                                 "happen. Cpus disabled on "
                                                 "srchost:%s, Cpus not "
                                                 "disabled on dsthost:%s" %
                                                 (disable_cpus, not_disabled))

                        Flags = check_cpuflags_work(vm, install_path,
                                                    flags.all_possible_guest_flags)
                        logging.info("Woking CPU flags: %s",
                                     str(Flags[0]))
                        logging.info("Not working CPU flags: %s",
                                     str(Flags[1]))
                        logging.warning("Flags works even if not defined on"
                                        " guest cpu flags: %s",
                                        str(Flags[0] - flags.guest_flags))
                        logging.warning("Not tested CPU flags: %s",
                                        str(Flags[2]))

                    self.ping_pong_migrate(sync, worker, check_worker)

            params_b = params.copy()
            params_b["cpu_model"] = cpu_model
            mig = testMultihostMigration(test, params_b, env)
            mig.run()

    test_type = params.get("test_type")
    if (test_type in locals()):
        tests_group = locals()[test_type]
        if params.get("cpu_model"):
            tests_group()
        else:
            cpu_models = (set(get_cpu_models_supported_by_host()) -
                          set(cpu_model_black_list))
            logging.info("Start test with cpu models %s" % (str(cpu_models)))
            failed = []
            for cpumodel in cpu_models:
                params["cpu_model"] = cpumodel
                try:
                    tests_group()
                except:
                    print_exception(tests_group)
                    failed.append(cpumodel)
            if failed != []:
                raise error.TestFail("Test of cpu models %s failed." %
                                     (str(failed)))
    else:
        raise error.TestFail("Test group '%s' is not defined in"
                             " cpuflags test" % test_type)

Example 9

Project: tp-qemu
Source File: cpuid.py
View license
def run(test, params, env):
    """
    Boot guest with different cpu_models and cpu flags and check if guest works correctly.

    :param test: kvm test object.
    :param params: Dictionary with the test parameters.
    :param env: Dictionary with test environment.
    """
    qemu_binary = utils_misc.get_qemu_binary(params)

    cpu_model = params.get("cpu_model", "qemu64")

    xfail = False
    if (params.get("xfail") is not None) and (params.get("xfail") == "yes"):
        xfail = True

    def cpu_models_to_test():
        """Return the list of CPU models to be tested, based on the
        cpu_models and cpu_model config options.

        Config option "cpu_model" may be used to ask a single CPU model
        to be tested. Config option "cpu_models" may be used to ask
        multiple CPU models to be tested.

        If cpu_models is "*", all CPU models reported by QEMU will be tested.
        """
        models_opt = params.get("cpu_models")
        model_opt = params.get("cpu_model")

        if (models_opt is None and model_opt is None):
            raise error.TestError("No cpu_models or cpu_model option is set")

        cpu_models = set()

        if models_opt == '*':
            cpu_models.update(utils_misc.get_qemu_cpu_models(qemu_binary))
        elif models_opt:
            cpu_models.update(models_opt.split())

        if model_opt:
            cpu_models.add(model_opt)

        return cpu_models

    def test_qemu_cpu_models_list(self):
        """
        check CPU models returned by <qemu> -cpu '?' are what is expected
        """
        """
        test method
        """
        cpu_models = cpu_models_to_test()
        qemu_models = utils_misc.get_qemu_cpu_models(qemu_binary)
        missing = set(cpu_models) - set(qemu_models)
        if missing:
            raise error.TestFail(
                "Some CPU models not in QEMU CPU model list: %r" % (missing))
        added = set(qemu_models) - set(cpu_models)
        if added:
            logging.info("Extra CPU models in QEMU CPU listing: %s", added)

    def compare_cpuid_output(a, b):
        """
        Generates a list of (bit, va, vb) tuples for
        each bit that is different between a and b.
        """
        for bit in range(32):
            ba = (a & (1 << bit)) >> bit
            if b is not None:
                bb = (b & (1 << bit)) >> bit
            else:
                bb = None
            if ba != bb:
                yield (bit, ba, bb)

    def parse_cpuid_dump(output):
        dbg("parsing cpuid dump: %r", output)
        cpuid_re = re.compile(
            "^ *(0x[0-9a-f]+) +0x([0-9a-f]+): +eax=0x([0-9a-f]+) ebx=0x([0-9a-f]+) ecx=0x([0-9a-f]+) edx=0x([0-9a-f]+)$")
        output_match = re.search('(==START TEST==.*==END TEST==)', output, re.M | re.DOTALL)
        if output_match is None:
            dbg("cpuid dump doesn't follow expected pattern")
            return None
        output = output_match.group(1)
        out_lines = output.splitlines()
        if out_lines[0] != '==START TEST==' or out_lines[-1] != '==END TEST==':
            dbg("cpuid dump doesn't have expected delimiters")
            return None
        if out_lines[1] != 'CPU:':
            dbg("cpuid dump doesn't start with 'CPU:' line")
            return None
        result = {}
        for l in out_lines[2:-1]:
            m = cpuid_re.match(l)
            if m is None:
                dbg("invalid cpuid dump line: %r", l)
                return None
            in_eax = int(m.group(1), 16)
            in_ecx = int(m.group(2), 16)
            result[in_eax, in_ecx, 'eax'] = int(m.group(3), 16)
            result[in_eax, in_ecx, 'ebx'] = int(m.group(4), 16)
            result[in_eax, in_ecx, 'ecx'] = int(m.group(5), 16)
            result[in_eax, in_ecx, 'edx'] = int(m.group(6), 16)
        return result

    def get_test_kernel_cpuid(self, vm):
        vm.resume()

        timeout = float(params.get("login_timeout", 240))
        logging.debug("Will wait for CPUID serial output at %r",
                      vm.serial_console)
        if not utils_misc.wait_for(lambda:
                                   re.search("==END TEST==",
                                             vm.serial_console.get_output()),
                                   timeout, 1):
            raise error.TestFail("Could not get test complete message.")

        test_output = parse_cpuid_dump(vm.serial_console.get_output())
        logging.debug("Got CPUID serial output: %r", test_output)
        if test_output is None:
            raise error.TestFail("Test output signature not found in "
                                 "output:\n %s", vm.serial_console.get_output())
        vm.destroy(gracefully=False)
        return test_output

    def find_cpu_obj(vm):
        """Find path of a valid VCPU object"""
        roots = ['/machine/icc-bridge/icc', '/machine/unattached/device']
        for root in roots:
            for child in vm.monitor.cmd('qom-list', dict(path=root)):
                logging.debug('child: %r', child)
                if child['type'].rstrip('>').endswith('-cpu'):
                    return root + '/' + child['name']

    def get_qom_cpuid(self, vm):
        assert vm.monitor.protocol == "qmp"
        cpu_path = find_cpu_obj(vm)
        logging.debug('cpu path: %r', cpu_path)
        r = {}
        for prop in 'feature-words', 'filtered-features':
            words = vm.monitor.cmd('qom-get', dict(path=cpu_path, property=prop))
            logging.debug('%s property: %r', prop, words)
            for w in words:
                reg = w['cpuid-register'].lower()
                key = (w['cpuid-input-eax'], w.get('cpuid-input-ecx', 0), reg)
                r.setdefault(key, 0)
                r[key] |= w['features']
        return r

    def get_guest_cpuid(self, cpu_model, feature=None, extra_params=None, qom_mode=False):
        if not qom_mode:
            test_kernel_dir = os.path.join(data_dir.get_deps_dir(), "cpuid", "src")
            os.chdir(test_kernel_dir)
            utils.make("cpuid_dump_kernel.bin")

        vm_name = params['main_vm']
        params_b = params.copy()
        if not qom_mode:
            params_b["kernel"] = os.path.join(
                test_kernel_dir, "cpuid_dump_kernel.bin")
        params_b["cpu_model"] = cpu_model
        params_b["cpu_model_flags"] = feature
        del params_b["images"]
        del params_b["nics"]
        if extra_params:
            params_b.update(extra_params)
        env_process.preprocess_vm(self, params_b, env, vm_name)
        vm = env.get_vm(vm_name)
        dbg('is dead: %r', vm.is_dead())
        vm.create()
        self.vm = vm
        if qom_mode:
            return get_qom_cpuid(self, vm)
        else:
            return get_test_kernel_cpuid(self, vm)

    def cpuid_to_vendor(cpuid_dump, idx):
        dst = []
        map(lambda i:
            dst.append((chr(cpuid_dump[idx, 0, 'ebx'] >> (8 * i) & 0xff))),
            range(0, 4))
        map(lambda i:
            dst.append((chr(cpuid_dump[idx, 0, 'edx'] >> (8 * i) & 0xff))),
            range(0, 4))
        map(lambda i:
            dst.append((chr(cpuid_dump[idx, 0, 'ecx'] >> (8 * i) & 0xff))),
            range(0, 4))
        return ''.join(dst)

    def default_vendor(self):
        """
        Boot qemu with specified cpu models and
        verify that CPU vendor matches requested
        """
        cpu_models = cpu_models_to_test()

        vendor = params.get("vendor")
        if vendor is None or vendor == "host":
            cmd = "grep 'vendor_id' /proc/cpuinfo | head -n1 | awk '{print $3}'"
            cmd_result = utils.run(cmd, ignore_status=True)
            vendor = cmd_result.stdout.strip()

        ignore_cpus = set(params.get("ignore_cpu_models", "").split(' '))
        cpu_models = cpu_models - ignore_cpus

        for cpu_model in cpu_models:
            out = get_guest_cpuid(self, cpu_model)
            guest_vendor = cpuid_to_vendor(out, 0x00000000)
            logging.debug("Guest's vendor: " + guest_vendor)
            if guest_vendor != vendor:
                raise error.TestFail("Guest vendor [%s], doesn't match "
                                     "required vendor [%s] for CPU [%s]" %
                                     (guest_vendor, vendor, cpu_model))

    def custom_vendor(self):
        """
        Boot qemu with specified vendor
        """
        has_error = False
        vendor = params["vendor"]

        try:
            out = get_guest_cpuid(self, cpu_model, "vendor=" + vendor)
            guest_vendor0 = cpuid_to_vendor(out, 0x00000000)
            guest_vendor80000000 = cpuid_to_vendor(out, 0x80000000)
            logging.debug("Guest's vendor[0]: " + guest_vendor0)
            logging.debug("Guest's vendor[0x80000000]: " +
                          guest_vendor80000000)
            if guest_vendor0 != vendor:
                raise error.TestFail("Guest vendor[0] [%s], doesn't match "
                                     "required vendor [%s] for CPU [%s]" %
                                     (guest_vendor0, vendor, cpu_model))
            if guest_vendor80000000 != vendor:
                raise error.TestFail("Guest vendor[0x80000000] [%s], "
                                     "doesn't match required vendor "
                                     "[%s] for CPU [%s]" %
                                     (guest_vendor80000000, vendor,
                                      cpu_model))
        except:
            has_error = True
            if xfail is False:
                raise
        if (has_error is False) and (xfail is True):
            raise error.TestFail("Test was expected to fail, but it didn't")

    def cpuid_to_level(cpuid_dump):
        r = cpuid_dump[0, 0]
        return r['eax']

    def custom_level(self):
        """
        Boot qemu with specified level
        """
        has_error = False
        level = params["level"]
        try:
            out = get_guest_cpuid(self, cpu_model, "level=" + level)
            guest_level = str(cpuid_to_level(out))
            if guest_level != level:
                raise error.TestFail("Guest's level [%s], doesn't match "
                                     "required level [%s]" %
                                     (guest_level, level))
        except:
            has_error = True
            if xfail is False:
                raise
        if (has_error is False) and (xfail is True):
            raise error.TestFail("Test was expected to fail, but it didn't")

    def cpuid_to_family(cpuid_dump):
        # Intel Processor Identification and the CPUID Instruction
        # http://www.intel.com/Assets/PDF/appnote/241618.pdf
        # 5.1.2 Feature Information (Function 01h)
        eax = cpuid_dump[1, 0]['eax']
        family = (eax >> 8) & 0xf
        if family == 0xf:
            # extract extendend family
            return family + ((eax >> 20) & 0xff)
        return family

    def custom_family(self):
        """
        Boot qemu with specified family
        """
        has_error = False
        family = params["family"]
        try:
            out = get_guest_cpuid(self, cpu_model, "family=" + family)
            guest_family = str(cpuid_to_family(out))
            if guest_family != family:
                raise error.TestFail("Guest's family [%s], doesn't match "
                                     "required family [%s]" %
                                     (guest_family, family))
        except:
            has_error = True
            if xfail is False:
                raise
        if (has_error is False) and (xfail is True):
            raise error.TestFail("Test was expected to fail, but it didn't")

    def cpuid_to_model(cpuid_dump):
        # Intel Processor Identification and the CPUID Instruction
        # http://www.intel.com/Assets/PDF/appnote/241618.pdf
        # 5.1.2 Feature Information (Function 01h)
        eax = cpuid_dump[1, 0]['eax']
        model = (eax >> 4) & 0xf
        # extended model
        model |= (eax >> 12) & 0xf0
        return model

    def custom_model(self):
        """
        Boot qemu with specified model
        """
        has_error = False
        model = params["model"]
        try:
            out = get_guest_cpuid(self, cpu_model, "model=" + model)
            guest_model = str(cpuid_to_model(out))
            if guest_model != model:
                raise error.TestFail("Guest's model [%s], doesn't match "
                                     "required model [%s]" %
                                     (guest_model, model))
        except:
            has_error = True
            if xfail is False:
                raise
        if (has_error is False) and (xfail is True):
            raise error.TestFail("Test was expected to fail, but it didn't")

    def cpuid_to_stepping(cpuid_dump):
        # Intel Processor Identification and the CPUID Instruction
        # http://www.intel.com/Assets/PDF/appnote/241618.pdf
        # 5.1.2 Feature Information (Function 01h)
        eax = cpuid_dump[1, 0]['eax']
        stepping = eax & 0xf
        return stepping

    def custom_stepping(self):
        """
        Boot qemu with specified stepping
        """
        has_error = False
        stepping = params["stepping"]
        try:
            out = get_guest_cpuid(self, cpu_model, "stepping=" + stepping)
            guest_stepping = str(cpuid_to_stepping(out))
            if guest_stepping != stepping:
                raise error.TestFail("Guest's stepping [%s], doesn't match "
                                     "required stepping [%s]" %
                                     (guest_stepping, stepping))
        except:
            has_error = True
            if xfail is False:
                raise
        if (has_error is False) and (xfail is True):
            raise error.TestFail("Test was expected to fail, but it didn't")

    def cpuid_to_xlevel(cpuid_dump):
        # Intel Processor Identification and the CPUID Instruction
        # http://www.intel.com/Assets/PDF/appnote/241618.pdf
        # 5.2.1 Largest Extendend Function # (Function 80000000h)
        return cpuid_dump[0x80000000, 0x00]['eax']

    def custom_xlevel(self):
        """
        Boot qemu with specified xlevel
        """
        has_error = False
        xlevel = params["xlevel"]
        if params.get("expect_xlevel") is not None:
            xlevel = params.get("expect_xlevel")

        try:
            out = get_guest_cpuid(self, cpu_model, "xlevel=" +
                                  params.get("xlevel"))
            guest_xlevel = str(cpuid_to_xlevel(out))
            if guest_xlevel != xlevel:
                raise error.TestFail("Guest's xlevel [%s], doesn't match "
                                     "required xlevel [%s]" %
                                     (guest_xlevel, xlevel))
        except:
            has_error = True
            if xfail is False:
                raise
        if (has_error is False) and (xfail is True):
            raise error.TestFail("Test was expected to fail, but it didn't")

    def cpuid_to_model_id(cpuid_dump):
        # Intel Processor Identification and the CPUID Instruction
        # http://www.intel.com/Assets/PDF/appnote/241618.pdf
        # 5.2.3 Processor Brand String (Functions 80000002h, 80000003h,
        # 80000004h)
        m_id = ""
        for idx in (0x80000002, 0x80000003, 0x80000004):
            regs = cpuid_dump[idx, 0]
            for name in ('eax', 'ebx', 'ecx', 'edx'):
                for shift in range(4):
                    c = ((regs[name] >> (shift * 8)) & 0xff)
                    if c == 0:  # drop trailing \0-s
                        break
                    m_id += chr(c)
        return m_id

    def custom_model_id(self):
        """
        Boot qemu with specified model_id
        """
        has_error = False
        model_id = params["model_id"]

        try:
            out = get_guest_cpuid(self, cpu_model, "model_id='%s'" %
                                  model_id)
            guest_model_id = cpuid_to_model_id(out)
            if guest_model_id != model_id:
                raise error.TestFail("Guest's model_id [%s], doesn't match "
                                     "required model_id [%s]" %
                                     (guest_model_id, model_id))
        except:
            has_error = True
            if xfail is False:
                raise
        if (has_error is False) and (xfail is True):
            raise error.TestFail("Test was expected to fail, but it didn't")

    def cpuid_regs_to_string(cpuid_dump, leaf, idx, regs):
        r = cpuid_dump[leaf, idx]
        signature = ""
        for i in regs:
            for shift in range(0, 4):
                c = chr((r[i] >> (shift * 8)) & 0xFF)
                if c in string.printable:
                    signature = signature + c
                else:
                    signature = "%s\\x%02x" % (signature, ord(c))
        logging.debug("(%s.%s:%s: signature: %s" % (leaf, idx, str(regs),
                                                    signature))
        return signature

    def cpuid_signature(self):
        """
        test signature in specified leaf:index:regs
        """
        has_error = False
        flags = params.get("flags", "")
        leaf = int(params.get("leaf", "0x40000000"), 0)
        idx = int(params.get("index", "0x00"), 0)
        regs = params.get("regs", "ebx ecx edx").split()
        signature = params["signature"]
        try:
            out = get_guest_cpuid(self, cpu_model, flags)
            _signature = cpuid_regs_to_string(out, leaf, idx, regs)
            if _signature != signature:
                raise error.TestFail("Guest's signature [%s], doesn't"
                                     "match required signature [%s]" %
                                     (_signature, signature))
        except:
            has_error = True
            if xfail is False:
                raise
        if (has_error is False) and (xfail is True):
            raise error.TestFail("Test was expected to fail, but it didn't")

    def cpuid_bit_test(self):
        """
        test bits in specified leaf:func:reg
        """
        has_error = False
        flags = params.get("flags", "")
        leaf = int(params.get("leaf", "0x40000000"), 0)
        idx = int(params.get("index", "0x00"), 0)
        reg = params.get("reg", "eax")
        bits = params["bits"].split()
        try:
            out = get_guest_cpuid(self, cpu_model, flags)
            r = out[leaf, idx][reg]
            logging.debug("CPUID(%s.%s).%s=0x%08x" % (leaf, idx, reg, r))
            for i in bits:
                if (r & (1 << int(i))) == 0:
                    raise error.TestFail("CPUID(%s.%s).%s[%s] is not set" %
                                         (leaf, idx, reg, i))
        except:
            has_error = True
            if xfail is False:
                raise
        if (has_error is False) and (xfail is True):
            raise error.TestFail("Test was expected to fail, but it didn't")

    def cpuid_reg_test(self):
        """
        test register value in specified leaf:index:reg
        """
        has_error = False
        flags = params.get("flags", "")
        leaf = int(params.get("leaf", "0x00"), 0)
        idx = int(params.get("index", "0x00"), 0)
        reg = params.get("reg", "eax")
        val = int(params["value"], 0)
        try:
            out = get_guest_cpuid(self, cpu_model, flags)
            r = out[leaf, idx][reg]
            logging.debug("CPUID(%s.%s).%s=0x%08x" % (leaf, idx, reg, r))
            if r != val:
                raise error.TestFail("CPUID(%s.%s).%s is not 0x%08x" %
                                     (leaf, idx, reg, val))
        except:
            has_error = True
            if xfail is False:
                raise
        if (has_error is False) and (xfail is True):
            raise error.TestFail("Test was expected to fail, but it didn't")

    def check_cpuid_dump(self):
        """
        Compare full CPUID dump data
        """
        machine_type = params.get("machine_type_to_check", "")
        kvm_enabled = params.get("enable_kvm", "yes") == "yes"

        ignore_cpuid_leaves = params.get("ignore_cpuid_leaves", "")
        ignore_cpuid_leaves = ignore_cpuid_leaves.split()
        whitelist = []
        for l in ignore_cpuid_leaves:
            l = l.split(',')
            # syntax of ignore_cpuid_leaves:
            # <in_eax>[,<in_ecx>[,<register>[ ,<bit>]]] ...
            for i in 0, 1, 3:  # integer fields:
                if len(l) > i:
                    l[i] = int(l[i], 0)
            whitelist.append(tuple(l))

        if not machine_type:
            raise error.TestNAError("No machine_type_to_check defined")
        cpu_model_flags = params.get('cpu_model_flags', '')
        full_cpu_model_name = cpu_model
        if cpu_model_flags:
            full_cpu_model_name += ','
            full_cpu_model_name += cpu_model_flags.lstrip(',')
        ref_file = os.path.join(data_dir.get_deps_dir(), 'cpuid',
                                "cpuid_dumps",
                                kvm_enabled and "kvm" or "nokvm",
                                machine_type, '%s-dump.txt' % (full_cpu_model_name))
        if not os.path.exists(ref_file):
            raise error.TestNAError("no cpuid dump file: %s" % (ref_file))
        reference = open(ref_file, 'r').read()
        if not reference:
            raise error.TestNAError(
                "no cpuid dump data on file: %s" % (ref_file))
        reference = parse_cpuid_dump(reference)
        if reference is None:
            raise error.TestNAError(
                "couldn't parse reference cpuid dump from file; %s" % (ref_file))
        qom_mode = params.get('qom_mode', "no").lower() == 'yes'
        if not qom_mode:
            cpu_model_flags += ',enforce'
        try:

            out = get_guest_cpuid(
                self, cpu_model, cpu_model_flags,
                extra_params=dict(machine_type=machine_type, smp=1),
                qom_mode=qom_mode)
        except (virt_vm.VMStartError, virt_vm.VMCreateError) as e:
            output = getattr(e, 'reason', getattr(e, 'output', ''))
            if "host doesn't support requested feature:" in output \
                or ("host cpuid" in output and
                    ("lacks requested flag" in output or
                     "flag restricted to guest" in output)) \
                    or ("Unable to find CPU definition:" in output):
                raise error.TestNAError(
                    "Can't run CPU model %s on this host" % (full_cpu_model_name))
            else:
                raise
        dbg('ref_file: %r', ref_file)
        dbg('ref: %r', reference)
        dbg('out: %r', out)
        ok = True
        for k in reference.keys():
            in_eax, in_ecx, reg = k
            diffs = compare_cpuid_output(reference[k], out.get(k))
            for d in diffs:
                bit, vreference, vout = d
                whitelisted = (in_eax,) in whitelist \
                    or (in_eax, in_ecx) in whitelist \
                    or (in_eax, in_ecx, reg) in whitelist \
                    or (in_eax, in_ecx, reg, bit) in whitelist
                silent = False

                if vout is None and params.get('ok_missing', 'no') == 'yes':
                    whitelisted = True
                    silent = True

                if not silent:
                    info(
                        "Non-matching bit: CPUID[0x%x,0x%x].%s[%d]: found %s instead of %s%s",
                        in_eax, in_ecx, reg, bit, vout, vreference,
                        whitelisted and " (whitelisted)" or "")

                if not whitelisted:
                    ok = False
        if not ok:
            raise error.TestFail("Unexpected CPUID data")

    # subtests runner
    test_type = params["test_type"]
    if test_type not in locals():
        raise error.TestError("Test function '%s' is not defined in"
                              " test" % test_type)

    test_func = locals()[test_type]
    return test_func(test)

Example 10

Project: mpop
Source File: viirs_sdr.py
View license
    def load(self, satscene, calibrate=1, time_interval=None,
             area=None, filename=None, **kwargs):
        """Read viirs SDR reflectances and Tbs from file and load it into
        *satscene*.
        """
        if satscene.instrument_name != "viirs":
            raise ValueError("Wrong instrument, expecting viirs")

        if kwargs:
            logger.warning(
                "Unsupported options for viirs reader: %s", str(kwargs))

        conf = ConfigParser()
        conf.read(os.path.join(CONFIG_PATH, satscene.fullname + ".cfg"))
        options = {}
        for option, value in conf.items(satscene.instrument_name + "-level2",
                                        raw=True):
            options[option] = value

        band_list = [s.name for s in satscene.channels]
        chns = satscene.channels_to_load & set(band_list)
        if len(chns) == 0:
            return

        if time_interval:
            time_start, time_end = time_interval
        else:
            time_start, time_end = satscene.time_slot, None

        import glob

        if "filename" not in options:
            raise IOError("No filename given, cannot load")

        values = {"orbit": satscene.orbit,
                  "satname": satscene.satname,
                  "instrument": satscene.instrument_name,
                  "satellite": satscene.satname
                  #"satellite": satscene.fullname
                  }

        file_list = []
        if filename is not None:
            if not isinstance(filename, (list, set, tuple)):
                filename = [filename]
            geofile_list = []
            for fname in filename:
                if os.path.basename(fname).startswith("SV"):
                    file_list.append(fname)
                elif os.path.basename(fname).startswith("G"):
                    geofile_list.append(fname)
                else:
                    logger.info("Unrecognized SDR file: %s", fname)
            if file_list:
                directory = os.path.dirname(file_list[0])
            if geofile_list:
                geodirectory = os.path.dirname(geofile_list[0])

        if not file_list:
            filename_tmpl = strftime(
                satscene.time_slot, options["filename"]) % values

            directory = strftime(satscene.time_slot, options["dir"]) % values

            if not os.path.exists(directory):
                #directory = globify(options["dir"]) % values
                directory = globify(
                    strftime(satscene.time_slot, options["dir"])) % values
                logger.debug(
                    "Looking for files in directory " + str(directory))
                directories = glob.glob(directory)
                if len(directories) > 1:
                    raise IOError("More than one directory for npp scene... " +
                                  "\nSearch path = %s\n\tPlease check npp.cfg file!" % directory)
                elif len(directories) == 0:
                    raise IOError("No directory found for npp scene. " +
                                  "\nSearch path = %s\n\tPlease check npp.cfg file!" % directory)
                else:
                    directory = directories[0]

            file_list = glob.glob(os.path.join(directory, filename_tmpl))

            # Only take the files in the interval given:
            logger.debug("Number of files before segment selection: "
                         + str(len(file_list)))
            for fname in file_list:
                if os.path.basename(fname).startswith("SVM14"):
                    logger.debug("File before segmenting: "
                                 + os.path.basename(fname))
            file_list = _get_swathsegment(
                file_list, time_start, time_end, area)
            logger.debug("Number of files after segment selection: "
                         + str(len(file_list)))

            for fname in file_list:
                if os.path.basename(fname).startswith("SVM14"):
                    logger.debug("File after segmenting: "
                                 + os.path.basename(fname))

            logger.debug("Template = " + str(filename_tmpl))

            # 22 VIIRS bands (16 M-bands + 5 I-bands + DNB)
            if len(file_list) % 22 != 0:
                logger.warning("Number of SDR files is not divisible by 22!")
            if len(file_list) == 0:
                logger.debug(
                    "File template = " + str(os.path.join(directory, filename_tmpl)))
                raise IOError("No VIIRS SDR file matching!: " +
                              "Start time = " + str(time_start) +
                              "  End time = " + str(time_end))

            geo_dir_string = options.get("geo_dir", None)
            if geo_dir_string:
                geodirectory = strftime(
                    satscene.time_slot, geo_dir_string) % values
            else:
                geodirectory = directory
            logger.debug("Geodir = " + str(geodirectory))

            geofile_list = []
            geo_filenames_string = options.get("geo_filenames", None)
            if geo_filenames_string:
                geo_filenames_tmpl = strftime(satscene.time_slot,
                                              geo_filenames_string) % values
                geofile_list = glob.glob(os.path.join(geodirectory,
                                                      geo_filenames_tmpl))
                logger.debug("List of geo-files: " + str(geofile_list))
                # Only take the files in the interval given:
                geofile_list = _get_swathsegment(
                    geofile_list, time_start, time_end)

            logger.debug("List of geo-files (after time interval selection): "
                         + str(geofile_list))

        filenames = [os.path.basename(s) for s in file_list]

        glob_info = {}

        self.geofiles = geofile_list

        logger.debug("Channels to load: " + str(satscene.channels_to_load))
        for chn in satscene.channels_to_load:
            # Take only those files in the list matching the band:
            # (Filename starts with 'SV' and then the band-name)
            fnames_band = []

            try:
                fnames_band = [s for s in filenames if s.find('SV' + chn) >= 0]
            except TypeError:
                logger.warning('Band frequency not available from VIIRS!')
                logger.info('Asking for channel' + str(chn) + '!')

            if len(fnames_band) == 0:
                continue

            filename_band = [
                os.path.join(directory, fname) for fname in fnames_band]
            logger.debug("fnames_band = " + str(filename_band))

            band = ViirsBandData(filename_band, calibrate=calibrate).read()
            logger.debug('Band id = ' + band.band_id)

            # If the list of geo-files is not specified in the config file or
            # some of them do not exist, we rely on what is written in the
            # band-data metadata header:
            if len(geofile_list) < len(filename_band):
                geofilenames_band = [os.path.join(geodirectory, gfile) for
                                     gfile in band.geo_filenames]
                logger.debug("Geolocation filenames for band: " +
                             str(geofilenames_band))
                # Check if the geo-filenames found from the metadata actually
                # exist and issue a warning if they do not:
                for geofilename in geofilenames_band:
                    if not os.path.exists(geofilename):
                        logger.warning("Geo file defined in metadata header " +
                                       "does not exist: " + str(geofilename))

            elif band.band_id.startswith('M'):
                geofilenames_band = [geofile for geofile in geofile_list
                                     if os.path.basename(geofile).startswith('GMTCO')]
                if len(geofilenames_band) != len(filename_band):
                    # Try the geoid instead:
                    geofilenames_band = [geofile for geofile in geofile_list
                                         if os.path.basename(geofile).startswith('GMODO')]
                    if len(geofilenames_band) != len(filename_band):
                        raise IOError("Not all geo location files " +
                                      "for this scene are present for band " +
                                      band.band_id + "!")
            elif band.band_id.startswith('I'):
                geofilenames_band = [geofile for geofile in geofile_list
                                     if os.path.basename(geofile).startswith('GITCO')]
                if len(geofilenames_band) != len(filename_band):
                    # Try the geoid instead:
                    geofilenames_band = [geofile for geofile in geofile_list
                                         if os.path.basename(geofile).startswith('GIMGO')]
                    if len(geofilenames_band) != len(filename_band):
                        raise IOError("Not all geo location files " +
                                      "for this scene are present for band " +
                                      band.band_id + "!")
            elif band.band_id.startswith('D'):
                geofilenames_band = [geofile for geofile in geofile_list
                                     if os.path.basename(geofile).startswith('GDNBO')]
                if len(geofilenames_band) != len(filename_band):
                    raise IOError("Not all geo-location files " +
                                  "for this scene are present for " +
                                  "the Day Night Band!")

            band.read_lonlat(geofilepaths=geofilenames_band)

            if not band.band_desc:
                logger.warning('Band name = ' + band.band_id)
                raise AttributeError('Band description not supported!')

            satscene[chn].data = band.data
            satscene[chn].info['units'] = band.units
            satscene[chn].info['band_id'] = band.band_id
            satscene[chn].info['start_time'] = band.begin_time
            satscene[chn].info['end_time'] = band.end_time
            if chn in ['M01', 'M02', 'M03', 'M04', 'M05', 'M06', 'M07', 'M08', 'M09', 'M10', 'M11',
                       'I01', 'I02', 'I03']:
                satscene[chn].info['sun_zen_correction_applied'] = True

            # We assume the same geolocation should apply to all M-bands!
            # ...and the same to all I-bands:

            from pyresample import geometry

            satscene[chn].area = geometry.SwathDefinition(
                lons=np.ma.masked_where(band.data.mask,
                                        band.geolocation.longitudes,
                                        copy=False),
                lats=np.ma.masked_where(band.data.mask,
                                        band.geolocation.latitudes,
                                        copy=False))
            area_name = ("swath_" + satscene.fullname + "_" +
                         str(satscene.time_slot) + "_"
                         + str(satscene[chn].data.shape) + "_" +
                         band.band_uid)

            satscene[chn].area.area_id = area_name
            satscene[chn].area_id = area_name

            if self.shape is None:
                self.shape = band.data.shape

            # except ImportError:
            #    satscene[chn].area = None
            #    satscene[chn].lat = np.ma.array(band.latitude, mask=band.data.mask)
            #    satscene[chn].lon = np.ma.array(band.longitude, mask=band.data.mask)

            # if 'institution' not in glob_info:
            ##    glob_info['institution'] = band.global_info['N_Dataset_Source']
            # if 'mission_name' not in glob_info:
            ##    glob_info['mission_name'] = band.global_info['Mission_Name']

        ViirsGeolocationData.clear_cache()

        # Compulsory global attribudes
        satscene.info["title"] = (satscene.satname.capitalize() +
                                  " satellite, " +
                                  satscene.instrument_name.capitalize() +
                                  " instrument.")
        if 'institution' in glob_info:
            satscene.info["institution"] = glob_info['institution']

        if 'mission_name' in glob_info:
            satscene.add_to_history(glob_info['mission_name'] +
                                    " VIIRS SDR read by mpop")
        else:
            satscene.add_to_history("NPP/JPSS VIIRS SDR read by mpop")

        satscene.info["references"] = "No reference."
        satscene.info["comments"] = "No comment."

        satscene.info["start_time"] = min([chn.info["start_time"]
                                           for chn in satscene
                                           if chn.is_loaded()])
        satscene.info["end_time"] = max([chn.info["end_time"]
                                         for chn in satscene
                                         if chn.is_loaded()])

Example 11

Project: AvsPmod
Source File: [7] Optimize Sliders.py
View license
def main():
    import random
    import math
    import subprocess
    import os
    import os.path
    
    app = avsp.GetWindow()
    params = []
    scriptTemplate = ''
    logfilename = 'log.txt'
    avs2avidir = os.path.join(app.toolsfolder, 'avs2avi.exe')
    
    # Simple Genetic Algorithm implementation
    class SGA(object):
        def __init__(self,
                chromosome_length,
                objective_function,
                population_size=100,
                probability_crossover=0.5,
                probability_mutation=0.01,
                selection_pressure=4,
                max_generations=10,
                minimize=True,
                dump_function=None):
            # Define the variables for the key GA parameters
            SGA.length = chromosome_length
            self.objfn = objective_function
            self.n = population_size - population_size % 2
            self.pc = probability_crossover
            self.pm = probability_mutation
            self.s = selection_pressure
            self.maxgen = max_generations
            SGA.minimize = minimize
            self.dump = dump_function
            self.generation = 0
            self.scoreDict = {}
            # Define the individual class
            class Individual(object):
                def __init__(self, chromosome=None):
                    self.length = SGA.length
                    self.minimize = SGA.minimize
                    self.score = None
                    self.chromosome = chromosome
                    if self.chromosome is None:
                        self.chromosome = [random.choice((0,1)) for i in xrange(self.length)]
                        
                def __cmp__(self, other):
                    if self.minimize:
                        return cmp(self.score, other.score)
                    else:
                        return cmp(other.score, self.score)                    
                        
                def copy(self):
                    twin = self.__class__(self.chromosome[:])
                    twin.score = self.score
                    return twin
            self.Individual = Individual
            
        def run(self):
            # Create the initial population (generation 0)
            self.population = [self.Individual() for i in range(self.n)]
            try:
                pb = avsp.ProgressBox(self.n, _('Initial evaluation...'), _('Generation 0 Progress'))
            except NameError:
                pb = None
            try:
                for i, individual in enumerate(self.population):
                    self.evaluate(individual)
                    if pb is not None:
                        if not pb.Update(i)[0]:
                            pb.Destroy()
                            return False
                # Dump the best data from this generation
                best = min(self.population)
                initialscore = best.score
                if self.dump is not None:
                    self.dump(best.chromosome, best.score)
                if pb is not None:
                    pb.Destroy()
                self.generation += 1
                # Run the genetic algorithm
                while self.generation < self.maxgen:
                    # Create a progress bar for this generation
                    if pb is not None:
                        pb = avsp.ProgressBox(
                            self.n,
                            _('Initial best score: %.3f, Current best score: %.3f') % (initialscore, best.score),
                            'Generation %i Progress' % self.generation
                        )
                    newpopulation = [best.copy()]
                    count = len(newpopulation)
                    while count < self.n:
                    #~ for i in xrange(self.n/2):
                        # Selection
                        mate1 = self.selection()
                        mate2 = self.selection()
                        # Crossover
                        children = self.crossover(mate1, mate2)
                        for individual in children:
                            # Mutation
                            self.mutation(individual)
                            # Evaluate the individual and add it to the new population
                            self.evaluate(individual)
                            newpopulation.append(individual)
                        # Update the progress bar
                        count = len(newpopulation)
                        if pb is not None:
                            i = min(count-1, self.n-1)
                            if not pb.Update(i)[0]:
                                pb.Destroy()
                                return False
                    # Update the internally stored population
                    self.population = newpopulation[:self.n]
                    # Dump the best data from this generation
                    best = min(self.population)
                    if self.dump is not None:
                        self.dump(best.chromosome, best.score)
                    # Destroy the progress bar for this generation
                    if pb is not None:
                        pb.Destroy()
                    self.generation += 1
            finally:
                if pb is not None:
                    pb.Destroy()
            return True
            
        def crossover(self, individual1, individual2):
            '''Two point crossover'''
            if random.random() < self.pc:
                # Pick the crossover points randomly
                left = random.randrange(1, self.length-2)
                right = random.randrange(left, self.length-1)
                # Create the children chromosomes
                p1 = individual1.chromosome
                p2 = individual2.chromosome
                c1 = p1[:left] + p2[left:right] + p1[right:]
                c2 = p2[:left] + p1[left:right] + p2[right:]
                # Return the new individuals
                return self.Individual(c1), self.Individual(c2)
            else:
                # Don't perform crossover
                return individual1.copy(), individual2.copy()
            
        def mutation(self, individual):
            '''Bit-flip mutation'''
            # Randomly flip each bit in the chromosome
            chromosome = individual.chromosome
            for gene in xrange(self.length):
                if random.random() < self.pm:
                    chromosome[gene] = int(not chromosome[gene])
                    
        def selection(self):
            '''Tournament selection with replacement'''
            # Return best individual from s randomly selected members
            competitors = [random.choice(self.population) for i in range(self.s)]
            #~ competitors.sort()
            #~ return competitors[0]
            return min(competitors)
            
        def evaluate(self, individual):
            intChromosome = binary2int(individual.chromosome)
            if self.scoreDict.has_key(intChromosome):
                # The chromosome was evaluated previously
                individual.score = self.scoreDict[intChromosome]
            else:
                # Run the objective function to evaluate the chromosome
                individual.score = self.objfn(individual.chromosome)
                self.scoreDict[intChromosome] = individual.score
                
    def binary2int(x):
        '''decode a binary list to a single unsigned integer'''
        return sum(map(lambda z: int(x[z]) and 2**(len(x) - z - 1),  range(len(x)-1, -1, -1)))
        
    def decode_params(bitlist, params):
        '''returns dictionary of values for each param'''
        iA = 0
        paramDict = {}
        for name, valuelist, nbits in params:
            iB = iA + nbits
            sublist = bitlist[iA:iB]
            #~ value = min + binary2int(sublist) * (max-min)/float(2**nbits - 1)
            #~ if type(min) == bool:
                #~ value = bool(value)
            index = int(binary2int(sublist) * (len(valuelist) - 1) / float(2 ** nbits - 1))
            paramDict[name] = valuelist[index]
            iA = iB
        return paramDict    
        
    def evaluate(chromosome):
        # Decode the bit string into the individual parameters
        paramDict = decode_params(chromosome, params)
        # Create the AviSynth script
        script = scriptTemplate % paramDict
        inputavsname = os.path.join(scriptdir, 'ga_evaluate.avs')
        script = app.GetEncodedText(script, bom=True)
        f = open(inputavsname, 'w')
        f.write(script)
        f.close()
        # Encode the video to get the results (dumped to log.txt)
        try:
            os.remove(logfilename)
        except OSError:
            pass
        subprocess.call([avs2avidir, inputavsname, '-q', '-o', 'n', '-c','null'], shell=True)
        # Read the results in log.txt
        if os.path.isfile(logfilename):
            f = open(logfilename, 'r')
            lines = f.readlines()
            f.close()
            score = float(lines[-1].split()[2])
            #~ print 'good!', score
        else:
            score = 0
            #~ print '*** Error, bad script:'
            #~ print script
            #~ print '*** End script'
        return score
        
    def dump(chromosome, score=None):
        '''Write the script to a file'''
        paramDict = decode_params(chromosome, params)
        script = scriptTemplate % paramDict
        script = app.GetEncodedText(script, bom=True)
        f = open(os.path.splitext(filename)[0] + '-optimized.avs', 'w')
        f.write(script)
        f.close()
        if score is not None:
            print _('Best score: %.2f') % score
            
    # MAIN SECTION
    if not avs2avidir or not os.path.isfile(avs2avidir):
        avsp.MsgBox(_('Must configure avs2avi directory to use this macro!'), _('Error'))
        return
    # Save the script
    filename = avsp.SaveScript()
    if not filename:
        return
    if not avsp.UpdateVideo():
        avsp.MsgBox(_('The current Avisynth script contains errors.'), _('Error'))
        return
    scriptdir = os.path.dirname(filename)
    scriptTemplate = avsp.GetText()
    # Parse the script to determine the log filename
    
    # Create the parameters to optimize based on user sliders in the script
    sliderInfoList = avsp.GetSliderInfo()
    if not sliderInfoList:
        avsp.MsgBox(_('Not user sliders on the current Avisynth script!'), _('Error'))
        return
    length = 0
    for text, label, valuelist, nDecimal in sliderInfoList:
        if valuelist is None:
            continue
        mantissa, nbits = math.frexp(len(valuelist))
        if mantissa == 0.5:
            nbits -= 1
        params.append([label, valuelist, nbits])
        length += nbits
        scriptTemplate = scriptTemplate.replace(text, '%('+label+').'+str(nDecimal)+'f')
    # Get basic encoder options with a dialog box
    title = _('Enter optimization info    (%i bits, %i possibilities)') % (length, 2**length)
    message = [_('SSIM log filename:'), [_('max generations:'), _('population size:'), 
              _('crossover probability:'), _('mutation probability:'), _('selection pressure:')]]
    dirname, basename = os.path.split(logfilename)
    if not os.path.isdir(dirname):
        logfilename = os.path.join(app.GetProposedPath(only='dir'), basename)
    default = [logfilename, [(10, 1), (30, 1), (0.6, 0, 1, 2, 0.05), (0.03, 0, 1, 2, 0.05), 4]]
    types = ['file_save', ['spin', 'spin', 'spin', 'spin', 'spin']]
    entries = avsp.GetTextEntry(message, default, title, types)
    if not entries:
        return
    # First clear the AVI from memory (to close the log file)
    txt = avsp.GetText()
    avsp.HideVideoWindow()
    avsp.CloseTab()
    avsp.OpenFile(filename)
    avsp.SetText(txt)
    avsp.SaveScript()
    # Run the optimization
    logfilename, maxgen, n, pc, pm, s = entries
    print _('Begin optimization...')
    print 'n=%s, pc=%s, pm=%s, s=%s, maxgen=%s (%i bits)' % (n, pc, pm, s, maxgen, length)
    sga = SGA(length, evaluate, int(n), float(pc), float(pm), int(s), int(maxgen), False, dump)
    sga.run()
    os.remove(os.path.join(scriptdir, 'ga_evaluate.avs'))
    print _('Finished optimization.')
    # Show the optimized results
    avsp.OpenFile(os.path.splitext(filename)[0] + '-optimized.avs')
    avsp.ShowVideoFrame()

Example 12

Project: Devede
Source File: devede_xml_menu.py
View license
	def create_xml(self):

		""" Creates the XML file for DVDAuthor """

		# calculate the position for each title
		
		title_list=[]
		counter=1
		for element in self.structure:
			title_list.append(counter)
			counter+=((len(element))-1)

		try:
			fichero=open(self.filefolder+self.filename+".xml","w")
			fichero.write('<dvdauthor dest="'+self.expand_xml(self.filefolder+self.filename)+'">\n')
			
			if self.onlyone:
				fichero.write('\t<vmgm />\n')
			else:
			
				fichero.write('\t<vmgm>\n')
				
				# MENU
				
				# in the FPC we do a jump to the first menu in the first titleset if we wanted MENU
				# or we jump to the second titleset if we didn't want MENU at startup
				
				fichero.write('\t\t<fpc>\n')
				fichero.write('\t\t\tg0=100;\n')
				if self.do_menu and self.with_menu:
					fichero.write('\t\t\tg1=0;\n')
				else:
					fichero.write('\t\t\tg1=100;\n')
				fichero.write('\t\t\tg2=1024;\n')
				fichero.write('\t\t\tjump menu 1;\n')
				fichero.write('\t\t</fpc>\n')
				
				# in the VMGM menu we create a code to jump to the title specified in G0
				# but if the title is 100, we jump to the menus. There we show the menu number
				# contained in G1
				
				fichero.write("\t\t<menus>\n")
					
				fichero.write('\t\t\t<video format="')
				if self.menu_PAL:
					fichero.write("pal")
				else:
					fichero.write("ntsc")
				fichero.write('" aspect="4:3"> </video>\n')
	
				fichero.write('\t\t\t<pgc>\n')
				fichero.write('\t\t\t\t<pre>\n')
				
				counter=1
				for element in self.structure:
					for element2 in element[1:]:
						fichero.write('\t\t\t\t\tif (g0 eq '+str(counter)+') {\n')
						fichero.write('\t\t\t\t\t\tjump titleset '+str(1+counter)+' menu;\n')
						fichero.write('\t\t\t\t\t}\n')
						counter+=1
				fichero.write('\t\t\t\t\tif (g0 eq 100) {\n')
				fichero.write('\t\t\t\t\t\tg2=1024;\n')
				fichero.write('\t\t\t\t\t\tjump titleset 1 menu;\n')
				fichero.write('\t\t\t\t\t}\n')
				fichero.write('\t\t\t\t</pre>\n')
				# fake video (one black picture with one second of sound) to ensure 100% compatibility
				fichero.write('\t\t\t\t<vob file="')
				if self.menu_PAL:
					fichero.write(self.expand_xml(str(os.path.join(self.installpath,"base_pal.mpg"))))
				else:
					fichero.write(self.expand_xml(str(os.path.join(self.installpath,"base_ntsc.mpg"))))
				fichero.write('"></vob>\n')
				fichero.write('\t\t\t</pgc>\n')
				fichero.write('\t\t</menus>\n')
				fichero.write("\t</vmgm>\n")
				
				fichero.write("\n")
				
				# the first titleset contains all the menus. G1 allows us to jump to the desired menu
				
				fichero.write('\t<titleset>\n')
				fichero.write('\t\t<menus>\n')
				fichero.write('\t\t\t<video format="')
				if self.menu_PAL:
					fichero.write("pal")
				else:
					fichero.write("ntsc")
				fichero.write('" aspect="4:3"> </video>\n')
				
				button_counter=0
				for menu_number in range(self.nmenues):
					fichero.write('\t\t\t<pgc>\n')
					fichero.write('\t\t\t\t<pre>\n')
					# first we recover the currently selected button
					fichero.write('\t\t\t\t\ts8=g2;\n')
					if menu_number==0: # here we add some code to jump to each menu
						for menu2 in range(self.nmenues-1):
							fichero.write('\t\t\t\t\tif (g1 eq '+str(menu2+1)+') {\n')
							fichero.write('\t\t\t\t\t\tjump menu '+str(menu2+2)+';\n')
							fichero.write('\t\t\t\t\t}\n')
						
						# this code is to fix a bug in some players
						fichero.write('\t\t\t\t\tif (g1 eq 100) {\n')
						fichero.write('\t\t\t\t\t\tjump title 1;\n')#menu '+str(self.nmenues+1)+';\n')
						fichero.write('\t\t\t\t\t}\n')
						
					fichero.write('\t\t\t\t</pre>\n')
					fichero.write('\t\t\t\t<vob file="')
					if self.with_menu:
						fichero.write(self.expand_xml(str(os.path.join(self.filefolder,self.filename)))+'_menu2_'+str(menu_number)+'.mpg"')
					else:
						if self.menu_PAL:
							fichero.write(self.expand_xml(str(os.path.join(self.installpath,"base_pal.mpg")))+'"')
						else:
							fichero.write(self.expand_xml(str(os.path.join(self.installpath,"base_ntsc.mpg")))+'"')
					fichero.write('></vob>\n')
					
					if self.with_menu:
						cantidad=len(self.structure[self.elements_per_menu*menu_number:(menu_number+1)*self.elements_per_menu])
						for nbutton in range(cantidad):
							fichero.write('\t\t\t\t<button name="boton')
							fichero.write(str(menu_number))
							fichero.write('x')
							fichero.write(str(nbutton))
							fichero.write('"> g0='+str(title_list[button_counter])+'; jump vmgm menu; </button>\n')
							button_counter+=1
							
						if (menu_number!=0):
							fichero.write('\t\t\t\t<button name="boton')
							fichero.write(str(menu_number))
							fichero.write('p"> g1=')
							fichero.write(str(menu_number-1))
							fichero.write('; g2=1024; jump menu ')
							fichero.write(str(menu_number))
							fichero.write('; </button>\n')
							
						if (menu_number!=self.nmenues-1) and (self.nmenues>1):
							fichero.write('\t\t\t\t<button name="boton')
							fichero.write(str(menu_number))
							fichero.write('n"> g1=')
							fichero.write(str(menu_number+1))
							fichero.write('; g2=1024; jump menu ')
							fichero.write(str(menu_number+2))
							fichero.write('; </button>\n')
					
					fichero.write('\t\t\t\t<post>\n')
					fichero.write('\t\t\t\t\tg2=s8;\n')
					fichero.write('\t\t\t\t\tg1='+str(menu_number)+';\n')
					fichero.write('\t\t\t\t\tjump menu '+str(menu_number+1)+';\n')
					fichero.write('\t\t\t\t</post>\n')
					fichero.write('\t\t\t</pgc>\n')
				
				fichero.write('\t\t</menus>\n')
				fichero.write('\t\t<titles>\n')
				fichero.write('\t\t\t<video format="')
				if self.menu_PAL:
					fichero.write("pal")
				else:
					fichero.write("ntsc")
				fichero.write('" aspect="4:3"> </video>\n')
				fichero.write('\t\t\t<pgc>\n')
				fichero.write('\t\t\t\t<vob file="')
				if self.menu_PAL:
					fichero.write(self.expand_xml(str(os.path.join(self.installpath,"base_pal.mpg"))))
				else:
					fichero.write(self.expand_xml(str(os.path.join(self.installpath,"base_ntsc.mpg"))))
				fichero.write('"></vob>\n')
				fichero.write('\t\t\t\t<post>\n')
				fichero.write('\t\t\t\t\tg0=1;\n')
				fichero.write('\t\t\t\t\tg1=0;\n')
				fichero.write('\t\t\t\t\tg2=1024;\n')
				fichero.write('\t\t\t\t\tcall vmgm menu entry title;\n')
				fichero.write('\t\t\t\t</post>\n')
				fichero.write('\t\t\t</pgc>\n')
				fichero.write('\t\t</titles>\n')
				fichero.write("\t</titleset>\n")
	
				fichero.write("\n")
				
			# Now we create the titleset for each video
			
			total_t=len(self.structure)
			titleset=1
			titles=0
			counter=0
			for element in self.structure:
				files=0
				num_chapters=len(element)-1
				action=element[0]["jumpto"]
				for element2 in element[1:]:
					fichero.write("\n")
					
					if element2["ismpeg"]:

						# if it's already an MPEG-2 compliant file, we use the original values
						if element2["ofps"]==25:
							pal_ntsc="pal"
							ispal=True
						else:
							pal_ntsc="ntsc"
							ispal=False
						if element2["oaspect"]>1.6:
							faspect='16:9'
							fwide=True
						else:
							faspect='4:3'
							fwide=False
					else:
						# but if we are converting it, we use the desired values
						if element2["fps"]==25:
							pal_ntsc="pal"
							ispal=True
						else:
							pal_ntsc="ntsc"
							ispal=False
						if element2["aspect"]>1.6:
							faspect='16:9'
							fwide=True
						else:
							faspect='4:3'
							fwide=False
					
					fichero.write("\t<titleset>\n")
					if not self.onlyone:
						fichero.write("\t\t<menus>\n")
						fichero.write('\t\t\t<video format="'+pal_ntsc+'" aspect="'+faspect+'"')
						if fwide:
							fichero.write(' widescreen="nopanscan"')
						fichero.write('> </video>\n')
						
						fichero.write("\t\t\t<pgc>\n")
						fichero.write("\t\t\t\t<pre>\n")
						fichero.write('\t\t\t\t\tif (g0 eq 100) {\n')
						fichero.write('\t\t\t\t\t\tjump vmgm menu entry title;\n')
						fichero.write('\t\t\t\t\t}\n')
						fichero.write('\t\t\t\t\tg0=100;\n')
						fichero.write('\t\t\t\t\tg1='+str(titles/self.elements_per_menu)+';\n')
						fichero.write('\t\t\t\t\tjump title 1;\n')
						fichero.write('\t\t\t\t</pre>\n')
						# fake video to ensure compatibility
						fichero.write('\t\t\t\t<vob file="')
						if ispal:
							fichero.write(self.expand_xml(str(os.path.join(self.installpath,"base_pal"))))
						else:
							fichero.write(self.expand_xml(str(os.path.join(self.installpath,"base_ntsc"))))
						if fwide:
							fichero.write("_wide")
						fichero.write('.mpg"></vob>\n')
						fichero.write("\t\t\t</pgc>\n")
						fichero.write("\t\t</menus>\n")

					fichero.write("\t\t<titles>\n")
					fichero.write('\t\t\t<video format="'+pal_ntsc+'" aspect="'+faspect+'"')
					if fwide:
						fichero.write(' widescreen="nopanscan"')
					fichero.write('> </video>\n')
					
					for element3 in element2["sub_list"]:
						fichero.write('\t\t\t<subpicture lang="'+self.expand_xml(str(element3["sub_language"][:2].lower()))+'" />\n')
					fichero.write('\t\t\t<pgc>\n')
					if (element2["force_subs"]) and (len(element2["sub_list"])!=0):
						fichero.write('\t\t\t\t<pre>\n')
						fichero.write('\t\t\t\t\tsubtitle=64;\n')
						fichero.write('\t\t\t\t</pre>\n')

					currentfile=self.create_filename(self.filefolder+self.filename,titles+1,files+1,False)
					fichero.write('\t\t\t\t<vob file="'+self.expand_xml(currentfile)+'" ')
					fichero.write('chapters="0')
					if (element2["olength"]>5):
						if (element2["lchapters"]!=0): # add chapters
							toadd=int(element2["lchapters"])
							seconds=toadd*60
							while seconds<(element2["olength"]-4):
								thetime=devede_other.return_time(seconds,False)
								fichero.write(","+thetime)
								seconds+=(toadd*60)
						fichero.write(','+devede_other.return_time((element2["olength"]-2),False))
					fichero.write('" />\n')
					
					if not self.onlyone:
						fichero.write('\t\t\t\t<post>\n')
						files+=1
						fichero.write('\t\t\t\t\tg1='+str(titles/self.elements_per_menu)+';\n')
						if (files==num_chapters) and (action=="menu"): # last chapter
							fichero.write('\t\t\t\t\tg0=100;\n')
							fichero.write('\t\t\t\t\tcall vmgm menu entry title;\n')
						else:
							fichero.write('\t\t\t\t\tg0=')
							if (files==num_chapters): # last chapter; do ACTION
								if action=="prev":
									if titles==0:
										prev_t=total_t-1
									else:
										prev_t=titles-1
									fichero.write(str(title_list[prev_t]))
								elif action=="loop":
									fichero.write(str(title_list[titles]))
								elif action=="next":
									if titles==total_t-1:
										next_t=0
									else:
										next_t=titles+1
									fichero.write(str(title_list[next_t]))
								elif action=="last":
									fichero.write(str(title_list[total_t-1]))
								else:
									fichero.write('1') # first
							else:
								 # jump to next chapter in title
								fichero.write(str(title_list[titles]+files))
							fichero.write(';\n')
							fichero.write('\t\t\t\t\tcall vmgm menu entry title;\n')
						fichero.write('\t\t\t\t</post>\n')
					fichero.write("\t\t\t</pgc>\n")
					fichero.write("\t\t</titles>\n")
					fichero.write("\t</titleset>\n")
					counter+=1
				titles+=1
			fichero.write("</dvdauthor>")
			fichero.close()
			return False
		except IOError:
			return True

Example 13

Project: pygame_sdl2
Source File: run_tests.py
View license
def run(*args, **kwds):
    """Run the Pygame unit test suite and return (total tests run, fails dict)

    Positional arguments (optional):
    The names of tests to include. If omitted then all tests are run. Test
    names need not include the trailing '_test'.

    Keyword arguments:
    incomplete - fail incomplete tests (default False)
    nosubprocess - run all test suites in the current process
                   (default False, use separate subprocesses)
    dump - dump failures/errors as dict ready to eval (default False)
    file - if provided, the name of a file into which to dump failures/errors
    timings - if provided, the number of times to run each individual test to
              get an average run time (default is run each test once)
    exclude - A list of TAG names to exclude from the run. The items may be
              comma or space separated.
    show_output - show silenced stderr/stdout on errors (default False)
    all - dump all results, not just errors (default False)
    randomize - randomize order of tests (default False)
    seed - if provided, a seed randomizer integer
    multi_thread - if provided, the number of THREADS in which to run
                   subprocessed tests
    time_out - if subprocess is True then the time limit in seconds before
               killing a test (default 30)
    fake - if provided, the name of the fake tests package in the
           run_tests__tests subpackage to run instead of the normal
           Pygame tests
    python - the path to a python executable to run subprocessed tests
             (default sys.executable)
    interative - allow tests tagged 'interative'.

    Return value:
    A tuple of total number of tests run, dictionary of error information. The
    dictionary is empty if no errors were recorded.

    By default individual test modules are run in separate subprocesses. This
    recreates normal Pygame usage where pygame.init() and pygame.quit() are
    called only once per program execution, and avoids unfortunate
    interactions between test modules. Also, a time limit is placed on test
    execution, so frozen tests are killed when there time allotment expired.
    Use the single process option if threading is not working properly or if
    tests are taking too long. It is not guaranteed that all tests will pass
    in single process mode.

    Tests are run in a randomized order if the randomize argument is True or a
    seed argument is provided. If no seed integer is provided then the system
    time is used.

    Individual test modules may have a corresponding *_tags.py module,
    defining a __tags__ attribute, a list of tag strings used to selectively
    omit modules from a run. By default only the 'interactive', 'ignore', and
    'subprocess_ignore' tags are ignored. 'interactive' is for modules that
    take user input, like cdrom_test.py. 'ignore' and 'subprocess_ignore' for
    for disabling modules for foreground and subprocess modes respectively.
    These are for disabling tests on optional modules or for experimental
    modules with known problems. These modules can be run from the console as
    a Python program.

    This function can only be called once per Python session. It is not
    reentrant.

    """

    global was_run

    if was_run:
        raise RuntimeError("run() was already called this session")
    was_run = True
                           
    options = kwds.copy()
    option_nosubprocess = options.get('nosubprocess', False)
    option_dump = options.pop('dump', False)
    option_file = options.pop('file', None)
    option_all = options.pop('all', False)
    option_randomize = options.get('randomize', False)
    option_seed = options.get('seed', None)
    option_multi_thread = options.pop('multi_thread', 1)
    option_time_out = options.pop('time_out', 120)
    option_fake = options.pop('fake', None)
    option_python = options.pop('python', sys.executable)
    option_exclude = options.pop('exclude', ())
    option_interactive = options.pop('interactive', False)

    if not option_interactive and 'interactive' not in option_exclude:
        option_exclude += ('interactive',)
    if not option_nosubprocess and 'subprocess_ignore' not in option_exclude:
        option_exclude += ('subprocess_ignore',)
    elif 'ignore' not in option_exclude:
        option_exclude += ('ignore',)
    if sys.version_info < (3, 0, 0):
        option_exclude += ('python2_ignore',)
    else:
        option_exclude += ('python3_ignore',)

    main_dir, test_subdir, fake_test_subdir = prepare_test_env()
    test_runner_py = os.path.join(test_subdir, "test_utils", "test_runner.py")
    cur_working_dir = os.path.abspath(os.getcwd())

    ###########################################################################
    # Compile a list of test modules. If fake, then compile list of fake
    # xxxx_test.py from run_tests__tests

    TEST_MODULE_RE = re.compile('^(.+_test)\.py$')

    test_mods_pkg_name = test_pkg_name
    
    if option_fake is not None:
        test_mods_pkg_name = '.'.join([test_mods_pkg_name,
                                       'run_tests__tests',
                                       option_fake])
        test_subdir = os.path.join(fake_test_subdir, option_fake)
        working_dir = test_subdir
    else:
        working_dir = main_dir


    # Added in because some machines will need os.environ else there will be
    # false failures in subprocess mode. Same issue as python2.6. Needs some
    # env vars.

    test_env = os.environ

    fmt1 = '%s.%%s' % test_mods_pkg_name
    fmt2 = '%s.%%s_test' % test_mods_pkg_name
    if args:
        test_modules = [
            m.endswith('_test') and (fmt1 % m) or (fmt2 % m) for m in args
        ]
    else:
        test_modules = []
        for f in sorted(os.listdir(test_subdir)):
            for match in TEST_MODULE_RE.findall(f):
                test_modules.append(fmt1 % (match,))

    ###########################################################################
    # Remove modules to be excluded.

    tmp = test_modules
    test_modules = []
    for name in tmp:
        tag_module_name = "%s_tags" % (name[0:-5],)
        try:
            tag_module = import_submodule(tag_module_name)
        except ImportError:
            test_modules.append(name)
        else:
            try:
                tags = tag_module.__tags__
            except AttributeError:
                print ("%s has no tags: ignoring" % (tag_module_name,))
                test_module.append(name)
            else:
                for tag in tags:
                    if tag in option_exclude:
                        print ("skipping %s (tag '%s')" % (name, tag))
                        break
                else:
                    test_modules.append(name)
    del tmp, tag_module_name, name

    ###########################################################################
    # Meta results

    results = {}
    meta_results = {'__meta__' : {}}
    meta = meta_results['__meta__']

    ###########################################################################
    # Randomization

    if option_randomize or option_seed is not None:
        if option_seed is None:
            option_seed = time.time()
        meta['random_seed'] = option_seed
        print ("\nRANDOM SEED USED: %s\n" % option_seed)
        random.seed(option_seed)
        random.shuffle(test_modules)

    ###########################################################################
    # Single process mode

    if option_nosubprocess:
        unittest_patch.patch(**options)

        options['exclude'] = option_exclude
        t = time.time()
        for module in test_modules:
            results.update(run_test(module, **options))
        t = time.time() - t

    ###########################################################################
    # Subprocess mode
    #

    if not option_nosubprocess:
        if is_pygame_pkg:
            from pygame.tests.test_utils.async_sub import proc_in_time_or_kill
        else:
            from test.test_utils.async_sub import proc_in_time_or_kill

        pass_on_args = ['--exclude', ','.join(option_exclude)]
        for option in ['timings', 'seed']:
            value = options.pop(option, None)
            if value is not None:
                pass_on_args.append('--%s' % option)
                pass_on_args.append(str(value))
        for option, value in options.items():
            if value:
                pass_on_args.append('--%s' % option)

        def sub_test(module):
            print ('loading %s' % module)

            cmd = [option_python, test_runner_py, module ] + pass_on_args

            return (module,
                    (cmd, test_env, working_dir),
                    proc_in_time_or_kill(cmd, option_time_out, env=test_env,
                                         wd=working_dir))

        if option_multi_thread > 1:
            def tmap(f, args):
                return pygame.threads.tmap (
                    f, args, stop_on_error = False,
                    num_workers = option_multi_thread
                )
        else:
            tmap = map

        t = time.time()

        for module, cmd, (return_code, raw_return) in tmap(sub_test,
                                                           test_modules):
            test_file = '%s.py' % os.path.join(test_subdir, module)
            cmd, test_env, working_dir = cmd

            test_results = get_test_results(raw_return)
            if test_results:
                results.update(test_results)
            else:
                results[module] = {}

            add_to_results = [
                'return_code', 'raw_return',  'cmd', 'test_file',
                'test_env', 'working_dir', 'module',
            ]

            results[module].update(from_namespace(locals(), add_to_results))

        t = time.time() - t

    ###########################################################################
    # Output Results
    #

    untrusty_total, combined = combine_results(results, t)
    total, fails = test_failures(results)

    meta['total_tests'] = total
    meta['combined'] = combined
    results.update(meta_results)

    if option_nosubprocess:
        assert total == untrusty_total

    if not option_dump:
        print (combined)
    else:
        results = option_all and results or fails
        print (TEST_RESULTS_START)
        print (pformat(results))

    if option_file is not None:
        results_file = open(option_file, 'w')
        try:
            results_file.write(pformat(results))
        finally:
            results_file.close()

    return total, fails

Example 14

Project: youtube-dl
Source File: options.py
View license
def parseOpts(overrideArguments=None):
    def _readOptions(filename_bytes, default=[]):
        try:
            optionf = open(filename_bytes)
        except IOError:
            return default  # silently skip if file is not present
        try:
            # FIXME: https://github.com/rg3/youtube-dl/commit/dfe5fa49aed02cf36ba9f743b11b0903554b5e56
            contents = optionf.read()
            if sys.version_info < (3,):
                contents = contents.decode(preferredencoding())
            res = compat_shlex_split(contents, comments=True)
        finally:
            optionf.close()
        return res

    def _readUserConf():
        xdg_config_home = compat_getenv('XDG_CONFIG_HOME')
        if xdg_config_home:
            userConfFile = os.path.join(xdg_config_home, 'youtube-dl', 'config')
            if not os.path.isfile(userConfFile):
                userConfFile = os.path.join(xdg_config_home, 'youtube-dl.conf')
        else:
            userConfFile = os.path.join(compat_expanduser('~'), '.config', 'youtube-dl', 'config')
            if not os.path.isfile(userConfFile):
                userConfFile = os.path.join(compat_expanduser('~'), '.config', 'youtube-dl.conf')
        userConf = _readOptions(userConfFile, None)

        if userConf is None:
            appdata_dir = compat_getenv('appdata')
            if appdata_dir:
                userConf = _readOptions(
                    os.path.join(appdata_dir, 'youtube-dl', 'config'),
                    default=None)
                if userConf is None:
                    userConf = _readOptions(
                        os.path.join(appdata_dir, 'youtube-dl', 'config.txt'),
                        default=None)

        if userConf is None:
            userConf = _readOptions(
                os.path.join(compat_expanduser('~'), 'youtube-dl.conf'),
                default=None)
        if userConf is None:
            userConf = _readOptions(
                os.path.join(compat_expanduser('~'), 'youtube-dl.conf.txt'),
                default=None)

        if userConf is None:
            userConf = []

        return userConf

    def _format_option_string(option):
        ''' ('-o', '--option') -> -o, --format METAVAR'''

        opts = []

        if option._short_opts:
            opts.append(option._short_opts[0])
        if option._long_opts:
            opts.append(option._long_opts[0])
        if len(opts) > 1:
            opts.insert(1, ', ')

        if option.takes_value():
            opts.append(' %s' % option.metavar)

        return ''.join(opts)

    def _comma_separated_values_options_callback(option, opt_str, value, parser):
        setattr(parser.values, option.dest, value.split(','))

    def _hide_login_info(opts):
        PRIVATE_OPTS = ['-p', '--password', '-u', '--username', '--video-password', '--ap-password', '--ap-username']
        eqre = re.compile('^(?P<key>' + ('|'.join(re.escape(po) for po in PRIVATE_OPTS)) + ')=.+$')

        def _scrub_eq(o):
            m = eqre.match(o)
            if m:
                return m.group('key') + '=PRIVATE'
            else:
                return o

        opts = list(map(_scrub_eq, opts))
        for private_opt in PRIVATE_OPTS:
            try:
                i = opts.index(private_opt)
                opts[i + 1] = 'PRIVATE'
            except ValueError:
                pass
        return opts

    # No need to wrap help messages if we're on a wide console
    columns = compat_get_terminal_size().columns
    max_width = columns if columns else 80
    max_help_position = 80

    fmt = optparse.IndentedHelpFormatter(width=max_width, max_help_position=max_help_position)
    fmt.format_option_strings = _format_option_string

    kw = {
        'version': __version__,
        'formatter': fmt,
        'usage': '%prog [OPTIONS] URL [URL...]',
        'conflict_handler': 'resolve',
    }

    parser = optparse.OptionParser(**compat_kwargs(kw))

    general = optparse.OptionGroup(parser, 'General Options')
    general.add_option(
        '-h', '--help',
        action='help',
        help='Print this help text and exit')
    general.add_option(
        '-v', '--version',
        action='version',
        help='Print program version and exit')
    general.add_option(
        '-U', '--update',
        action='store_true', dest='update_self',
        help='Update this program to latest version. Make sure that you have sufficient permissions (run with sudo if needed)')
    general.add_option(
        '-i', '--ignore-errors',
        action='store_true', dest='ignoreerrors', default=False,
        help='Continue on download errors, for example to skip unavailable videos in a playlist')
    general.add_option(
        '--abort-on-error',
        action='store_false', dest='ignoreerrors',
        help='Abort downloading of further videos (in the playlist or the command line) if an error occurs')
    general.add_option(
        '--dump-user-agent',
        action='store_true', dest='dump_user_agent', default=False,
        help='Display the current browser identification')
    general.add_option(
        '--list-extractors',
        action='store_true', dest='list_extractors', default=False,
        help='List all supported extractors')
    general.add_option(
        '--extractor-descriptions',
        action='store_true', dest='list_extractor_descriptions', default=False,
        help='Output descriptions of all supported extractors')
    general.add_option(
        '--force-generic-extractor',
        action='store_true', dest='force_generic_extractor', default=False,
        help='Force extraction to use the generic extractor')
    general.add_option(
        '--default-search',
        dest='default_search', metavar='PREFIX',
        help='Use this prefix for unqualified URLs. For example "gvsearch2:" downloads two videos from google videos for youtube-dl "large apple". Use the value "auto" to let youtube-dl guess ("auto_warning" to emit a warning when guessing). "error" just throws an error. The default value "fixup_error" repairs broken URLs, but emits an error if this is not possible instead of searching.')
    general.add_option(
        '--ignore-config',
        action='store_true',
        help='Do not read configuration files. '
        'When given in the global configuration file /etc/youtube-dl.conf: '
        'Do not read the user configuration in ~/.config/youtube-dl/config '
        '(%APPDATA%/youtube-dl/config.txt on Windows)')
    general.add_option(
        '--flat-playlist',
        action='store_const', dest='extract_flat', const='in_playlist',
        default=False,
        help='Do not extract the videos of a playlist, only list them.')
    general.add_option(
        '--mark-watched',
        action='store_true', dest='mark_watched', default=False,
        help='Mark videos watched (YouTube only)')
    general.add_option(
        '--no-mark-watched',
        action='store_false', dest='mark_watched', default=False,
        help='Do not mark videos watched (YouTube only)')
    general.add_option(
        '--no-color', '--no-colors',
        action='store_true', dest='no_color',
        default=False,
        help='Do not emit color codes in output')

    network = optparse.OptionGroup(parser, 'Network Options')
    network.add_option(
        '--proxy', dest='proxy',
        default=None, metavar='URL',
        help='Use the specified HTTP/HTTPS/SOCKS proxy. To enable experimental '
             'SOCKS proxy, specify a proper scheme. For example '
             'socks5://127.0.0.1:1080/. Pass in an empty string (--proxy "") '
             'for direct connection')
    network.add_option(
        '--socket-timeout',
        dest='socket_timeout', type=float, default=None, metavar='SECONDS',
        help='Time to wait before giving up, in seconds')
    network.add_option(
        '--source-address',
        metavar='IP', dest='source_address', default=None,
        help='Client-side IP address to bind to (experimental)',
    )
    network.add_option(
        '-4', '--force-ipv4',
        action='store_const', const='0.0.0.0', dest='source_address',
        help='Make all connections via IPv4 (experimental)',
    )
    network.add_option(
        '-6', '--force-ipv6',
        action='store_const', const='::', dest='source_address',
        help='Make all connections via IPv6 (experimental)',
    )
    network.add_option(
        '--geo-verification-proxy',
        dest='geo_verification_proxy', default=None, metavar='URL',
        help='Use this proxy to verify the IP address for some geo-restricted sites. '
        'The default proxy specified by --proxy (or none, if the options is not present) is used for the actual downloading. (experimental)'
    )
    network.add_option(
        '--cn-verification-proxy',
        dest='cn_verification_proxy', default=None, metavar='URL',
        help=optparse.SUPPRESS_HELP,
    )

    selection = optparse.OptionGroup(parser, 'Video Selection')
    selection.add_option(
        '--playlist-start',
        dest='playliststart', metavar='NUMBER', default=1, type=int,
        help='Playlist video to start at (default is %default)')
    selection.add_option(
        '--playlist-end',
        dest='playlistend', metavar='NUMBER', default=None, type=int,
        help='Playlist video to end at (default is last)')
    selection.add_option(
        '--playlist-items',
        dest='playlist_items', metavar='ITEM_SPEC', default=None,
        help='Playlist video items to download. Specify indices of the videos in the playlist separated by commas like: "--playlist-items 1,2,5,8" if you want to download videos indexed 1, 2, 5, 8 in the playlist. You can specify range: "--playlist-items 1-3,7,10-13", it will download the videos at index 1, 2, 3, 7, 10, 11, 12 and 13.')
    selection.add_option(
        '--match-title',
        dest='matchtitle', metavar='REGEX',
        help='Download only matching titles (regex or caseless sub-string)')
    selection.add_option(
        '--reject-title',
        dest='rejecttitle', metavar='REGEX',
        help='Skip download for matching titles (regex or caseless sub-string)')
    selection.add_option(
        '--max-downloads',
        dest='max_downloads', metavar='NUMBER', type=int, default=None,
        help='Abort after downloading NUMBER files')
    selection.add_option(
        '--min-filesize',
        metavar='SIZE', dest='min_filesize', default=None,
        help='Do not download any videos smaller than SIZE (e.g. 50k or 44.6m)')
    selection.add_option(
        '--max-filesize',
        metavar='SIZE', dest='max_filesize', default=None,
        help='Do not download any videos larger than SIZE (e.g. 50k or 44.6m)')
    selection.add_option(
        '--date',
        metavar='DATE', dest='date', default=None,
        help='Download only videos uploaded in this date')
    selection.add_option(
        '--datebefore',
        metavar='DATE', dest='datebefore', default=None,
        help='Download only videos uploaded on or before this date (i.e. inclusive)')
    selection.add_option(
        '--dateafter',
        metavar='DATE', dest='dateafter', default=None,
        help='Download only videos uploaded on or after this date (i.e. inclusive)')
    selection.add_option(
        '--min-views',
        metavar='COUNT', dest='min_views', default=None, type=int,
        help='Do not download any videos with less than COUNT views')
    selection.add_option(
        '--max-views',
        metavar='COUNT', dest='max_views', default=None, type=int,
        help='Do not download any videos with more than COUNT views')
    selection.add_option(
        '--match-filter',
        metavar='FILTER', dest='match_filter', default=None,
        help=(
            'Generic video filter (experimental). '
            'Specify any key (see help for -o for a list of available keys) to'
            ' match if the key is present, '
            '!key to check if the key is not present,'
            'key > NUMBER (like "comment_count > 12", also works with '
            '>=, <, <=, !=, =) to compare against a number, and '
            '& to require multiple matches. '
            'Values which are not known are excluded unless you'
            ' put a question mark (?) after the operator.'
            'For example, to only match videos that have been liked more than '
            '100 times and disliked less than 50 times (or the dislike '
            'functionality is not available at the given service), but who '
            'also have a description, use --match-filter '
            '"like_count > 100 & dislike_count <? 50 & description" .'
        ))
    selection.add_option(
        '--no-playlist',
        action='store_true', dest='noplaylist', default=False,
        help='Download only the video, if the URL refers to a video and a playlist.')
    selection.add_option(
        '--yes-playlist',
        action='store_false', dest='noplaylist', default=False,
        help='Download the playlist, if the URL refers to a video and a playlist.')
    selection.add_option(
        '--age-limit',
        metavar='YEARS', dest='age_limit', default=None, type=int,
        help='Download only videos suitable for the given age')
    selection.add_option(
        '--download-archive', metavar='FILE',
        dest='download_archive',
        help='Download only videos not listed in the archive file. Record the IDs of all downloaded videos in it.')
    selection.add_option(
        '--include-ads',
        dest='include_ads', action='store_true',
        help='Download advertisements as well (experimental)')

    authentication = optparse.OptionGroup(parser, 'Authentication Options')
    authentication.add_option(
        '-u', '--username',
        dest='username', metavar='USERNAME',
        help='Login with this account ID')
    authentication.add_option(
        '-p', '--password',
        dest='password', metavar='PASSWORD',
        help='Account password. If this option is left out, youtube-dl will ask interactively.')
    authentication.add_option(
        '-2', '--twofactor',
        dest='twofactor', metavar='TWOFACTOR',
        help='Two-factor auth code')
    authentication.add_option(
        '-n', '--netrc',
        action='store_true', dest='usenetrc', default=False,
        help='Use .netrc authentication data')
    authentication.add_option(
        '--video-password',
        dest='videopassword', metavar='PASSWORD',
        help='Video password (vimeo, smotri, youku)')

    adobe_pass = optparse.OptionGroup(parser, 'Adobe Pass Options')
    adobe_pass.add_option(
        '--ap-mso',
        dest='ap_mso', metavar='MSO',
        help='Adobe Pass multiple-system operator (TV provider) identifier, use --ap-list-mso for a list of available MSOs')
    adobe_pass.add_option(
        '--ap-username',
        dest='ap_username', metavar='USERNAME',
        help='Multiple-system operator account login')
    adobe_pass.add_option(
        '--ap-password',
        dest='ap_password', metavar='PASSWORD',
        help='Multiple-system operator account password. If this option is left out, youtube-dl will ask interactively.')
    adobe_pass.add_option(
        '--ap-list-mso',
        action='store_true', dest='ap_list_mso', default=False,
        help='List all supported multiple-system operators')

    video_format = optparse.OptionGroup(parser, 'Video Format Options')
    video_format.add_option(
        '-f', '--format',
        action='store', dest='format', metavar='FORMAT', default=None,
        help='Video format code, see the "FORMAT SELECTION" for all the info')
    video_format.add_option(
        '--all-formats',
        action='store_const', dest='format', const='all',
        help='Download all available video formats')
    video_format.add_option(
        '--prefer-free-formats',
        action='store_true', dest='prefer_free_formats', default=False,
        help='Prefer free video formats unless a specific one is requested')
    video_format.add_option(
        '-F', '--list-formats',
        action='store_true', dest='listformats',
        help='List all available formats of requested videos')
    video_format.add_option(
        '--youtube-include-dash-manifest',
        action='store_true', dest='youtube_include_dash_manifest', default=True,
        help=optparse.SUPPRESS_HELP)
    video_format.add_option(
        '--youtube-skip-dash-manifest',
        action='store_false', dest='youtube_include_dash_manifest',
        help='Do not download the DASH manifests and related data on YouTube videos')
    video_format.add_option(
        '--merge-output-format',
        action='store', dest='merge_output_format', metavar='FORMAT', default=None,
        help=(
            'If a merge is required (e.g. bestvideo+bestaudio), '
            'output to given container format. One of mkv, mp4, ogg, webm, flv. '
            'Ignored if no merge is required'))

    subtitles = optparse.OptionGroup(parser, 'Subtitle Options')
    subtitles.add_option(
        '--write-sub', '--write-srt',
        action='store_true', dest='writesubtitles', default=False,
        help='Write subtitle file')
    subtitles.add_option(
        '--write-auto-sub', '--write-automatic-sub',
        action='store_true', dest='writeautomaticsub', default=False,
        help='Write automatically generated subtitle file (YouTube only)')
    subtitles.add_option(
        '--all-subs',
        action='store_true', dest='allsubtitles', default=False,
        help='Download all the available subtitles of the video')
    subtitles.add_option(
        '--list-subs',
        action='store_true', dest='listsubtitles', default=False,
        help='List all available subtitles for the video')
    subtitles.add_option(
        '--sub-format',
        action='store', dest='subtitlesformat', metavar='FORMAT', default='best',
        help='Subtitle format, accepts formats preference, for example: "srt" or "ass/srt/best"')
    subtitles.add_option(
        '--sub-lang', '--sub-langs', '--srt-lang',
        action='callback', dest='subtitleslangs', metavar='LANGS', type='str',
        default=[], callback=_comma_separated_values_options_callback,
        help='Languages of the subtitles to download (optional) separated by commas, use --list-subs for available language tags')

    downloader = optparse.OptionGroup(parser, 'Download Options')
    downloader.add_option(
        '-r', '--limit-rate', '--rate-limit',
        dest='ratelimit', metavar='RATE',
        help='Maximum download rate in bytes per second (e.g. 50K or 4.2M)')
    downloader.add_option(
        '-R', '--retries',
        dest='retries', metavar='RETRIES', default=10,
        help='Number of retries (default is %default), or "infinite".')
    downloader.add_option(
        '--fragment-retries',
        dest='fragment_retries', metavar='RETRIES', default=10,
        help='Number of retries for a fragment (default is %default), or "infinite" (DASH and hlsnative only)')
    downloader.add_option(
        '--skip-unavailable-fragments',
        action='store_true', dest='skip_unavailable_fragments', default=True,
        help='Skip unavailable fragments (DASH and hlsnative only)')
    general.add_option(
        '--abort-on-unavailable-fragment',
        action='store_false', dest='skip_unavailable_fragments',
        help='Abort downloading when some fragment is not available')
    downloader.add_option(
        '--buffer-size',
        dest='buffersize', metavar='SIZE', default='1024',
        help='Size of download buffer (e.g. 1024 or 16K) (default is %default)')
    downloader.add_option(
        '--no-resize-buffer',
        action='store_true', dest='noresizebuffer', default=False,
        help='Do not automatically adjust the buffer size. By default, the buffer size is automatically resized from an initial value of SIZE.')
    downloader.add_option(
        '--test',
        action='store_true', dest='test', default=False,
        help=optparse.SUPPRESS_HELP)
    downloader.add_option(
        '--playlist-reverse',
        action='store_true',
        help='Download playlist videos in reverse order')
    downloader.add_option(
        '--xattr-set-filesize',
        dest='xattr_set_filesize', action='store_true',
        help='Set file xattribute ytdl.filesize with expected filesize (experimental)')
    downloader.add_option(
        '--hls-prefer-native',
        dest='hls_prefer_native', action='store_true', default=None,
        help='Use the native HLS downloader instead of ffmpeg')
    downloader.add_option(
        '--hls-prefer-ffmpeg',
        dest='hls_prefer_native', action='store_false', default=None,
        help='Use ffmpeg instead of the native HLS downloader')
    downloader.add_option(
        '--hls-use-mpegts',
        dest='hls_use_mpegts', action='store_true',
        help='Use the mpegts container for HLS videos, allowing to play the '
             'video while downloading (some players may not be able to play it)')
    downloader.add_option(
        '--external-downloader',
        dest='external_downloader', metavar='COMMAND',
        help='Use the specified external downloader. '
             'Currently supports %s' % ','.join(list_external_downloaders()))
    downloader.add_option(
        '--external-downloader-args',
        dest='external_downloader_args', metavar='ARGS',
        help='Give these arguments to the external downloader')

    workarounds = optparse.OptionGroup(parser, 'Workarounds')
    workarounds.add_option(
        '--encoding',
        dest='encoding', metavar='ENCODING',
        help='Force the specified encoding (experimental)')
    workarounds.add_option(
        '--no-check-certificate',
        action='store_true', dest='no_check_certificate', default=False,
        help='Suppress HTTPS certificate validation')
    workarounds.add_option(
        '--prefer-insecure',
        '--prefer-unsecure', action='store_true', dest='prefer_insecure',
        help='Use an unencrypted connection to retrieve information about the video. (Currently supported only for YouTube)')
    workarounds.add_option(
        '--user-agent',
        metavar='UA', dest='user_agent',
        help='Specify a custom user agent')
    workarounds.add_option(
        '--referer',
        metavar='URL', dest='referer', default=None,
        help='Specify a custom referer, use if the video access is restricted to one domain',
    )
    workarounds.add_option(
        '--add-header',
        metavar='FIELD:VALUE', dest='headers', action='append',
        help='Specify a custom HTTP header and its value, separated by a colon \':\'. You can use this option multiple times',
    )
    workarounds.add_option(
        '--bidi-workaround',
        dest='bidi_workaround', action='store_true',
        help='Work around terminals that lack bidirectional text support. Requires bidiv or fribidi executable in PATH')
    workarounds.add_option(
        '--sleep-interval', '--min-sleep-interval', metavar='SECONDS',
        dest='sleep_interval', type=float,
        help=(
            'Number of seconds to sleep before each download when used alone '
            'or a lower bound of a range for randomized sleep before each download '
            '(minimum possible number of seconds to sleep) when used along with '
            '--max-sleep-interval.'))
    workarounds.add_option(
        '--max-sleep-interval', metavar='SECONDS',
        dest='max_sleep_interval', type=float,
        help=(
            'Upper bound of a range for randomized sleep before each download '
            '(maximum possible number of seconds to sleep). Must only be used '
            'along with --min-sleep-interval.'))

    verbosity = optparse.OptionGroup(parser, 'Verbosity / Simulation Options')
    verbosity.add_option(
        '-q', '--quiet',
        action='store_true', dest='quiet', default=False,
        help='Activate quiet mode')
    verbosity.add_option(
        '--no-warnings',
        dest='no_warnings', action='store_true', default=False,
        help='Ignore warnings')
    verbosity.add_option(
        '-s', '--simulate',
        action='store_true', dest='simulate', default=False,
        help='Do not download the video and do not write anything to disk')
    verbosity.add_option(
        '--skip-download',
        action='store_true', dest='skip_download', default=False,
        help='Do not download the video')
    verbosity.add_option(
        '-g', '--get-url',
        action='store_true', dest='geturl', default=False,
        help='Simulate, quiet but print URL')
    verbosity.add_option(
        '-e', '--get-title',
        action='store_true', dest='gettitle', default=False,
        help='Simulate, quiet but print title')
    verbosity.add_option(
        '--get-id',
        action='store_true', dest='getid', default=False,
        help='Simulate, quiet but print id')
    verbosity.add_option(
        '--get-thumbnail',
        action='store_true', dest='getthumbnail', default=False,
        help='Simulate, quiet but print thumbnail URL')
    verbosity.add_option(
        '--get-description',
        action='store_true', dest='getdescription', default=False,
        help='Simulate, quiet but print video description')
    verbosity.add_option(
        '--get-duration',
        action='store_true', dest='getduration', default=False,
        help='Simulate, quiet but print video length')
    verbosity.add_option(
        '--get-filename',
        action='store_true', dest='getfilename', default=False,
        help='Simulate, quiet but print output filename')
    verbosity.add_option(
        '--get-format',
        action='store_true', dest='getformat', default=False,
        help='Simulate, quiet but print output format')
    verbosity.add_option(
        '-j', '--dump-json',
        action='store_true', dest='dumpjson', default=False,
        help='Simulate, quiet but print JSON information. See --output for a description of available keys.')
    verbosity.add_option(
        '-J', '--dump-single-json',
        action='store_true', dest='dump_single_json', default=False,
        help='Simulate, quiet but print JSON information for each command-line argument. If the URL refers to a playlist, dump the whole playlist information in a single line.')
    verbosity.add_option(
        '--print-json',
        action='store_true', dest='print_json', default=False,
        help='Be quiet and print the video information as JSON (video is still being downloaded).',
    )
    verbosity.add_option(
        '--newline',
        action='store_true', dest='progress_with_newline', default=False,
        help='Output progress bar as new lines')
    verbosity.add_option(
        '--no-progress',
        action='store_true', dest='noprogress', default=False,
        help='Do not print progress bar')
    verbosity.add_option(
        '--console-title',
        action='store_true', dest='consoletitle', default=False,
        help='Display progress in console titlebar')
    verbosity.add_option(
        '-v', '--verbose',
        action='store_true', dest='verbose', default=False,
        help='Print various debugging information')
    verbosity.add_option(
        '--dump-pages', '--dump-intermediate-pages',
        action='store_true', dest='dump_intermediate_pages', default=False,
        help='Print downloaded pages encoded using base64 to debug problems (very verbose)')
    verbosity.add_option(
        '--write-pages',
        action='store_true', dest='write_pages', default=False,
        help='Write downloaded intermediary pages to files in the current directory to debug problems')
    verbosity.add_option(
        '--youtube-print-sig-code',
        action='store_true', dest='youtube_print_sig_code', default=False,
        help=optparse.SUPPRESS_HELP)
    verbosity.add_option(
        '--print-traffic', '--dump-headers',
        dest='debug_printtraffic', action='store_true', default=False,
        help='Display sent and read HTTP traffic')
    verbosity.add_option(
        '-C', '--call-home',
        dest='call_home', action='store_true', default=False,
        help='Contact the youtube-dl server for debugging')
    verbosity.add_option(
        '--no-call-home',
        dest='call_home', action='store_false', default=False,
        help='Do NOT contact the youtube-dl server for debugging')

    filesystem = optparse.OptionGroup(parser, 'Filesystem Options')
    filesystem.add_option(
        '-a', '--batch-file',
        dest='batchfile', metavar='FILE',
        help='File containing URLs to download (\'-\' for stdin)')
    filesystem.add_option(
        '--id', default=False,
        action='store_true', dest='useid', help='Use only video ID in file name')
    filesystem.add_option(
        '-o', '--output',
        dest='outtmpl', metavar='TEMPLATE',
        help=('Output filename template, see the "OUTPUT TEMPLATE" for all the info'))
    filesystem.add_option(
        '--autonumber-size',
        dest='autonumber_size', metavar='NUMBER',
        help='Specify the number of digits in %(autonumber)s when it is present in output filename template or --auto-number option is given')
    filesystem.add_option(
        '--restrict-filenames',
        action='store_true', dest='restrictfilenames', default=False,
        help='Restrict filenames to only ASCII characters, and avoid "&" and spaces in filenames')
    filesystem.add_option(
        '-A', '--auto-number',
        action='store_true', dest='autonumber', default=False,
        help='[deprecated; use -o "%(autonumber)s-%(title)s.%(ext)s" ] Number downloaded files starting from 00000')
    filesystem.add_option(
        '-t', '--title',
        action='store_true', dest='usetitle', default=False,
        help='[deprecated] Use title in file name (default)')
    filesystem.add_option(
        '-l', '--literal', default=False,
        action='store_true', dest='usetitle',
        help='[deprecated] Alias of --title')
    filesystem.add_option(
        '-w', '--no-overwrites',
        action='store_true', dest='nooverwrites', default=False,
        help='Do not overwrite files')
    filesystem.add_option(
        '-c', '--continue',
        action='store_true', dest='continue_dl', default=True,
        help='Force resume of partially downloaded files. By default, youtube-dl will resume downloads if possible.')
    filesystem.add_option(
        '--no-continue',
        action='store_false', dest='continue_dl',
        help='Do not resume partially downloaded files (restart from beginning)')
    filesystem.add_option(
        '--no-part',
        action='store_true', dest='nopart', default=False,
        help='Do not use .part files - write directly into output file')
    filesystem.add_option(
        '--no-mtime',
        action='store_false', dest='updatetime', default=True,
        help='Do not use the Last-modified header to set the file modification time')
    filesystem.add_option(
        '--write-description',
        action='store_true', dest='writedescription', default=False,
        help='Write video description to a .description file')
    filesystem.add_option(
        '--write-info-json',
        action='store_true', dest='writeinfojson', default=False,
        help='Write video metadata to a .info.json file')
    filesystem.add_option(
        '--write-annotations',
        action='store_true', dest='writeannotations', default=False,
        help='Write video annotations to a .annotations.xml file')
    filesystem.add_option(
        '--load-info-json', '--load-info',
        dest='load_info_filename', metavar='FILE',
        help='JSON file containing the video information (created with the "--write-info-json" option)')
    filesystem.add_option(
        '--cookies',
        dest='cookiefile', metavar='FILE',
        help='File to read cookies from and dump cookie jar in')
    filesystem.add_option(
        '--cache-dir', dest='cachedir', default=None, metavar='DIR',
        help='Location in the filesystem where youtube-dl can store some downloaded information permanently. By default $XDG_CACHE_HOME/youtube-dl or ~/.cache/youtube-dl . At the moment, only YouTube player files (for videos with obfuscated signatures) are cached, but that may change.')
    filesystem.add_option(
        '--no-cache-dir', action='store_const', const=False, dest='cachedir',
        help='Disable filesystem caching')
    filesystem.add_option(
        '--rm-cache-dir',
        action='store_true', dest='rm_cachedir',
        help='Delete all filesystem cache files')

    thumbnail = optparse.OptionGroup(parser, 'Thumbnail images')
    thumbnail.add_option(
        '--write-thumbnail',
        action='store_true', dest='writethumbnail', default=False,
        help='Write thumbnail image to disk')
    thumbnail.add_option(
        '--write-all-thumbnails',
        action='store_true', dest='write_all_thumbnails', default=False,
        help='Write all thumbnail image formats to disk')
    thumbnail.add_option(
        '--list-thumbnails',
        action='store_true', dest='list_thumbnails', default=False,
        help='Simulate and list all available thumbnail formats')

    postproc = optparse.OptionGroup(parser, 'Post-processing Options')
    postproc.add_option(
        '-x', '--extract-audio',
        action='store_true', dest='extractaudio', default=False,
        help='Convert video files to audio-only files (requires ffmpeg or avconv and ffprobe or avprobe)')
    postproc.add_option(
        '--audio-format', metavar='FORMAT', dest='audioformat', default='best',
        help='Specify audio format: "best", "aac", "vorbis", "mp3", "m4a", "opus", or "wav"; "%default" by default')
    postproc.add_option(
        '--audio-quality', metavar='QUALITY',
        dest='audioquality', default='5',
        help='Specify ffmpeg/avconv audio quality, insert a value between 0 (better) and 9 (worse) for VBR or a specific bitrate like 128K (default %default)')
    postproc.add_option(
        '--recode-video',
        metavar='FORMAT', dest='recodevideo', default=None,
        help='Encode the video to another format if necessary (currently supported: mp4|flv|ogg|webm|mkv|avi)')
    postproc.add_option(
        '--postprocessor-args',
        dest='postprocessor_args', metavar='ARGS',
        help='Give these arguments to the postprocessor')
    postproc.add_option(
        '-k', '--keep-video',
        action='store_true', dest='keepvideo', default=False,
        help='Keep the video file on disk after the post-processing; the video is erased by default')
    postproc.add_option(
        '--no-post-overwrites',
        action='store_true', dest='nopostoverwrites', default=False,
        help='Do not overwrite post-processed files; the post-processed files are overwritten by default')
    postproc.add_option(
        '--embed-subs',
        action='store_true', dest='embedsubtitles', default=False,
        help='Embed subtitles in the video (only for mp4, webm and mkv videos)')
    postproc.add_option(
        '--embed-thumbnail',
        action='store_true', dest='embedthumbnail', default=False,
        help='Embed thumbnail in the audio as cover art')
    postproc.add_option(
        '--add-metadata',
        action='store_true', dest='addmetadata', default=False,
        help='Write metadata to the video file')
    postproc.add_option(
        '--metadata-from-title',
        metavar='FORMAT', dest='metafromtitle',
        help='Parse additional metadata like song title / artist from the video title. '
             'The format syntax is the same as --output, '
             'the parsed parameters replace existing values. '
             'Additional templates: %(album)s, %(artist)s. '
             'Example: --metadata-from-title "%(artist)s - %(title)s" matches a title like '
             '"Coldplay - Paradise"')
    postproc.add_option(
        '--xattrs',
        action='store_true', dest='xattrs', default=False,
        help='Write metadata to the video file\'s xattrs (using dublin core and xdg standards)')
    postproc.add_option(
        '--fixup',
        metavar='POLICY', dest='fixup', default='detect_or_warn',
        help='Automatically correct known faults of the file. '
             'One of never (do nothing), warn (only emit a warning), '
             'detect_or_warn (the default; fix file if we can, warn otherwise)')
    postproc.add_option(
        '--prefer-avconv',
        action='store_false', dest='prefer_ffmpeg',
        help='Prefer avconv over ffmpeg for running the postprocessors (default)')
    postproc.add_option(
        '--prefer-ffmpeg',
        action='store_true', dest='prefer_ffmpeg',
        help='Prefer ffmpeg over avconv for running the postprocessors')
    postproc.add_option(
        '--ffmpeg-location', '--avconv-location', metavar='PATH',
        dest='ffmpeg_location',
        help='Location of the ffmpeg/avconv binary; either the path to the binary or its containing directory.')
    postproc.add_option(
        '--exec',
        metavar='CMD', dest='exec_cmd',
        help='Execute a command on the file after downloading, similar to find\'s -exec syntax. Example: --exec \'adb push {} /sdcard/Music/ && rm {}\'')
    postproc.add_option(
        '--convert-subs', '--convert-subtitles',
        metavar='FORMAT', dest='convertsubtitles', default=None,
        help='Convert the subtitles to other format (currently supported: srt|ass|vtt)')

    parser.add_option_group(general)
    parser.add_option_group(network)
    parser.add_option_group(selection)
    parser.add_option_group(downloader)
    parser.add_option_group(filesystem)
    parser.add_option_group(thumbnail)
    parser.add_option_group(verbosity)
    parser.add_option_group(workarounds)
    parser.add_option_group(video_format)
    parser.add_option_group(subtitles)
    parser.add_option_group(authentication)
    parser.add_option_group(adobe_pass)
    parser.add_option_group(postproc)

    if overrideArguments is not None:
        opts, args = parser.parse_args(overrideArguments)
        if opts.verbose:
            write_string('[debug] Override config: ' + repr(overrideArguments) + '\n')
    else:
        def compat_conf(conf):
            if sys.version_info < (3,):
                return [a.decode(preferredencoding(), 'replace') for a in conf]
            return conf

        command_line_conf = compat_conf(sys.argv[1:])

        if '--ignore-config' in command_line_conf:
            system_conf = []
            user_conf = []
        else:
            system_conf = _readOptions('/etc/youtube-dl.conf')
            if '--ignore-config' in system_conf:
                user_conf = []
            else:
                user_conf = _readUserConf()
        argv = system_conf + user_conf + command_line_conf

        opts, args = parser.parse_args(argv)
        if opts.verbose:
            write_string('[debug] System config: ' + repr(_hide_login_info(system_conf)) + '\n')
            write_string('[debug] User config: ' + repr(_hide_login_info(user_conf)) + '\n')
            write_string('[debug] Command-line args: ' + repr(_hide_login_info(command_line_conf)) + '\n')

    return parser, opts, args

Example 15

Project: script.module.youtube.dl
Source File: options.py
View license
def parseOpts(overrideArguments=None):
    def _readOptions(filename_bytes, default=[]):
        try:
            optionf = open(filename_bytes)
        except IOError:
            return default  # silently skip if file is not present
        try:
            # FIXME: https://github.com/rg3/youtube-dl/commit/dfe5fa49aed02cf36ba9f743b11b0903554b5e56
            contents = optionf.read()
            if sys.version_info < (3,):
                contents = contents.decode(preferredencoding())
            res = compat_shlex_split(contents, comments=True)
        finally:
            optionf.close()
        return res

    def _readUserConf():
        xdg_config_home = compat_getenv('XDG_CONFIG_HOME')
        if xdg_config_home:
            userConfFile = os.path.join(xdg_config_home, 'youtube-dl', 'config')
            if not os.path.isfile(userConfFile):
                userConfFile = os.path.join(xdg_config_home, 'youtube-dl.conf')
        else:
            userConfFile = os.path.join(compat_expanduser('~'), '.config', 'youtube-dl', 'config')
            if not os.path.isfile(userConfFile):
                userConfFile = os.path.join(compat_expanduser('~'), '.config', 'youtube-dl.conf')
        userConf = _readOptions(userConfFile, None)

        if userConf is None:
            appdata_dir = compat_getenv('appdata')
            if appdata_dir:
                userConf = _readOptions(
                    os.path.join(appdata_dir, 'youtube-dl', 'config'),
                    default=None)
                if userConf is None:
                    userConf = _readOptions(
                        os.path.join(appdata_dir, 'youtube-dl', 'config.txt'),
                        default=None)

        if userConf is None:
            userConf = _readOptions(
                os.path.join(compat_expanduser('~'), 'youtube-dl.conf'),
                default=None)
        if userConf is None:
            userConf = _readOptions(
                os.path.join(compat_expanduser('~'), 'youtube-dl.conf.txt'),
                default=None)

        if userConf is None:
            userConf = []

        return userConf

    def _format_option_string(option):
        ''' ('-o', '--option') -> -o, --format METAVAR'''

        opts = []

        if option._short_opts:
            opts.append(option._short_opts[0])
        if option._long_opts:
            opts.append(option._long_opts[0])
        if len(opts) > 1:
            opts.insert(1, ', ')

        if option.takes_value():
            opts.append(' %s' % option.metavar)

        return ''.join(opts)

    def _comma_separated_values_options_callback(option, opt_str, value, parser):
        setattr(parser.values, option.dest, value.split(','))

    def _hide_login_info(opts):
        PRIVATE_OPTS = ['-p', '--password', '-u', '--username', '--video-password', '--ap-password', '--ap-username']
        eqre = re.compile('^(?P<key>' + ('|'.join(re.escape(po) for po in PRIVATE_OPTS)) + ')=.+$')

        def _scrub_eq(o):
            m = eqre.match(o)
            if m:
                return m.group('key') + '=PRIVATE'
            else:
                return o

        opts = list(map(_scrub_eq, opts))
        for private_opt in PRIVATE_OPTS:
            try:
                i = opts.index(private_opt)
                opts[i + 1] = 'PRIVATE'
            except ValueError:
                pass
        return opts

    # No need to wrap help messages if we're on a wide console
    columns = compat_get_terminal_size().columns
    max_width = columns if columns else 80
    max_help_position = 80

    fmt = optparse.IndentedHelpFormatter(width=max_width, max_help_position=max_help_position)
    fmt.format_option_strings = _format_option_string

    kw = {
        'version': __version__,
        'formatter': fmt,
        'usage': '%prog [OPTIONS] URL [URL...]',
        'conflict_handler': 'resolve',
    }

    parser = optparse.OptionParser(**compat_kwargs(kw))

    general = optparse.OptionGroup(parser, 'General Options')
    general.add_option(
        '-h', '--help',
        action='help',
        help='Print this help text and exit')
    general.add_option(
        '-v', '--version',
        action='version',
        help='Print program version and exit')
    general.add_option(
        '-U', '--update',
        action='store_true', dest='update_self',
        help='Update this program to latest version. Make sure that you have sufficient permissions (run with sudo if needed)')
    general.add_option(
        '-i', '--ignore-errors',
        action='store_true', dest='ignoreerrors', default=False,
        help='Continue on download errors, for example to skip unavailable videos in a playlist')
    general.add_option(
        '--abort-on-error',
        action='store_false', dest='ignoreerrors',
        help='Abort downloading of further videos (in the playlist or the command line) if an error occurs')
    general.add_option(
        '--dump-user-agent',
        action='store_true', dest='dump_user_agent', default=False,
        help='Display the current browser identification')
    general.add_option(
        '--list-extractors',
        action='store_true', dest='list_extractors', default=False,
        help='List all supported extractors')
    general.add_option(
        '--extractor-descriptions',
        action='store_true', dest='list_extractor_descriptions', default=False,
        help='Output descriptions of all supported extractors')
    general.add_option(
        '--force-generic-extractor',
        action='store_true', dest='force_generic_extractor', default=False,
        help='Force extraction to use the generic extractor')
    general.add_option(
        '--default-search',
        dest='default_search', metavar='PREFIX',
        help='Use this prefix for unqualified URLs. For example "gvsearch2:" downloads two videos from google videos for youtube-dl "large apple". Use the value "auto" to let youtube-dl guess ("auto_warning" to emit a warning when guessing). "error" just throws an error. The default value "fixup_error" repairs broken URLs, but emits an error if this is not possible instead of searching.')
    general.add_option(
        '--ignore-config',
        action='store_true',
        help='Do not read configuration files. '
        'When given in the global configuration file /etc/youtube-dl.conf: '
        'Do not read the user configuration in ~/.config/youtube-dl/config '
        '(%APPDATA%/youtube-dl/config.txt on Windows)')
    general.add_option(
        '--flat-playlist',
        action='store_const', dest='extract_flat', const='in_playlist',
        default=False,
        help='Do not extract the videos of a playlist, only list them.')
    general.add_option(
        '--mark-watched',
        action='store_true', dest='mark_watched', default=False,
        help='Mark videos watched (YouTube only)')
    general.add_option(
        '--no-mark-watched',
        action='store_false', dest='mark_watched', default=False,
        help='Do not mark videos watched (YouTube only)')
    general.add_option(
        '--no-color', '--no-colors',
        action='store_true', dest='no_color',
        default=False,
        help='Do not emit color codes in output')

    network = optparse.OptionGroup(parser, 'Network Options')
    network.add_option(
        '--proxy', dest='proxy',
        default=None, metavar='URL',
        help='Use the specified HTTP/HTTPS/SOCKS proxy. To enable experimental '
             'SOCKS proxy, specify a proper scheme. For example '
             'socks5://127.0.0.1:1080/. Pass in an empty string (--proxy "") '
             'for direct connection')
    network.add_option(
        '--socket-timeout',
        dest='socket_timeout', type=float, default=None, metavar='SECONDS',
        help='Time to wait before giving up, in seconds')
    network.add_option(
        '--source-address',
        metavar='IP', dest='source_address', default=None,
        help='Client-side IP address to bind to (experimental)',
    )
    network.add_option(
        '-4', '--force-ipv4',
        action='store_const', const='0.0.0.0', dest='source_address',
        help='Make all connections via IPv4 (experimental)',
    )
    network.add_option(
        '-6', '--force-ipv6',
        action='store_const', const='::', dest='source_address',
        help='Make all connections via IPv6 (experimental)',
    )
    network.add_option(
        '--geo-verification-proxy',
        dest='geo_verification_proxy', default=None, metavar='URL',
        help='Use this proxy to verify the IP address for some geo-restricted sites. '
        'The default proxy specified by --proxy (or none, if the options is not present) is used for the actual downloading. (experimental)'
    )
    network.add_option(
        '--cn-verification-proxy',
        dest='cn_verification_proxy', default=None, metavar='URL',
        help=optparse.SUPPRESS_HELP,
    )

    selection = optparse.OptionGroup(parser, 'Video Selection')
    selection.add_option(
        '--playlist-start',
        dest='playliststart', metavar='NUMBER', default=1, type=int,
        help='Playlist video to start at (default is %default)')
    selection.add_option(
        '--playlist-end',
        dest='playlistend', metavar='NUMBER', default=None, type=int,
        help='Playlist video to end at (default is last)')
    selection.add_option(
        '--playlist-items',
        dest='playlist_items', metavar='ITEM_SPEC', default=None,
        help='Playlist video items to download. Specify indices of the videos in the playlist separated by commas like: "--playlist-items 1,2,5,8" if you want to download videos indexed 1, 2, 5, 8 in the playlist. You can specify range: "--playlist-items 1-3,7,10-13", it will download the videos at index 1, 2, 3, 7, 10, 11, 12 and 13.')
    selection.add_option(
        '--match-title',
        dest='matchtitle', metavar='REGEX',
        help='Download only matching titles (regex or caseless sub-string)')
    selection.add_option(
        '--reject-title',
        dest='rejecttitle', metavar='REGEX',
        help='Skip download for matching titles (regex or caseless sub-string)')
    selection.add_option(
        '--max-downloads',
        dest='max_downloads', metavar='NUMBER', type=int, default=None,
        help='Abort after downloading NUMBER files')
    selection.add_option(
        '--min-filesize',
        metavar='SIZE', dest='min_filesize', default=None,
        help='Do not download any videos smaller than SIZE (e.g. 50k or 44.6m)')
    selection.add_option(
        '--max-filesize',
        metavar='SIZE', dest='max_filesize', default=None,
        help='Do not download any videos larger than SIZE (e.g. 50k or 44.6m)')
    selection.add_option(
        '--date',
        metavar='DATE', dest='date', default=None,
        help='Download only videos uploaded in this date')
    selection.add_option(
        '--datebefore',
        metavar='DATE', dest='datebefore', default=None,
        help='Download only videos uploaded on or before this date (i.e. inclusive)')
    selection.add_option(
        '--dateafter',
        metavar='DATE', dest='dateafter', default=None,
        help='Download only videos uploaded on or after this date (i.e. inclusive)')
    selection.add_option(
        '--min-views',
        metavar='COUNT', dest='min_views', default=None, type=int,
        help='Do not download any videos with less than COUNT views')
    selection.add_option(
        '--max-views',
        metavar='COUNT', dest='max_views', default=None, type=int,
        help='Do not download any videos with more than COUNT views')
    selection.add_option(
        '--match-filter',
        metavar='FILTER', dest='match_filter', default=None,
        help=(
            'Generic video filter (experimental). '
            'Specify any key (see help for -o for a list of available keys) to'
            ' match if the key is present, '
            '!key to check if the key is not present,'
            'key > NUMBER (like "comment_count > 12", also works with '
            '>=, <, <=, !=, =) to compare against a number, and '
            '& to require multiple matches. '
            'Values which are not known are excluded unless you'
            ' put a question mark (?) after the operator.'
            'For example, to only match videos that have been liked more than '
            '100 times and disliked less than 50 times (or the dislike '
            'functionality is not available at the given service), but who '
            'also have a description, use --match-filter '
            '"like_count > 100 & dislike_count <? 50 & description" .'
        ))
    selection.add_option(
        '--no-playlist',
        action='store_true', dest='noplaylist', default=False,
        help='Download only the video, if the URL refers to a video and a playlist.')
    selection.add_option(
        '--yes-playlist',
        action='store_false', dest='noplaylist', default=False,
        help='Download the playlist, if the URL refers to a video and a playlist.')
    selection.add_option(
        '--age-limit',
        metavar='YEARS', dest='age_limit', default=None, type=int,
        help='Download only videos suitable for the given age')
    selection.add_option(
        '--download-archive', metavar='FILE',
        dest='download_archive',
        help='Download only videos not listed in the archive file. Record the IDs of all downloaded videos in it.')
    selection.add_option(
        '--include-ads',
        dest='include_ads', action='store_true',
        help='Download advertisements as well (experimental)')

    authentication = optparse.OptionGroup(parser, 'Authentication Options')
    authentication.add_option(
        '-u', '--username',
        dest='username', metavar='USERNAME',
        help='Login with this account ID')
    authentication.add_option(
        '-p', '--password',
        dest='password', metavar='PASSWORD',
        help='Account password. If this option is left out, youtube-dl will ask interactively.')
    authentication.add_option(
        '-2', '--twofactor',
        dest='twofactor', metavar='TWOFACTOR',
        help='Two-factor auth code')
    authentication.add_option(
        '-n', '--netrc',
        action='store_true', dest='usenetrc', default=False,
        help='Use .netrc authentication data')
    authentication.add_option(
        '--video-password',
        dest='videopassword', metavar='PASSWORD',
        help='Video password (vimeo, smotri, youku)')

    adobe_pass = optparse.OptionGroup(parser, 'Adobe Pass Options')
    adobe_pass.add_option(
        '--ap-mso',
        dest='ap_mso', metavar='MSO',
        help='Adobe Pass multiple-system operator (TV provider) identifier, use --ap-list-mso for a list of available MSOs')
    adobe_pass.add_option(
        '--ap-username',
        dest='ap_username', metavar='USERNAME',
        help='Multiple-system operator account login')
    adobe_pass.add_option(
        '--ap-password',
        dest='ap_password', metavar='PASSWORD',
        help='Multiple-system operator account password. If this option is left out, youtube-dl will ask interactively.')
    adobe_pass.add_option(
        '--ap-list-mso',
        action='store_true', dest='ap_list_mso', default=False,
        help='List all supported multiple-system operators')

    video_format = optparse.OptionGroup(parser, 'Video Format Options')
    video_format.add_option(
        '-f', '--format',
        action='store', dest='format', metavar='FORMAT', default=None,
        help='Video format code, see the "FORMAT SELECTION" for all the info')
    video_format.add_option(
        '--all-formats',
        action='store_const', dest='format', const='all',
        help='Download all available video formats')
    video_format.add_option(
        '--prefer-free-formats',
        action='store_true', dest='prefer_free_formats', default=False,
        help='Prefer free video formats unless a specific one is requested')
    video_format.add_option(
        '-F', '--list-formats',
        action='store_true', dest='listformats',
        help='List all available formats of requested videos')
    video_format.add_option(
        '--youtube-include-dash-manifest',
        action='store_true', dest='youtube_include_dash_manifest', default=True,
        help=optparse.SUPPRESS_HELP)
    video_format.add_option(
        '--youtube-skip-dash-manifest',
        action='store_false', dest='youtube_include_dash_manifest',
        help='Do not download the DASH manifests and related data on YouTube videos')
    video_format.add_option(
        '--merge-output-format',
        action='store', dest='merge_output_format', metavar='FORMAT', default=None,
        help=(
            'If a merge is required (e.g. bestvideo+bestaudio), '
            'output to given container format. One of mkv, mp4, ogg, webm, flv. '
            'Ignored if no merge is required'))

    subtitles = optparse.OptionGroup(parser, 'Subtitle Options')
    subtitles.add_option(
        '--write-sub', '--write-srt',
        action='store_true', dest='writesubtitles', default=False,
        help='Write subtitle file')
    subtitles.add_option(
        '--write-auto-sub', '--write-automatic-sub',
        action='store_true', dest='writeautomaticsub', default=False,
        help='Write automatically generated subtitle file (YouTube only)')
    subtitles.add_option(
        '--all-subs',
        action='store_true', dest='allsubtitles', default=False,
        help='Download all the available subtitles of the video')
    subtitles.add_option(
        '--list-subs',
        action='store_true', dest='listsubtitles', default=False,
        help='List all available subtitles for the video')
    subtitles.add_option(
        '--sub-format',
        action='store', dest='subtitlesformat', metavar='FORMAT', default='best',
        help='Subtitle format, accepts formats preference, for example: "srt" or "ass/srt/best"')
    subtitles.add_option(
        '--sub-lang', '--sub-langs', '--srt-lang',
        action='callback', dest='subtitleslangs', metavar='LANGS', type='str',
        default=[], callback=_comma_separated_values_options_callback,
        help='Languages of the subtitles to download (optional) separated by commas, use --list-subs for available language tags')

    downloader = optparse.OptionGroup(parser, 'Download Options')
    downloader.add_option(
        '-r', '--limit-rate', '--rate-limit',
        dest='ratelimit', metavar='RATE',
        help='Maximum download rate in bytes per second (e.g. 50K or 4.2M)')
    downloader.add_option(
        '-R', '--retries',
        dest='retries', metavar='RETRIES', default=10,
        help='Number of retries (default is %default), or "infinite".')
    downloader.add_option(
        '--fragment-retries',
        dest='fragment_retries', metavar='RETRIES', default=10,
        help='Number of retries for a fragment (default is %default), or "infinite" (DASH and hlsnative only)')
    downloader.add_option(
        '--skip-unavailable-fragments',
        action='store_true', dest='skip_unavailable_fragments', default=True,
        help='Skip unavailable fragments (DASH and hlsnative only)')
    general.add_option(
        '--abort-on-unavailable-fragment',
        action='store_false', dest='skip_unavailable_fragments',
        help='Abort downloading when some fragment is not available')
    downloader.add_option(
        '--buffer-size',
        dest='buffersize', metavar='SIZE', default='1024',
        help='Size of download buffer (e.g. 1024 or 16K) (default is %default)')
    downloader.add_option(
        '--no-resize-buffer',
        action='store_true', dest='noresizebuffer', default=False,
        help='Do not automatically adjust the buffer size. By default, the buffer size is automatically resized from an initial value of SIZE.')
    downloader.add_option(
        '--test',
        action='store_true', dest='test', default=False,
        help=optparse.SUPPRESS_HELP)
    downloader.add_option(
        '--playlist-reverse',
        action='store_true',
        help='Download playlist videos in reverse order')
    downloader.add_option(
        '--xattr-set-filesize',
        dest='xattr_set_filesize', action='store_true',
        help='Set file xattribute ytdl.filesize with expected filesize (experimental)')
    downloader.add_option(
        '--hls-prefer-native',
        dest='hls_prefer_native', action='store_true', default=None,
        help='Use the native HLS downloader instead of ffmpeg')
    downloader.add_option(
        '--hls-prefer-ffmpeg',
        dest='hls_prefer_native', action='store_false', default=None,
        help='Use ffmpeg instead of the native HLS downloader')
    downloader.add_option(
        '--hls-use-mpegts',
        dest='hls_use_mpegts', action='store_true',
        help='Use the mpegts container for HLS videos, allowing to play the '
             'video while downloading (some players may not be able to play it)')
    downloader.add_option(
        '--external-downloader',
        dest='external_downloader', metavar='COMMAND',
        help='Use the specified external downloader. '
             'Currently supports %s' % ','.join(list_external_downloaders()))
    downloader.add_option(
        '--external-downloader-args',
        dest='external_downloader_args', metavar='ARGS',
        help='Give these arguments to the external downloader')

    workarounds = optparse.OptionGroup(parser, 'Workarounds')
    workarounds.add_option(
        '--encoding',
        dest='encoding', metavar='ENCODING',
        help='Force the specified encoding (experimental)')
    workarounds.add_option(
        '--no-check-certificate',
        action='store_true', dest='no_check_certificate', default=False,
        help='Suppress HTTPS certificate validation')
    workarounds.add_option(
        '--prefer-insecure',
        '--prefer-unsecure', action='store_true', dest='prefer_insecure',
        help='Use an unencrypted connection to retrieve information about the video. (Currently supported only for YouTube)')
    workarounds.add_option(
        '--user-agent',
        metavar='UA', dest='user_agent',
        help='Specify a custom user agent')
    workarounds.add_option(
        '--referer',
        metavar='URL', dest='referer', default=None,
        help='Specify a custom referer, use if the video access is restricted to one domain',
    )
    workarounds.add_option(
        '--add-header',
        metavar='FIELD:VALUE', dest='headers', action='append',
        help='Specify a custom HTTP header and its value, separated by a colon \':\'. You can use this option multiple times',
    )
    workarounds.add_option(
        '--bidi-workaround',
        dest='bidi_workaround', action='store_true',
        help='Work around terminals that lack bidirectional text support. Requires bidiv or fribidi executable in PATH')
    workarounds.add_option(
        '--sleep-interval', '--min-sleep-interval', metavar='SECONDS',
        dest='sleep_interval', type=float,
        help=(
            'Number of seconds to sleep before each download when used alone '
            'or a lower bound of a range for randomized sleep before each download '
            '(minimum possible number of seconds to sleep) when used along with '
            '--max-sleep-interval.'))
    workarounds.add_option(
        '--max-sleep-interval', metavar='SECONDS',
        dest='max_sleep_interval', type=float,
        help=(
            'Upper bound of a range for randomized sleep before each download '
            '(maximum possible number of seconds to sleep). Must only be used '
            'along with --min-sleep-interval.'))

    verbosity = optparse.OptionGroup(parser, 'Verbosity / Simulation Options')
    verbosity.add_option(
        '-q', '--quiet',
        action='store_true', dest='quiet', default=False,
        help='Activate quiet mode')
    verbosity.add_option(
        '--no-warnings',
        dest='no_warnings', action='store_true', default=False,
        help='Ignore warnings')
    verbosity.add_option(
        '-s', '--simulate',
        action='store_true', dest='simulate', default=False,
        help='Do not download the video and do not write anything to disk')
    verbosity.add_option(
        '--skip-download',
        action='store_true', dest='skip_download', default=False,
        help='Do not download the video')
    verbosity.add_option(
        '-g', '--get-url',
        action='store_true', dest='geturl', default=False,
        help='Simulate, quiet but print URL')
    verbosity.add_option(
        '-e', '--get-title',
        action='store_true', dest='gettitle', default=False,
        help='Simulate, quiet but print title')
    verbosity.add_option(
        '--get-id',
        action='store_true', dest='getid', default=False,
        help='Simulate, quiet but print id')
    verbosity.add_option(
        '--get-thumbnail',
        action='store_true', dest='getthumbnail', default=False,
        help='Simulate, quiet but print thumbnail URL')
    verbosity.add_option(
        '--get-description',
        action='store_true', dest='getdescription', default=False,
        help='Simulate, quiet but print video description')
    verbosity.add_option(
        '--get-duration',
        action='store_true', dest='getduration', default=False,
        help='Simulate, quiet but print video length')
    verbosity.add_option(
        '--get-filename',
        action='store_true', dest='getfilename', default=False,
        help='Simulate, quiet but print output filename')
    verbosity.add_option(
        '--get-format',
        action='store_true', dest='getformat', default=False,
        help='Simulate, quiet but print output format')
    verbosity.add_option(
        '-j', '--dump-json',
        action='store_true', dest='dumpjson', default=False,
        help='Simulate, quiet but print JSON information. See --output for a description of available keys.')
    verbosity.add_option(
        '-J', '--dump-single-json',
        action='store_true', dest='dump_single_json', default=False,
        help='Simulate, quiet but print JSON information for each command-line argument. If the URL refers to a playlist, dump the whole playlist information in a single line.')
    verbosity.add_option(
        '--print-json',
        action='store_true', dest='print_json', default=False,
        help='Be quiet and print the video information as JSON (video is still being downloaded).',
    )
    verbosity.add_option(
        '--newline',
        action='store_true', dest='progress_with_newline', default=False,
        help='Output progress bar as new lines')
    verbosity.add_option(
        '--no-progress',
        action='store_true', dest='noprogress', default=False,
        help='Do not print progress bar')
    verbosity.add_option(
        '--console-title',
        action='store_true', dest='consoletitle', default=False,
        help='Display progress in console titlebar')
    verbosity.add_option(
        '-v', '--verbose',
        action='store_true', dest='verbose', default=False,
        help='Print various debugging information')
    verbosity.add_option(
        '--dump-pages', '--dump-intermediate-pages',
        action='store_true', dest='dump_intermediate_pages', default=False,
        help='Print downloaded pages encoded using base64 to debug problems (very verbose)')
    verbosity.add_option(
        '--write-pages',
        action='store_true', dest='write_pages', default=False,
        help='Write downloaded intermediary pages to files in the current directory to debug problems')
    verbosity.add_option(
        '--youtube-print-sig-code',
        action='store_true', dest='youtube_print_sig_code', default=False,
        help=optparse.SUPPRESS_HELP)
    verbosity.add_option(
        '--print-traffic', '--dump-headers',
        dest='debug_printtraffic', action='store_true', default=False,
        help='Display sent and read HTTP traffic')
    verbosity.add_option(
        '-C', '--call-home',
        dest='call_home', action='store_true', default=False,
        help='Contact the youtube-dl server for debugging')
    verbosity.add_option(
        '--no-call-home',
        dest='call_home', action='store_false', default=False,
        help='Do NOT contact the youtube-dl server for debugging')

    filesystem = optparse.OptionGroup(parser, 'Filesystem Options')
    filesystem.add_option(
        '-a', '--batch-file',
        dest='batchfile', metavar='FILE',
        help='File containing URLs to download (\'-\' for stdin)')
    filesystem.add_option(
        '--id', default=False,
        action='store_true', dest='useid', help='Use only video ID in file name')
    filesystem.add_option(
        '-o', '--output',
        dest='outtmpl', metavar='TEMPLATE',
        help=('Output filename template, see the "OUTPUT TEMPLATE" for all the info'))
    filesystem.add_option(
        '--autonumber-size',
        dest='autonumber_size', metavar='NUMBER',
        help='Specify the number of digits in %(autonumber)s when it is present in output filename template or --auto-number option is given')
    filesystem.add_option(
        '--restrict-filenames',
        action='store_true', dest='restrictfilenames', default=False,
        help='Restrict filenames to only ASCII characters, and avoid "&" and spaces in filenames')
    filesystem.add_option(
        '-A', '--auto-number',
        action='store_true', dest='autonumber', default=False,
        help='[deprecated; use -o "%(autonumber)s-%(title)s.%(ext)s" ] Number downloaded files starting from 00000')
    filesystem.add_option(
        '-t', '--title',
        action='store_true', dest='usetitle', default=False,
        help='[deprecated] Use title in file name (default)')
    filesystem.add_option(
        '-l', '--literal', default=False,
        action='store_true', dest='usetitle',
        help='[deprecated] Alias of --title')
    filesystem.add_option(
        '-w', '--no-overwrites',
        action='store_true', dest='nooverwrites', default=False,
        help='Do not overwrite files')
    filesystem.add_option(
        '-c', '--continue',
        action='store_true', dest='continue_dl', default=True,
        help='Force resume of partially downloaded files. By default, youtube-dl will resume downloads if possible.')
    filesystem.add_option(
        '--no-continue',
        action='store_false', dest='continue_dl',
        help='Do not resume partially downloaded files (restart from beginning)')
    filesystem.add_option(
        '--no-part',
        action='store_true', dest='nopart', default=False,
        help='Do not use .part files - write directly into output file')
    filesystem.add_option(
        '--no-mtime',
        action='store_false', dest='updatetime', default=True,
        help='Do not use the Last-modified header to set the file modification time')
    filesystem.add_option(
        '--write-description',
        action='store_true', dest='writedescription', default=False,
        help='Write video description to a .description file')
    filesystem.add_option(
        '--write-info-json',
        action='store_true', dest='writeinfojson', default=False,
        help='Write video metadata to a .info.json file')
    filesystem.add_option(
        '--write-annotations',
        action='store_true', dest='writeannotations', default=False,
        help='Write video annotations to a .annotations.xml file')
    filesystem.add_option(
        '--load-info-json', '--load-info',
        dest='load_info_filename', metavar='FILE',
        help='JSON file containing the video information (created with the "--write-info-json" option)')
    filesystem.add_option(
        '--cookies',
        dest='cookiefile', metavar='FILE',
        help='File to read cookies from and dump cookie jar in')
    filesystem.add_option(
        '--cache-dir', dest='cachedir', default=None, metavar='DIR',
        help='Location in the filesystem where youtube-dl can store some downloaded information permanently. By default $XDG_CACHE_HOME/youtube-dl or ~/.cache/youtube-dl . At the moment, only YouTube player files (for videos with obfuscated signatures) are cached, but that may change.')
    filesystem.add_option(
        '--no-cache-dir', action='store_const', const=False, dest='cachedir',
        help='Disable filesystem caching')
    filesystem.add_option(
        '--rm-cache-dir',
        action='store_true', dest='rm_cachedir',
        help='Delete all filesystem cache files')

    thumbnail = optparse.OptionGroup(parser, 'Thumbnail images')
    thumbnail.add_option(
        '--write-thumbnail',
        action='store_true', dest='writethumbnail', default=False,
        help='Write thumbnail image to disk')
    thumbnail.add_option(
        '--write-all-thumbnails',
        action='store_true', dest='write_all_thumbnails', default=False,
        help='Write all thumbnail image formats to disk')
    thumbnail.add_option(
        '--list-thumbnails',
        action='store_true', dest='list_thumbnails', default=False,
        help='Simulate and list all available thumbnail formats')

    postproc = optparse.OptionGroup(parser, 'Post-processing Options')
    postproc.add_option(
        '-x', '--extract-audio',
        action='store_true', dest='extractaudio', default=False,
        help='Convert video files to audio-only files (requires ffmpeg or avconv and ffprobe or avprobe)')
    postproc.add_option(
        '--audio-format', metavar='FORMAT', dest='audioformat', default='best',
        help='Specify audio format: "best", "aac", "vorbis", "mp3", "m4a", "opus", or "wav"; "%default" by default')
    postproc.add_option(
        '--audio-quality', metavar='QUALITY',
        dest='audioquality', default='5',
        help='Specify ffmpeg/avconv audio quality, insert a value between 0 (better) and 9 (worse) for VBR or a specific bitrate like 128K (default %default)')
    postproc.add_option(
        '--recode-video',
        metavar='FORMAT', dest='recodevideo', default=None,
        help='Encode the video to another format if necessary (currently supported: mp4|flv|ogg|webm|mkv|avi)')
    postproc.add_option(
        '--postprocessor-args',
        dest='postprocessor_args', metavar='ARGS',
        help='Give these arguments to the postprocessor')
    postproc.add_option(
        '-k', '--keep-video',
        action='store_true', dest='keepvideo', default=False,
        help='Keep the video file on disk after the post-processing; the video is erased by default')
    postproc.add_option(
        '--no-post-overwrites',
        action='store_true', dest='nopostoverwrites', default=False,
        help='Do not overwrite post-processed files; the post-processed files are overwritten by default')
    postproc.add_option(
        '--embed-subs',
        action='store_true', dest='embedsubtitles', default=False,
        help='Embed subtitles in the video (only for mp4, webm and mkv videos)')
    postproc.add_option(
        '--embed-thumbnail',
        action='store_true', dest='embedthumbnail', default=False,
        help='Embed thumbnail in the audio as cover art')
    postproc.add_option(
        '--add-metadata',
        action='store_true', dest='addmetadata', default=False,
        help='Write metadata to the video file')
    postproc.add_option(
        '--metadata-from-title',
        metavar='FORMAT', dest='metafromtitle',
        help='Parse additional metadata like song title / artist from the video title. '
             'The format syntax is the same as --output, '
             'the parsed parameters replace existing values. '
             'Additional templates: %(album)s, %(artist)s. '
             'Example: --metadata-from-title "%(artist)s - %(title)s" matches a title like '
             '"Coldplay - Paradise"')
    postproc.add_option(
        '--xattrs',
        action='store_true', dest='xattrs', default=False,
        help='Write metadata to the video file\'s xattrs (using dublin core and xdg standards)')
    postproc.add_option(
        '--fixup',
        metavar='POLICY', dest='fixup', default='detect_or_warn',
        help='Automatically correct known faults of the file. '
             'One of never (do nothing), warn (only emit a warning), '
             'detect_or_warn (the default; fix file if we can, warn otherwise)')
    postproc.add_option(
        '--prefer-avconv',
        action='store_false', dest='prefer_ffmpeg',
        help='Prefer avconv over ffmpeg for running the postprocessors (default)')
    postproc.add_option(
        '--prefer-ffmpeg',
        action='store_true', dest='prefer_ffmpeg',
        help='Prefer ffmpeg over avconv for running the postprocessors')
    postproc.add_option(
        '--ffmpeg-location', '--avconv-location', metavar='PATH',
        dest='ffmpeg_location',
        help='Location of the ffmpeg/avconv binary; either the path to the binary or its containing directory.')
    postproc.add_option(
        '--exec',
        metavar='CMD', dest='exec_cmd',
        help='Execute a command on the file after downloading, similar to find\'s -exec syntax. Example: --exec \'adb push {} /sdcard/Music/ && rm {}\'')
    postproc.add_option(
        '--convert-subs', '--convert-subtitles',
        metavar='FORMAT', dest='convertsubtitles', default=None,
        help='Convert the subtitles to other format (currently supported: srt|ass|vtt)')

    parser.add_option_group(general)
    parser.add_option_group(network)
    parser.add_option_group(selection)
    parser.add_option_group(downloader)
    parser.add_option_group(filesystem)
    parser.add_option_group(thumbnail)
    parser.add_option_group(verbosity)
    parser.add_option_group(workarounds)
    parser.add_option_group(video_format)
    parser.add_option_group(subtitles)
    parser.add_option_group(authentication)
    parser.add_option_group(adobe_pass)
    parser.add_option_group(postproc)

    if overrideArguments is not None:
        opts, args = parser.parse_args(overrideArguments)
        if opts.verbose:
            write_string('[debug] Override config: ' + repr(overrideArguments) + '\n')
    else:
        def compat_conf(conf):
            if sys.version_info < (3,):
                return [a.decode(preferredencoding(), 'replace') for a in conf]
            return conf

        command_line_conf = compat_conf(sys.argv[1:])

        if '--ignore-config' in command_line_conf:
            system_conf = []
            user_conf = []
        else:
            system_conf = _readOptions('/etc/youtube-dl.conf')
            if '--ignore-config' in system_conf:
                user_conf = []
            else:
                user_conf = _readUserConf()
        argv = system_conf + user_conf + command_line_conf

        opts, args = parser.parse_args(argv)
        if opts.verbose:
            write_string('[debug] System config: ' + repr(_hide_login_info(system_conf)) + '\n')
            write_string('[debug] User config: ' + repr(_hide_login_info(user_conf)) + '\n')
            write_string('[debug] Command-line args: ' + repr(_hide_login_info(command_line_conf)) + '\n')

    return parser, opts, args

Example 16

Project: broc
Source File: Syntax.py
View license
def DIRECTORY(v): 
    """
    Add sub directory
    Args:
       v : the name of subdirectory, v is relative path
    """ 
    # gather all dependent module  
    env = Environment.GetCurrent()
    child_broc_dir = os.path.abspath(os.path.join(env.ModulePath(), v))
    if env.ModulePath() not in child_broc_dir:
            raise BrocArgumentIllegalError("DIRECTORY(%s) is wrong: %s not in %s" % \
                                          (child_broc_dir, env.ModulePath())

    child_broc_file = os.path.join(parent.module.root_path, v, 'BROC')
    if sys.argv[0] == 'PLANISH':
        parent = sys.argv[1]
        if not os.path.exists(child_broc_file):
            raise BrocArgumentIllegalError('Not found %s in Tag Directory(%s)' % (child_broc_file, v))
        try:
            execfile(child_broc_file)
        except BaseException as err:
            traceback.print_exc()
            raise BrocArgumentIllegalError(err)
    else: # find all targets to build
        if not os.path.exists(child_broc_file):
            raise BrocArgumentIllegalError('Not found %s in Tag Directory(%s)' % (child_broc_file, v))
        # Log.Log().LevPrint("INFO", 'add sub directory (%s) for module %s' % (v, env._module.module_cvspath)) 
        env.AddSubDir(v)

def PUBLISH(srcs, out_dir):
    """
    copy srcs to out_dir
    Args:
        srcs: the files needed to move should belongs to the module
        out_dir: the destination directory that must start with $OUT
        if argument is illeagl, raise BrocArgumentIllegalError 
    """
    if sys.argv[0] == 'PLANISH':
        return
    env = Environment.GetCurrent()
    if not out_dir.strip().startswith('$OUT'):
        raise BrocArgumentIllegalError("PUBLISH argument dst(%s) must start with $OUT \
                                         in %s " % (out_dir, env.BrocPath()))
    src_lists = srcs.split()
    for s in src_lists:
        abs_s = os.path.normpath(os.path.join(env.BrocDir(), s))
        if env.ModulePath() not in abs_s:
            raise NotInSelfModuleError(abs_s, env.ModulePath())

    env.AddPublish(srcs, out_dir)


def SVN_PATH():
    """
    return local path of module
    """
    if sys.argv[0] == 'PLANISH':
        return
    env = Environment.GetCurrent()
    return env.SvnPath()


def SVN_URL():
    """
    return url of module
    """
    if sys.argv[0] == 'PLANISH':
        return
    env = Environment.GetCurrent()
    return env.SvnUrl()


def SVN_REVISION():
    """
    return revision of module
    """
    if sys.argv[0] == 'PLANISH':
        return
    env = Environment.GetCurrent()
    return env.SvnRevision()


def SVN_LAST_CHANGED_REV():
    """
    return last changed rev
    """
    if sys.argv[0] == 'PLANISH':
        return
    env = Environment.GetCurrent()
    return env.SvnLastChangedRev()


def GIT_PATH():
    """
    return local path of module
    """
    if sys.argv[0] == 'PLANISH':
        return
    env = Environment.GetCurrent()
    return env.GitPath()
        
    
def GIT_URL():
    """
    return url of module
    """
    if sys.argv[0] == 'PLANISH':
        return
    env = Environment.GetCurrent()
    return env.GitUrl()


def GIT_BRANCH():
    """
    return the branch name of module
    """
    if sys.argv[0] == 'PLANISH':
        return
    env = Environment.GetCurrent()
    return env.GitBranch()


def GIT_COMMIT_ID():
    """
    return the commit id of module
    """
    if sys.argv[0] == 'PLANISH':
        return
    env = Environment.GetCurrent()
    return env.GitCommitID()


def GIT_TAG():
    """
    return the tag of module
    """
    if sys.argv[0] == 'PLANISH':
        return
    env = Environment.GetCurrent()
    return env.GitTag()

class BrocLoader(object):
    """
    the class loading BROC file
    """
    class __impl(object):
        """
        the implementation of singleton interface
        """
        def __init__(self):
            """
            """
            self._root = None
            self._nodes = dict()                   # module
            self._checked_configs = set()          # storing content of tag CONFIGS
            self._broc_dir = tempfile.mkdtemp()    # the temporary directory storing all BROC files 
            self._queue = Queue.Queue()
            self._lack_broc = set()                # the set of module who lack BROC file 
    
        def Id(self):
            """
            test method, return singleton id
            """
            return id(self)

        def SetRoot(self, root):
            """
            Args:
                root : the BrocNode object
            """
            if not self._root:
                self._root = root
                BrocTree.BrocTree().SetRoot(root)
                self._queue.put(root)

        def AddNode(self, node):
            """
            add new node
            Args:
                node : the object of BrocNode
            """
            if node.module.module_cvspath not in self._nodes:
                self._nodes[node.module.module_cvspath] = []
            
            self._nodes[node.module.module_cvspath].append(node)

        def AllNodes(self):
            """
            """
            return self._nodes

        def LackBrocModules(self):
            """
            return the set object containing the modules that lack BROC file
            """
            return self._lack_broc

        def LoadBROC(self):
            """
            to run main module BROC file
            """
            # main thread to load BROC
            # first node is root node representing main module
            while not self._queue.empty():
                parent = self._queue.get()
                sys.argv = ['PLANISH', parent]
                broc_file = self._download_broc(parent)
                if not broc_file:
                    self._lack_broc.add(parent.module.origin_config)
                    continue
                try:
                    execfile(broc_file)
                except BaseException as err:
                    traceback.print_exc()
            # print dependent tree
            BrocTree.BrocTree().Dump()

        def handle_configs(self, s, parent):
            """
            Args:
                s : [email protected]@xx set at tag CONFIGS 
                parent : the BrocNode object
            """
            if s in self._checked_configs:
                return 
            tree = BrocTree.BrocTree()
            repo_domain = BrocConfig.BrocConfig().RepoDomain(parent.module.repo_kind)
            postfix_branch = BrocConfig.BrocConfig().SVNPostfixBranch()
            postfix_tag = BrocConfig.BrocConfig().SVNPostfixTag()
            child_module = PlanishUtil.ParseConfig(s, 
                                           parent.module.workspace, 
                                           parent.module.dep_level + 1, 
                                           parent.module.repo_kind, 
                                           repo_domain, 
                                           postfix_branch, 
                                           postfix_tag) 
            # Log.Log().LevPrint("MSG", 'create node(%s), level %d' % (s, child_module.dep_level)) 
            child_node = BrocTree.BrocNode(child_module, parent, False)
            parent.AddChild(child_node)
            self.AddNode(child_node)
            self._queue.put(child_node)
            self._checked_configs.add(s)
            
        def _download_broc(self, node):
            """
            download BROC file from repository
            Args:
                node : the BrocNode object
            Returns:
                return abs path of BROC file if download success
                return None if download failed
            """
            broc_path = None
            cmd = None
            # for svn 
            # Log.Log().LevPrint("MSG", 'download BROC %s' % node.module.url)
            if node.module.repo_kind == BrocModule_pb2.Module.SVN:
                hash_value = Function.CalcHash(node.module.url)
                broc_url = os.path.join(node.module.url, 'BROC')
                broc_path = os.path.join(self._broc_dir, "%s_BROC" % hash_value)
                if node.module.revision:
                    broc_url = "%s -r %s" % (broc_url, node.module.revision)
                cmd = "svn export %s %s" % (broc_url, broc_path)
            else:
                # for GIT
                broc_path = os.path.join(node.module.workspace, node.module.module_cvspath, 'BROC')
                broc_dir = os.path.dirname(broc_path)
                if not os.path.exists(broc_path):
                    cmd += "git clone %s %s &&" \
                          % ("%s.git" % node.module.url, "%s" % broc_dir)

                    if node.module.br_name and node.module.br_name != 'master':
                        br_name = node.module.br_name
                        cmd += "cd %s && (git checkout %s || (git fetch origin %s:%s && git checkout %s))" \
                               % (broc_dir, br_name, br_name, br_name, br_name)
                    elif node.module.tag_name:
                        tag_name = node.module.tag_name
                        cmd += "cd %s && (git checkout %s || (git fetch origin %s:%s && git checkout %s))" \
                               % (broc_dir, tag_name, tag_name, tag_name, tag_name)

            if cmd: 
                Log.Log().LevPrint("MSG", "Getting BROC(%s) ..." % cmd)
                ret, msg = Function.RunCommand(cmd) 
                if ret != 0:
                    Log.Log().LevPrint("ERROR", msg)
                    return None

            return broc_path

    # class BrocLoader
    __instance = None
    def __init__(self):
        """ Create singleton instance """
        # Check whether we already have an instance
        if BrocLoader.__instance is None:
            # Create and remember instance
            BrocLoader.__instance = BrocLoader.__impl()

        # Store instance reference as the only member in the handle
        self.__dict__['_BrocLoader__instance'] = BrocLoader.__instance

    def __getattr__(self, attr):
        """ Delegate access to implementation """
        return getattr(self.__instance, attr)

    def __setattr__(self, attr, value):
        """ Delegate access to implementation """
        return setattr(self.__instance, attr, value)

Example 17

Project: smart_server
Source File: smart_manager.py
View license
def main():
    parser = argparse.ArgumentParser(description='SMART Server Management Tool')

    parser.add_argument("-a", "--all-steps", dest="all_steps",
                    action="store_true",
                    default=False,
                    help="All steps: clone, generate settings, "+
                      "kill running servers, generate sample data, "+
                      "run app server, reset api server, "+
                      "load sample data, run api servers")

    parser.add_argument("-b", "--branch", dest="using_branch",
                    default=False,
                    help="Use a specific branch for checkouts and updates")

    parser.add_argument("-d", "--development-branch", dest="branch_dev",
                    action="store_true",
                    default=False,
                    help="Use development branch for checkous and updates")
                      
    parser.add_argument("-g", "--clone-git-repositories", dest="clone_git",
                    action="store_true",
                    default=False,
                    help="Clone git repositories")
                    
    parser.add_argument("-u", "--update-git-repositories", dest="update_git",
                    action="store_true",
                    default=False,
                    help="Update git repositories")

    parser.add_argument("-s", "--generate-settings-files", dest="generate_settings",
                    action="store_true",
                    default=False,
                    help="Generate settings files")
                    
    parser.add_argument("-k", "--kill-servers", dest="kill_servers",
                    action="store_true",
                    default=False,
                    help="Kill all currently-running django development servers")
                      
    parser.add_argument("-p", "--generate-sample-data", dest="generate_sample_data",
                    action="store_true",
                    default=False,
                    help="Generate sample patient data")

    parser.add_argument("-v", "--run-app-server", dest="run_app_server",
                    action="store_true",
                    default=False,
                    help="Run app server (first kills all running SMART servers). Can be used in conjunction with -w")
    
    parser.add_argument("-r", "--reset-api-server", dest="reset_servers",
                    action="store_true",
                    default=False,
                    help="Reset API server")
                    
    parser.add_argument("-l", "--load-sample-data", dest="load_sample_data",
                    action="store_true",
                    default=False,
                    help="Load sample data into DB")
                    
    parser.add_argument("-c", "--create-user", dest="create_user",
                    action="store_true",
                    default=False,
                    help="Create a user account for web login")
    
    parser.add_argument("-w", "--run-api-servers", dest="run_api_servers",
                    action="store_true",
                    default=False,
                    help="Run api server (first kills all running SMART servers). Can be used conjunction with -v")

    args = parser.parse_args()
    repos = ["smart_server", "smart_ui_server", "smart_sample_patients", "smart_sample_apps"]

    reloadflag = ""
    if django.VERSION[:3] <= (1,3,0):
        reloadflag = " --noreload "

    if args.branch_dev:
        args.using_branch = "dev"
    if args.using_branch:
       print "USING BRANCH", args.using_branch
    if not args.all_steps and not ( 
        args.clone_git or
        args.update_git or
        args.generate_settings or
        args.generate_sample_data or
        args.load_sample_data or
        args.create_user or 
        args.kill_servers or
        args.reset_servers or
        args.run_app_server or 
        args.run_api_servers):
            parser.print_help()
            sys.exit(1)

    if args.all_steps:
        args.clone_git = True
        args.generate_settings  = True
        args.generate_sample_data  = True
        args.load_sample_data  = True
        args.create_user = True
        args.kill_servers = True
        args.run_app_server = True
        args.reset_servers  = True
        args.run_api_servers = True

    if args.clone_git:
        if not args.using_branch:
            args.using_branch = "master"

        print "Cloning (4) SMART git repositories..."
        for r in repos:
            call_command("git clone --recursive --recurse-submodules https://github.com/smart-platforms/"+r+".git", 
                        print_output=True)

    if args.update_git or args.clone_git:
        for r in repos:

            if args.using_branch:
                call_command("cd "+r+" && git checkout "+args.using_branch+" && cd ..")

            call_command("cd "+r+" && " +
                     "git pull && " +
                     "git submodule update --init --recursive && " +
                     "cd .. ",
                      print_output=True)
        
    if args.generate_settings:
        print "Configuring SMART server settings..."

        api_server_base_url = get_input("SMART API Server", "http://localhost:7000")
        
        chrome_consumer = get_input("Chrome App Consumer ID", "chrome")
        
        chrome_secret = get_input("Chrome App Consumer secret",  
                                  ''.join([choice(PASSWORD_LETTERBANK) for i in range(8)]))

        # TO DO: The password should be random here, but somehow we need to be able to change the DB password to match
        db_password = get_input("Database User Password",  "smart")          

        ui_server_base_url = get_input("SMART UI server", "http://localhost:7001")
        app_server_base_url = get_input("SMART App server", "http://localhost:8001")
        
        standalone_mode = get_input(
            """Run server in standalone mode (patient data stored in local db)?  
            If you choose 'no', the server will be configured in proxy mode, 
            with patient data hosted at a REST URL you provide.""", "yes")

        if standalone_mode=="no":
            proxy_base = get_input("Proxy server to use for medical record data",
                                   "none")

            proxy_user = get_input("User to associate with proxied requests", 
                                   "[email protected]")
        
        call_command("cp smart_server/settings.py.default smart_server/settings.py")

        call_command("cp smart_server/bootstrap_helpers/application_list.json.default " + 
                            "smart_server/bootstrap_helpers/application_list.json ")
                            
        call_command("cp smart_server/bootstrap_helpers/bootstrap_applications.py.default " + 
                            "smart_server/bootstrap_helpers/bootstrap_applications.py ")
                   
        fill_field('smart_server/bootstrap_helpers/application_list.json', 'app_server_base_url', app_server_base_url)
        fill_field('smart_server/bootstrap_helpers/bootstrap_applications.py', 'app_server_base_url', app_server_base_url)
        fill_field('smart_server/bootstrap_helpers/bootstrap_applications.py', 'ui_server_base_url', ui_server_base_url)

        call_command("cp smart_ui_server/settings.py.default smart_ui_server/settings.py")
        call_command("cp smart_sample_apps/settings.py.default smart_sample_apps/settings.py")

        fill_field('smart_server/settings.py', 'path_to_smart_server', 
                   os.path.join(cwd, "smart_server"))
        
        fill_field('smart_ui_server/settings.py', 'path_to_smart_ui_server', 
                   os.path.join(cwd, "smart_ui_server"))
        
        fill_field('smart_sample_apps/settings.py', 'path_to_smart_sample_apps', 
                   os.path.join(cwd, "smart_sample_apps"))

        fill_field('smart_server/settings.py', 'api_server_base_url', api_server_base_url)
        fill_field('smart_ui_server/settings.py', 'api_server_base_url', api_server_base_url)
        fill_field('smart_sample_apps/settings.py', 'app_server_base_url', app_server_base_url)
        fill_field('smart_sample_apps/settings.py', 'api_server_base_url', api_server_base_url)

        fill_field('smart_server/settings.py', 'chrome_consumer', chrome_consumer)
        fill_field('smart_server/settings.py', 'chrome_secret', chrome_secret)
        fill_field('smart_server/settings.py', 'db_password', db_password)

        fill_field('smart_ui_server/settings.py', 'chrome_consumer', chrome_consumer)
        fill_field('smart_ui_server/settings.py', 'chrome_secret', chrome_secret)
        fill_field('smart_ui_server/settings.py', 'db_password', db_password)
        
        fill_field('smart_sample_apps/settings.py', 'db_password', db_password)
        
        fill_field('smart_server/settings.py', 'django_secret_key', ''.join([choice(PASSWORD_LETTERBANK) for i in range(8)]))
        fill_field('smart_ui_server/settings.py', 'django_secret_key', ''.join([choice(PASSWORD_LETTERBANK) for i in range(8)]))
        fill_field('smart_sample_apps/settings.py', 'django_secret_key', ''.join([choice(PASSWORD_LETTERBANK) for i in range(8)]))

        fill_field('smart_server/settings.py', 'ui_server_base_url', ui_server_base_url)
        
        fill_field('smart_ui_server/settings.py', 'pretty_name_value', 'Reference EMR')

        if standalone_mode=="no":
            print "nostandalone"
            fill_field('smart_server/settings.py', 'use_proxy', 'True')
            fill_field('smart_server/settings.py', 'proxy_user_email', proxy_user)
            fill_field('smart_server/settings.py', 'proxy_base', proxy_base)

        else: 
            print "yes standalone"
            fill_field('smart_server/settings.py', 'use_proxy', 'False')

        fill_field('smart_server/settings.py', 'triplestore_engine', 'sesame')

        fill_field('smart_server/settings.py', 'triplestore_endpoint',
                'http://localhost:8080/openrdf-sesame/repositories/record_rdf')

    if args.run_app_server or args.run_api_servers:
        args.kill_servers = True
        server_settings = imp.load_source("settings", "smart_server/settings.py")
        app_settings = imp.load_source("settings", "smart_sample_apps/settings.py")
        app_server = app_settings.SMART_APP_SERVER_BASE
        api_server = server_settings.SITE_URL_PREFIX
        ui_server = server_settings.SMART_UI_SERVER_LOCATION

    if args.kill_servers:
        call_command("ps ax | "+
                     "grep -i 'python' | "+
                     "grep -i 'manage.py' | "+
                     "egrep  -o '^[ 0-9]+' | "+
                     "xargs -t  kill", failure_okay=True)
                     
    if args.generate_sample_data:
        call_command("cd smart_sample_patients/bin && " + 
                     "rm -rf ../generated-data/*.xml && " + 
                     "python generate.py --write ../generated-data &&" + 
                     "python generate-vitals-patient.py ../generated-data/99912345.xml &&" +
                     "cd ../..", print_output=True)

    if args.run_app_server:
        port = get_port(app_server)
        print "port:", port
        call_command("cd smart_sample_apps && python manage.py runconcurrentserver %s 0.0.0.0:%s &"%(reloadflag, port), 
                     print_output=True)
        call_command("sleep 2")

        print "App Server running."

    if args.reset_servers:
        print "Resetting the SMART server..."
        print "Note: Enter the SMART databse password when prompted (2 times)."
        print "      It is 'smart' by default."
        call_command("cd smart_server && "+
                     "sh ./reset.sh && "+
                     "cd ../..", print_output=True)
        
        print "Resetting the SMART UI server..."
        print "Note: Enter the SMART databse password when prompted (2 times)."
        print "      It is 'smart' by default."
        call_command("cd smart_ui_server && "+
                     "sh ./reset.sh &&"+
                     "cd ../..", print_output=True)


    if args.load_sample_data:
        call_command("cd smart_server && " + 
                     "PYTHONPATH=.:.. DJANGO_SETTINGS_MODULE=settings "+
                     "python load_tools/load_one_patient.py  " + 
                     "../smart_sample_patients/generated-data/* ../smart_sample_patients/deidentified-patients/*  && "
                     "cd ..", print_output=True)

    if args.create_user:
        print "Configuring a user ..."

        given_name = get_input("Given Name", "Demo")
        family_name = get_input("Family Name", "User")
        email = get_input("Email", "[email protected]")
        password = get_input("Password", "password")

        call_command("cd smart_server && " + 
                     "PYTHONPATH=.:.. DJANGO_SETTINGS_MODULE=settings "+
                     "python load_tools/create_user.py  " + 
                     given_name + " " + 
                     family_name + " " + 
                     email + " " + 
                     password + " && " 
                     "cd ..", print_output=True)

    if args.run_api_servers:

        port = get_port(api_server)
        print "port:", port
        call_command("cd smart_server && python manage.py runconcurrentserver %s 0.0.0.0:%s &"%(reloadflag, port), 
                     print_output=True)
        print "API Servers running."

        port = get_port(ui_server)
        print "port:", port
        call_command("cd smart_ui_server && python manage.py runconcurrentserver %s 0.0.0.0:%s &"%(reloadflag, port), 
                     print_output=True)
        call_command("sleep 2")

Example 18

View license
    def test_process_scpv2(self):

        # SearchCommand.process should

        # 1. Recognize all standard options:

        metadata = (
            '{{'
                '"action": "getinfo", "preview": false, "searchinfo": {{'
                    '"latest_time": "0",'
                    '"splunk_version": "20150522",'
                    '"username": "admin",'
                    '"app": "searchcommands_app",'
                    '"args": ['
                        '"logging_configuration={logging_configuration}",'
                        '"logging_level={logging_level}",'
                        '"record={record}",'
                        '"show_configuration={show_configuration}",'
                        '"required_option_1=value_1",'
                        '"required_option_2=value_2"'
                    '],'
                    '"search": "%7C%20inputlookup%20tweets%20%7C%20countmatches%20fieldname%3Dword_count%20pattern%3D%22%5Cw%2B%22%20text%20record%3Dt%20%7C%20export%20add_timestamp%3Df%20add_offset%3Dt%20format%3Dcsv%20segmentation%3Draw",'
                    '"earliest_time": "0",'
                    '"session_key": "0JbG1fJEvXrL6iYZw9y7tmvd6nHjTKj7ggaE7a4Jv5R0UIbeYJ65kThn^3hiNeoqzMT_LOtLpVR3Y8TIJyr5bkHUElMijYZ8l14wU0L4n^Oa5QxepsZNUIIQCBm^",'
                    '"owner": "admin",'
                    '"sid": "1433261372.158",'
                    '"splunkd_uri": "https://127.0.0.1:8089",'
                    '"dispatch_dir": {dispatch_dir},'
                    '"raw_args": ['
                        '"logging_configuration={logging_configuration}",'
                        '"logging_level={logging_level}",'
                        '"record={record}",'
                        '"show_configuration={show_configuration}",'
                        '"required_option_1=value_1",'
                        '"required_option_2=value_2"'
                    ']'
                '}}'
            '}}')

        basedir = self._package_directory

        default_logging_configuration = os.path.join(basedir, 'apps', 'app_with_logging_configuration', 'default', 'logging.conf')
        dispatch_dir = os.path.join(basedir, 'recordings', 'scpv2', 'Splunk-6.3', 'countmatches.dispatch_dir')
        logging_configuration = os.path.join(basedir, 'apps', 'app_with_logging_configuration', 'logging.conf')
        logging_level = 'ERROR'
        record = False
        show_configuration = True

        getinfo_metadata = metadata.format(
            dispatch_dir=encode_string(dispatch_dir),
            logging_configuration=encode_string(logging_configuration)[1:-1],
            logging_level=logging_level,
            record=('true' if record is True else 'false'),
            show_configuration=('true' if show_configuration is True else 'false'))

        execute_metadata = '{"action":"execute","finished":true}'
        execute_body = 'test\r\ndata\r\n'

        ifile = StringIO(
            'chunked 1.0,{},0\n{}'.format(len(getinfo_metadata), getinfo_metadata) +
            'chunked 1.0,{},{}\n{}{}'.format(len(execute_metadata), len(execute_body), execute_metadata, execute_body))

        command = TestCommand()
        result = StringIO()
        argv = ['some-external-search-command.py']

        self.assertEqual(command.logging_level, 'WARNING')
        self.assertIs(command.record, None)
        self.assertIs(command.show_configuration, None)

        try:
            # noinspection PyTypeChecker
            command.process(argv, ifile, ofile=result)
        except SystemExit as error:
            self.fail('Unexpected exception: {}: {}'.format(type(error).__name__, error))

        self.assertEqual(command.logging_configuration, logging_configuration)
        self.assertEqual(command.logging_level, 'ERROR')
        self.assertEqual(command.record, record)
        self.assertEqual(command.show_configuration, show_configuration)
        self.assertEqual(command.required_option_1, 'value_1')
        self.assertEqual(command.required_option_2, 'value_2')

        self.assertEqual(
            'chunked 1.0,68,0\n'
            '{"inspector":{"messages":[["INFO","test command configuration: "]]}}\n'
            'chunked 1.0,17,23\n'
            '{"finished":true}test,__mv_test\r\n'
            'data,\r\n',
            result.getvalue())

        self.assertEqual(command.protocol_version, 2)

        # 2. Provide access to these properties:
        #   fieldnames
        #   input_header
        #   metadata
        #   search_results_info
        #   service

        self.assertEqual([], command.fieldnames)

        command_metadata = command.metadata
        input_header = command.input_header

        self.assertIsNone(input_header['allowStream'])
        self.assertEqual(input_header['infoPath'], os.path.join(command_metadata.searchinfo.dispatch_dir, 'info.csv'))
        self.assertIsNone(input_header['keywords'])
        self.assertEqual(input_header['preview'], command_metadata.preview)
        self.assertIs(input_header['realtime'], False)
        self.assertEqual(input_header['search'], command_metadata.searchinfo.search)
        self.assertEqual(input_header['sid'], command_metadata.searchinfo.sid)
        self.assertEqual(input_header['splunkVersion'], command_metadata.searchinfo.splunk_version)
        self.assertIsNone(input_header['truncated'])

        self.assertEqual(command_metadata.preview, input_header['preview'])
        self.assertEqual(command_metadata.searchinfo.app, 'searchcommands_app')
        self.assertEqual(command_metadata.searchinfo.args, ['logging_configuration=' + logging_configuration, 'logging_level=ERROR', 'record=false', 'show_configuration=true', 'required_option_1=value_1', 'required_option_2=value_2'])
        self.assertEqual(command_metadata.searchinfo.dispatch_dir, os.path.dirname(input_header['infoPath']))
        self.assertEqual(command_metadata.searchinfo.earliest_time, 0.0)
        self.assertEqual(command_metadata.searchinfo.latest_time, 0.0)
        self.assertEqual(command_metadata.searchinfo.owner, 'admin')
        self.assertEqual(command_metadata.searchinfo.raw_args, command_metadata.searchinfo.args)
        self.assertEqual(command_metadata.searchinfo.search, '| inputlookup tweets | countmatches fieldname=word_count pattern="\\w+" text record=t | export add_timestamp=f add_offset=t format=csv segmentation=raw')
        self.assertEqual(command_metadata.searchinfo.session_key, '0JbG1fJEvXrL6iYZw9y7tmvd6nHjTKj7ggaE7a4Jv5R0UIbeYJ65kThn^3hiNeoqzMT_LOtLpVR3Y8TIJyr5bkHUElMijYZ8l14wU0L4n^Oa5QxepsZNUIIQCBm^')
        self.assertEqual(command_metadata.searchinfo.sid, '1433261372.158')
        self.assertEqual(command_metadata.searchinfo.splunk_version, '20150522')
        self.assertEqual(command_metadata.searchinfo.splunkd_uri, 'https://127.0.0.1:8089')
        self.assertEqual(command_metadata.searchinfo.username, 'admin')

        command.search_results_info.search_metrics = command.search_results_info.search_metrics.__dict__
        command.search_results_info.optional_fields_json = command.search_results_info.optional_fields_json.__dict__

        self.maxDiff = None

        self.assertDictEqual(command.search_results_info.__dict__, {
            u'is_summary_index': 0,
            u'bs_thread_count': 1,
            u'rt_backfill': 0,
            u'rtspan': '',
            u'search_StartTime': 1433261392.934936,
            u'read_raw': 1,
            u'root_sid': '',
            u'field_rendering': '',
            u'query_finished': 1,
            u'optional_fields_json': {},
            u'group_list': '',
            u'remoteServers': '',
            u'rt_latest': '',
            u'remote_log_download_mode': 'disabled',
            u'reduce_search': '',
            u'request_finalization': 0,
            u'auth_token': 'UQZSgWwE2f9oIKrj1QG^kVhW^T_cR4H5Z65bPtMhwlHytS5jFrFYyH^dGzjTusDjVTgoBNeR7bvIzctHF7DrLJ1ANevgDOWEWRvABNj6d_k0koqxw9Io',
            u'indexed_realtime': 0,
            u'ppc_bs': '$SPLUNK_HOME/etc',
            u'drop_count': 0,
            u'datamodel_map': '',
            u'search_can_be_event_type': 0,
            u'search_StartUp_Spent': 0,
            u'realtime': 0,
            u'splunkd_uri': 'https://127.0.0.1:8089',
            u'columnOrder': '',
            u'kv_store_settings': 'hosts;127.0.0.1:8191\\;;local;127.0.0.1:8191;read_preference;958513E3-8716-4ABF-9559-DA0C9678437F;replica_set_name;958513E3-8716-4ABF-9559-DA0C9678437F;status;ready;',
            u'label': '',
            u'summary_maxtimespan': '',
            u'indexed_realtime_offset': 0,
            u'sid': 1433261392.159,
            u'msg': [],
            u'internal_only': 0,
            u'summary_id': '',
            u'orig_search_head': '',
            u'ppc_app': 'chunked_searchcommands',
            u'countMap': {
                u'invocations.dispatch.writeStatus': u'1',
                u'duration.dispatch.writeStatus': u'2',
                u'duration.startup.handoff': u'79',
                u'duration.startup.configuration': u'34',
                u'invocations.startup.handoff': u'1',
                u'invocations.startup.configuration': u'1'},
            u'is_shc_mode': 0,
            u'shp_id': '958513E3-8716-4ABF-9559-DA0C9678437F',
            u'timestamp': 1433261392.936374, u'is_remote_sorted': 0,
            u'remote_search': '',
            u'splunkd_protocol': 'https',
            u'site': '',
            u'maxevents': 0,
            u'keySet': '',
            u'summary_stopped': 0,
            u'search_metrics': {
                u'ConsideredEvents': 0,
                u'ConsideredBuckets': 0,
                u'TotalSlicesInBuckets': 0,
                u'EliminatedBuckets': 0,
                u'DecompressedSlices': 0},
            u'summary_mode': 'all', u'now': 1433261392.0,
            u'splunkd_port': 8089, u'is_saved_search': 0,
            u'rtoptions': '',
            u'search': '| inputlookup random_data max=50000 | sum total=total value1 record=t | export add_timestamp=f add_offset=t format=csv segmentation=raw',
            u'bundle_version': 0,
            u'generation_id': 0,
            u'bs_thread_id': 0,
            u'is_batch_mode': 0,
            u'scan_count': 0,
            u'rt_earliest': '',
            u'default_group': '*',
            u'tstats_reduce': '',
            u'kv_store_additional_settings': 'hosts_guids;958513E3-8716-4ABF-9559-DA0C9678437F\\;;',
            u'enable_event_stream': 0,
            u'is_remote': 0,
            u'is_scheduled': 0,
            u'sample_ratio': 1,
            u'ppc_user': 'admin',
            u'sample_seed': 0})

        self.assertIsInstance(command.service, Service)

        self.assertEqual(command.service.authority, command_metadata.searchinfo.splunkd_uri)
        self.assertEqual(command.service.scheme, command.search_results_info.splunkd_protocol)
        self.assertEqual(command.service.port, command.search_results_info.splunkd_port)
        self.assertEqual(command.service.token, command_metadata.searchinfo.session_key)
        self.assertEqual(command.service.namespace.app, command.metadata.searchinfo.app)
        self.assertIsNone(command.service.namespace.owner)
        self.assertIsNone(command.service.namespace.sharing)

        self.assertEqual(command.protocol_version, 2)

        # 3. Produce an error message, log a debug message, and exit when invalid standard option values are encountered

        # Note on loggers
        # Loggers are global and can't be removed once they're created. We create loggers that are keyed by class name
        # Each instance of a class thus created gets access to the same logger. We created one in the prior test and
        # set it's level to ERROR. That level is retained in this test.

        logging_configuration = 'non-existent-logging.conf'
        logging_level = 'NON-EXISTENT-LOGGING-LEVEL'
        record = 'Non-boolean value'
        show_configuration = 'Non-boolean value'

        getinfo_metadata = metadata.format(
            dispatch_dir=encode_string(dispatch_dir),
            logging_configuration=encode_string(logging_configuration)[1:-1],
            logging_level=logging_level,
            record=record,
            show_configuration=show_configuration)

        execute_metadata = '{"action":"execute","finished":true}'
        execute_body = 'test\r\ndata\r\n'

        ifile = StringIO(
            'chunked 1.0,{},0\n{}'.format(len(getinfo_metadata), getinfo_metadata) +
            'chunked 1.0,{},{}\n{}{}'.format(len(execute_metadata), len(execute_body), execute_metadata, execute_body))

        command = TestCommand()
        result = StringIO()
        argv = ['test.py']

        # noinspection PyTypeChecker
        self.assertRaises(SystemExit, command.process, argv, ifile, ofile=result)
        self.assertEqual(command.logging_level, 'ERROR')
        self.assertEqual(command.record, False)
        self.assertEqual(command.show_configuration, False)
        self.assertEqual(command.required_option_1, 'value_1')
        self.assertEqual(command.required_option_2, 'value_2')

        self.assertEqual(
            'chunked 1.0,287,0\n'
            '{"inspector":{"messages":[["ERROR","Illegal value: logging_configuration=non-existent-logging.conf"],'
            '["ERROR","Illegal value: logging_level=NON-EXISTENT-LOGGING-LEVEL"],'
            '["ERROR","Illegal value: record=Non-boolean value"],'
            '["ERROR","Illegal value: show_configuration=Non-boolean value"]]}}\n'
            'chunked 1.0,17,0\n'
            '{"finished":true}',
            result.getvalue())

        self.assertEqual(command.protocol_version, 2)

        # 4. Produce an error message, log an error message that includes a traceback, and exit when an exception is
        #    raised during command execution.

        logging_configuration = os.path.join(basedir, 'apps', 'app_with_logging_configuration', 'logging.conf')
        logging_level = 'WARNING'
        record = False
        show_configuration = False

        getinfo_metadata = metadata.format(
            dispatch_dir=encode_string(dispatch_dir),
            logging_configuration=encode_string(logging_configuration)[1:-1],
            logging_level=logging_level,
            record=('true' if record is True else 'false'),
            show_configuration=('true' if show_configuration is True else 'false'))

        execute_metadata = '{"action":"execute","finished":true}'
        execute_body = 'action\r\nraise_exception\r\n'

        ifile = StringIO(
            'chunked 1.0,{},0\n{}'.format(len(getinfo_metadata), getinfo_metadata) +
            'chunked 1.0,{},{}\n{}{}'.format(len(execute_metadata), len(execute_body), execute_metadata, execute_body))

        command = TestCommand()
        result = StringIO()
        argv = ['test.py']

        try:
            command.process(argv, ifile, ofile=result)
        except SystemExit as error:
            self.assertNotEqual(0, error.code)
        except BaseException as error:
            self.fail('{0}: {1}: {2}\n'.format(type(error).__name__, error, result.getvalue()))
        else:
            self.fail('Expected SystemExit, not a return from TestCommand.process: {}\n'.format(result.getvalue()))

        self.assertEqual(command.logging_configuration, logging_configuration)
        self.assertEqual(command.logging_level, logging_level)
        self.assertEqual(command.record, record)
        self.assertEqual(command.show_configuration, show_configuration)
        self.assertEqual(command.required_option_1, 'value_1')
        self.assertEqual(command.required_option_2, 'value_2')

        finished = r'"finished":true'

        inspector = \
            r'"inspector":\{"messages":\[\["ERROR","StandardError at \\".+\\", line \d+ : test ' \
            r'logging_configuration=\\".+\\" logging_level=\\"WARNING\\" record=\\"f\\" ' \
            r'required_option_1=\\"value_1\\" required_option_2=\\"value_2\\" show_configuration=\\"f\\""\]\]\}'

        self.assertRegexpMatches(
            result.getvalue(),
            r'^chunked 1.0,2,0\n'
            r'\{\}\n'
            r'chunked 1.0,\d+,0\n'
            r'\{(' + inspector + r',' + finished + r'|' + finished + r',' + inspector + r')\}')

        self.assertEqual(command.protocol_version, 2)
        return

Example 19

Project: tree-hmm
Source File: plot.py
View license
def plot_params(args):
    """Plot alpha, theta, and the emission probabilities"""
    old_err = sp.seterr(under='ignore')
    oldsize = matplotlib.rcParams['font.size']
    K, L = args.emit_probs.shape if not args.continuous_observations else args.means.shape

    # alpha
    #matplotlib.rcParams['font.size'] = 12
    pyplot.figure()
    _, xedges, yedges = sp.histogram2d([0,K], [0,K], bins=[K,K])
    extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
    pyplot.imshow(args.alpha.astype(sp.float64), extent=extent, interpolation='nearest',
                  vmin=0, vmax=1,  cmap='OrRd', origin='lower')
    pyplot.xticks(sp.arange(K) + .5, sp.arange(K)+1)
    pyplot.gca().set_xticks(sp.arange(K)+1, minor=True)
    pyplot.yticks(sp.arange(K) + .5, sp.arange(K)+1)
    pyplot.gca().set_yticks(sp.arange(K)+1, minor=True)
    pyplot.grid(which='minor', alpha=.2)
    for line in pyplot.gca().yaxis.get_ticklines() + pyplot.gca().xaxis.get_ticklines() + pyplot.gca().xaxis.get_ticklines(minor=True) + pyplot.gca().yaxis.get_ticklines(minor=True):
    # label is a Text instance
        line.set_markersize(0)
    pyplot.ylabel('Horizontal parent state')
    pyplot.xlabel('Node state')
    pyplot.title(r"Top root transition ($\alpha$) for {approx} iteration {iteration}".
                        format(approx=args.approx, iteration=args.iteration))
    b = pyplot.colorbar(shrink=.9)
    b.set_label("Probability")
    outfile = (args.out_params + '_it{iteration}.png').format(param='alpha', **args.__dict__)
    pyplot.savefig(os.path.join(args.out_dir, outfile), dpi=240)


    # beta
    pyplot.figure()
    _, xedges, yedges = sp.histogram2d([0,K], [0,K], bins=[K,K])
    extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
    pyplot.clf()
    pyplot.imshow(args.beta.astype(sp.float64), extent=extent, interpolation='nearest',
                  vmin=0, vmax=1, cmap='OrRd', origin='lower')
    pyplot.xticks(sp.arange(K) + .5, sp.arange(K)+1)
    pyplot.gca().set_xticks(sp.arange(K)+1, minor=True)
    pyplot.yticks(sp.arange(K) + .5, sp.arange(K)+1)
    pyplot.gca().set_yticks(sp.arange(K)+1, minor=True)
    pyplot.grid(which='minor', alpha=.2)
    for line in pyplot.gca().yaxis.get_ticklines() + pyplot.gca().xaxis.get_ticklines() + pyplot.gca().xaxis.get_ticklines(minor=True) + pyplot.gca().yaxis.get_ticklines(minor=True):
    # label is a Text instance
        line.set_markersize(0)
    pyplot.ylabel('Vertical parent state')
    pyplot.xlabel('Node state')
    pyplot.title(r"Left root transition ($\beta$) for {approx} iteration {iteration}".
                        format(approx=args.approx, iteration=args.iteration))
    b = pyplot.colorbar(shrink=.9)
    b.set_label("Probability")
    outfile = (args.out_params + '_it{iteration}.png').format(param='beta', **args.__dict__)
    pyplot.savefig(os.path.join(args.out_dir, outfile), dpi=240)


    # theta
    if args.separate_theta:
        theta_tmp = args.theta
        for i in range((args.theta.shape)[0]):
            setattr(args, 'theta_%s'%(i+1), args.theta[i,:,:,:])

    for theta_name in ['theta'] + ['theta_%s' % i for i in range(20)]:
        #print 'trying', theta_name
        if not hasattr(args, theta_name):
            #print 'missing', theta_name
            continue
        _, xedges, yedges = sp.histogram2d([0,K], [0,K], bins=[K,K])
        extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
        if K == 18:
            numx_plots = 6
            numy_plots = 3
        elif K == 15:
            numx_plots = 5
            numy_plots = 3
        else:
            numx_plots = int(ceil(sp.sqrt(K)))
            numy_plots = int(ceil(sp.sqrt(K)))
        matplotlib.rcParams['font.size'] = 8
        fig, axs = pyplot.subplots(numy_plots, numx_plots, sharex=True, sharey=True, figsize=(numx_plots*2.5,numy_plots*2.5))
        for k in xrange(K):
            pltx, plty = k // numx_plots, k % numx_plots
            #axs[pltx,plty].imshow(args.theta[k,:,:], extent=extent, interpolation='nearest',
            axs[pltx,plty].imshow(getattr(args, theta_name)[:,k,:].astype(sp.float64), extent=extent, interpolation='nearest',
                          vmin=0, vmax=1, cmap='OrRd', aspect='auto', origin='lower')
            #if k < numx_plots:
            #axs[pltx,plty].text(0 + .5, K - .5, 'vp=%s' % (k+1), horizontalalignment='left', verticalalignment='top', fontsize=10)
            axs[pltx,plty].text(0 + .5, K - .5, 'hp=%s' % (k+1), horizontalalignment='left', verticalalignment='top', fontsize=10)
            #axs[pltx,plty].xticks(sp.arange(K) + .5, sp.arange(K))
            #axs[pltx,plty].yticks(sp.arange(K) + .5, sp.arange(K))
            axs[pltx,plty].set_xticks(sp.arange(K) + .5)
            axs[pltx,plty].set_xticks(sp.arange(K)+1, minor=True)
            axs[pltx,plty].set_xticklabels(sp.arange(K) + 1)
            axs[pltx,plty].set_yticks(sp.arange(K) + .5)
            axs[pltx,plty].set_yticks(sp.arange(K)+1, minor=True)
            axs[pltx,plty].set_yticklabels(sp.arange(K) + 1)
            for line in axs[pltx,plty].yaxis.get_ticklines() + axs[pltx,plty].xaxis.get_ticklines() + axs[pltx,plty].yaxis.get_ticklines(minor=True) + axs[pltx,plty].xaxis.get_ticklines(minor=True):
                line.set_markersize(0)
            axs[pltx,plty].grid(True, which='minor', alpha=.2)

        #fig.suptitle(r"$\Theta$ with fixed parents for {approx} iteration {iteration}".
        #                    format(approx=args.approx, iteration=args.iteration),
        #                    fontsize=14, verticalalignment='top')
        fig.suptitle('Node state', y=.03, fontsize=14, verticalalignment='center')
        #fig.suptitle('Horizontal parent state', y=.5, x=.02, rotation=90,
        fig.suptitle('Vertical parent state', y=.5, x=.02, rotation=90,
                     verticalalignment='center', fontsize=14)
        matplotlib.rcParams['font.size'] = 6.5
        fig.subplots_adjust(wspace=.05, hspace=.05, left=.05, right=.95)
        #b = fig.colorbar(shrink=.9)
        #b.set_label("Probability")
        outfile = (args.out_params + '_vertparent_it{iteration}.png').format(param=theta_name, **args.__dict__)
        pyplot.savefig(os.path.join(args.out_dir, outfile), dpi=240)


        fig, axs = pyplot.subplots(numy_plots, numx_plots, sharex=True, sharey=True, figsize=(numx_plots*2.5,numy_plots*2.5))
        for k in xrange(K):
            pltx, plty = k // numx_plots, k % numx_plots
            axs[pltx,plty].imshow(getattr(args, theta_name)[k,:,:].astype(sp.float64), extent=extent, interpolation='nearest',
            #axs[pltx,plty].imshow(args.theta[:,k,:], extent=extent, interpolation='nearest',
                          vmin=0, vmax=1, cmap='OrRd', aspect='auto', origin='lower')
            #if k < numx_plots:
            axs[pltx,plty].text(0 + .5, K - .5, 'vp=%s' % (k+1), horizontalalignment='left', verticalalignment='top', fontsize=10)
            #axs[pltx,plty].xticks(sp.arange(K) + .5, sp.arange(K))
            #axs[pltx,plty].yticks(sp.arange(K) + .5, sp.arange(K))
            axs[pltx,plty].set_xticks(sp.arange(K) + .5)
            axs[pltx,plty].set_xticks(sp.arange(K)+1, minor=True)
            axs[pltx,plty].set_xticklabels(sp.arange(K) + 1)
            axs[pltx,plty].set_yticks(sp.arange(K) + .5)
            axs[pltx,plty].set_yticks(sp.arange(K)+1, minor=True)
            axs[pltx,plty].set_yticklabels(sp.arange(K) + 1)
            for line in axs[pltx,plty].yaxis.get_ticklines() + axs[pltx,plty].xaxis.get_ticklines() + axs[pltx,plty].yaxis.get_ticklines(minor=True) + axs[pltx,plty].xaxis.get_ticklines(minor=True):
                line.set_markersize(0)
            axs[pltx,plty].grid(True, which='minor', alpha=.2)

        #fig.suptitle(r"$\Theta$ with fixed parents for {approx} iteration {iteration}".
        #                    format(approx=args.approx, iteration=args.iteration),
        #                    fontsize=14, verticalalignment='top')
        fig.suptitle('Node state', y=.03, fontsize=14, verticalalignment='center')
        fig.suptitle('Horizontal parent state', y=.5, x=.02, rotation=90,
        #fig.suptitle('Vertical parent state', y=.5, x=.02, rotation=90,
                     verticalalignment='center', fontsize=14)
        matplotlib.rcParams['font.size'] = 6.5
        fig.subplots_adjust(wspace=.05, hspace=.05, left=.05, right=.95)
        #b = fig.colorbar(shrink=.9)
        #b.set_label("Probability")
        outfile = (args.out_params + '_it{iteration}.png').format(param=theta_name, **args.__dict__)
        pyplot.savefig(os.path.join(args.out_dir, outfile), dpi=240)


    # emission probabilities
    if args.continuous_observations:
        # plot mean values
        matplotlib.rcParams['font.size'] = 8
        pyplot.figure(figsize=(max(1,round(L/3.)),max(1, round(K/3.))))
        print (max(1,round(L/3.)),max(1, round(K/3.)))
        pyplot.imshow(args.means.astype(sp.float64), interpolation='nearest', aspect='auto',
                      vmin=0, vmax=args.means.max(), cmap='OrRd', origin='lower')
        for k in range(K):
            for l in range(L):
                pyplot.text(l, k, '%.1f' % (args.means[k,l]), horizontalalignment='center', verticalalignment='center', fontsize=5)
        pyplot.yticks(sp.arange(K), sp.arange(K)+1)
        pyplot.gca().set_yticks(sp.arange(K)+.5, minor=True)
        pyplot.xticks(sp.arange(L), valid_marks, rotation=30, horizontalalignment='right')
        pyplot.gca().set_xticks(sp.arange(L)+.5, minor=True)
        pyplot.grid(which='minor', alpha=.2)
        for line in pyplot.gca().yaxis.get_ticklines() + pyplot.gca().xaxis.get_ticklines() + pyplot.gca().xaxis.get_ticklines(minor=True) + pyplot.gca().yaxis.get_ticklines(minor=True):
        # label is a Text instance
            line.set_markersize(0)
        pyplot.ylabel('Hidden State')
        pyplot.title("Emission Mean")
        #b = pyplot.colorbar(shrink=.7)
        #b.set_label("Probability")
        outfile = (args.out_params + '_it{iteration}.png').format(param='emission_means', **args.__dict__)
        pyplot.savefig(os.path.join(args.out_dir, outfile), dpi=240)

        # plot variances
        pyplot.figure(figsize=(max(1,round(L/3.)),max(1, round(K/3.))))
        print (L/3,K/3.)
        pyplot.imshow(args.variances.astype(sp.float64), interpolation='nearest', aspect='auto',
                      vmin=0, vmax=args.variances.max(), cmap='OrRd', origin='lower')
        for k in range(K):
            for l in range(L):
                pyplot.text(l, k, '%.1f' % (args.variances[k,l]), horizontalalignment='center', verticalalignment='center', fontsize=5)
        pyplot.yticks(sp.arange(K), sp.arange(K)+1)
        pyplot.gca().set_yticks(sp.arange(K)+.5, minor=True)
        pyplot.xticks(sp.arange(L), valid_marks, rotation=30, horizontalalignment='right')
        pyplot.gca().set_xticks(sp.arange(L)+.5, minor=True)
        pyplot.grid(which='minor', alpha=.2)
        for line in pyplot.gca().yaxis.get_ticklines() + pyplot.gca().xaxis.get_ticklines() + pyplot.gca().xaxis.get_ticklines(minor=True) + pyplot.gca().yaxis.get_ticklines(minor=True):
        # label is a Text instance
            line.set_markersize(0)
        pyplot.ylabel('Hidden State')
        pyplot.title("Emission Variance")
        #b = pyplot.colorbar(shrink=.7)
        #b.set_label("Probability")
        outfile = (args.out_params + '_it{iteration}.png').format(param='emission_variances', **args.__dict__)
        pyplot.savefig(os.path.join(args.out_dir, outfile), dpi=240)
    else:
        matplotlib.rcParams['font.size'] = 8
        pyplot.figure(figsize=(max(1,round(L/3.)),max(1, round(K/3.))))
        print (L/3,K/3.)
        pyplot.imshow(args.emit_probs.astype(sp.float64), interpolation='nearest', aspect='auto',
                      vmin=0, vmax=1, cmap='OrRd', origin='lower')
        for k in range(K):
            for l in range(L):
                pyplot.text(l, k, '%2.0f' % (args.emit_probs[k,l] * 100), horizontalalignment='center', verticalalignment='center')
        pyplot.yticks(sp.arange(K), sp.arange(K)+1)
        pyplot.gca().set_yticks(sp.arange(K)+.5, minor=True)
        pyplot.xticks(sp.arange(L), valid_marks, rotation=30, horizontalalignment='right')
        pyplot.gca().set_xticks(sp.arange(L)+.5, minor=True)
        pyplot.grid(which='minor', alpha=.2)
        for line in pyplot.gca().yaxis.get_ticklines() + pyplot.gca().xaxis.get_ticklines() + pyplot.gca().xaxis.get_ticklines(minor=True) + pyplot.gca().yaxis.get_ticklines(minor=True):
        # label is a Text instance
            line.set_markersize(0)
        pyplot.ylabel('Hidden State')
        pyplot.title("Emission probabilities")
        #b = pyplot.colorbar(shrink=.7)
        #b.set_label("Probability")
        outfile = (args.out_params + '_it{iteration}.png').format(param='emission', **args.__dict__)
        pyplot.savefig(os.path.join(args.out_dir, outfile), dpi=240)


    #broad_paper_enrichment = sp.array([[16,2,2,6,17,93,99,96,98,2],
    #                               [12,2,6,9,53,94,95,14,44,1],
    #                               [13,72,0,9,48,78,49,1,10,1],
    #                               [11,1,15,11,96,99,75,97,86,4],
    #                               [5,0,10,3,88,57,5,84,25,1],
    #                               [7,1,1,3,58,75,8,6,5,1],
    #                               [2,1,2,1,56,3,0,6,2,1],
    #                               [92,2,1,3,6,3,0,0,1,1],
    #                               [5,0,43,43,37,11,2,9,4,1],
    #                               [1,0,47,3,0,0,0,0,0,1],
    #                               [0,0,3,2,0,0,0,0,0,0],
    #                               [1,27,0,2,0,0,0,0,0,0],
    #                               [0,0,0,0,0,0,0,0,0,0],
    #                               [22,28,19,41,6,5,26,5,13,37],
    #                               [85,85,91,88,76,77,91,73,85,78],
    #                               [float('nan'), float('nan'),float('nan'),float('nan'),float('nan'),float('nan'),float('nan'),float('nan'),float('nan'),float('nan')]
    #                            ]) / 100.
    #mapping_from_broad = dict(zip(range(K), (5,2,0,14,4,6,9,1,12,-1,3,12,8,7,10,12,11,13)))
    #broad_paper_enrichment = broad_paper_enrichment[tuple(mapping_from_broad[i] for i in range(K)), :]
    #broad_names = ['Active promoter', 'Weak promoter', 'Inactive/poised promoter', 'Strong enhancer',
    #               'Strong enhancer', 'weak/poised enhancer', 'Weak/poised enhancer', 'Insulator',
    #               'Transcriptional transition', 'Transcriptional elongation', 'Weak transcribed',
    #               'Polycomb repressed', 'Heterochrom; low signal', 'Repetitive/CNV', 'Repetitive/CNV',
    #               'NA', 'NA', 'NA']
    #pyplot.figure(figsize=(L/3,K/3.))
    #print (L/3,K/3.)
    #pyplot.imshow(broad_paper_enrichment, interpolation='nearest', aspect='auto',
    #              vmin=0, vmax=1, cmap='OrRd', origin='lower')
    #for k in range(K):
    #    for l in range(L):
    #        pyplot.text(l, k, '%2.0f' % (broad_paper_enrichment[k,l] * 100), horizontalalignment='center', verticalalignment='center')
    #    pyplot.text(L, k, broad_names[mapping_from_broad[k]], horizontalalignment='left', verticalalignment='center', fontsize=6)
    #pyplot.yticks(sp.arange(K), sp.arange(K)+1)
    #pyplot.gca().set_yticks(sp.arange(K)+.5, minor=True)
    #pyplot.xticks(sp.arange(L), valid_marks, rotation=30, horizontalalignment='right')
    #pyplot.gca().set_xticks(sp.arange(L)+.5, minor=True)
    #pyplot.grid(which='minor', alpha=.2)
    #for line in pyplot.gca().yaxis.get_ticklines() + pyplot.gca().xaxis.get_ticklines() + pyplot.gca().xaxis.get_ticklines(minor=True) + pyplot.gca().yaxis.get_ticklines(minor=True):
    ## label is a Text instance
    #    line.set_markersize(0)
    #pyplot.ylabel('Hidden State')
    #pyplot.title("Broad paper Emission probabilities")
    ##b = pyplot.colorbar(shrink=.7)
    ##b.set_label("Probability")
    #pyplot.subplots_adjust(right=.7)
    #outfile = (args.out_params + '_broadpaper.png').format(param='emission', **args.__dict__)
    #pyplot.savefig(os.path.join(args.out_dir, outfile), dpi=240)

    pyplot.close('all')
    sp.seterr(**old_err)
    matplotlib.rcParams['font.size'] = oldsize

Example 20

Project: ursgal
Source File: barth_et_al_large_scale.py
View license
def main(folder):
    '''

    Example script for reproducing the data for figure 3

    usage:

        ./barth_et_al_large_scale.py <folder>

    The folder determines the target folder where the files will be downloaded

    Chlamydomonas reinhardtii samples

    Three biological replicates of 4 conditions (2_3, 2_4, 3_1, 4_1)

    For more details on the samples please refer to
    Barth, J.; Bergner, S. V.; Jaeger, D.; Niehues, A.; Schulze, S.; Scholz,
    M.; Fufezan, C. The interplay of light and oxygen in the reactive oxygen
    stress response of Chlamydomonas reinhardtii dissected by quantitative mass
    spectrometry. MCP 2014, 13 (4), 969–989.

    Merge all search results (per biological replicate and condition, on folder
    level) on engine level and validate via percolator.

    'LTQ XL high res':

        * repetition 1
        * repetition 2

    'LTQ XL low res':

        * repetition 3

    Database:

        * Creinhardtii_281_v5_5_CP_MT_with_contaminants_target_decoy.fasta

    Note:

        The database and the files will be automatically downloaded from our
        webpage and peptideatlas
    '''

    input_params = {
        'database' : os.path.join(
            os.pardir,
            'example_data',
            'Creinhardtii_281_v5_5_CP_MT_with_contaminants_target_decoy.fasta'
        ),
        'modifications' : [
            'M,opt,any,Oxidation',
            '*,opt,Prot-N-term,Acetyl',  # N-Acetylation
        ],
        'ftp_url'       : 'ftp.peptideatlas.org',

        'ftp_login'         : 'PASS00269',
        'ftp_password'      : 'FI4645a',

        'ftp_output_folder_root' : folder,
        'http_url': 'http://www.uni-muenster.de/Biologie.IBBP.AGFufezan/misc/Creinhardtii_281_v5_5_CP_MT_with_contaminants_target_decoy.fasta' ,
        'http_output_folder' : os.path.join(
            os.pardir,
            'example_data'
        )
    }

    uc = ursgal.UController(
        params = input_params
    )

    if os.path.exists(input_params['database']) is False:
        uc.fetch_file(
            engine     = 'get_http_files_1_0_0'
        )

    output_folder_to_file_list ={
        ('rep1_sample_2_3','LTQ XL high res') : [
            'CF_07062012_pH8_2_3A.mzML',
            'CF_13062012_pH3_2_3A.mzML',
            'CF_13062012_pH4_2_3A.mzML',
            'CF_13062012_pH5_2_3A.mzML',
            'CF_13062012_pH6_2_3A.mzML',
            'CF_13062012_pH11FT_2_3A.mzML',
        ],
        ('rep1_sample_2_4','LTQ XL high res') : [
            'CF_07062012_pH8_2_4A.mzML',
            'CF_13062012_pH3_2_4A_120615113039.mzML',
            'CF_13062012_pH4_2_4A.mzML',
            'CF_13062012_pH5_2_4A.mzML',
            'CF_13062012_pH6_2_4A.mzML',
            'CF_13062012_pH11FT_2_4A.mzML',

        ],
        ('rep1_sample_3_1','LTQ XL high res') : [
            'CF_12062012_pH8_1_3A.mzML',
            'CF_13062012_pH3_1_3A.mzML',
            'CF_13062012_pH4_1_3A.mzML',
            'CF_13062012_pH5_1_3A.mzML',
            'CF_13062012_pH6_1_3A.mzML',
            'CF_13062012_pH11FT_1_3A.mzML',
        ],
        ('rep1_sample_4_1','LTQ XL high res') : [
            'CF_07062012_pH8_1_4A.mzML',
            'CF_13062012_pH3_1_4A.mzML',
            'CF_13062012_pH4_1_4A.mzML',
            'CF_13062012_pH5_1_4A.mzML',
            'CF_13062012_pH6_1_4A.mzML',
            'CF_13062012_pH11FT_1_4A.mzML',
        ],

        ('rep2_sample_2_3','LTQ XL high res') : [
            'JB_18072012_2-3_A_FT.mzML',
            'JB_18072012_2-3_A_pH3.mzML',
            'JB_18072012_2-3_A_pH4.mzML',
            'JB_18072012_2-3_A_pH5.mzML',
            'JB_18072012_2-3_A_pH6.mzML',
            'JB_18072012_2-3_A_pH8.mzML',
        ],
        ('rep2_sample_2_4','LTQ XL high res') : [

            'JB_18072012_2-4_A_FT.mzML',
            'JB_18072012_2-4_A_pH3.mzML',
            'JB_18072012_2-4_A_pH4.mzML',
            'JB_18072012_2-4_A_pH5.mzML',
            'JB_18072012_2-4_A_pH6.mzML',
            'JB_18072012_2-4_A_pH8.mzML',

        ],
        ('rep2_sample_3_1','LTQ XL high res') : [
            'JB_18072012_3-1_A_FT.mzML',
            'JB_18072012_3-1_A_pH3.mzML',
            'JB_18072012_3-1_A_pH4.mzML',
            'JB_18072012_3-1_A_pH5.mzML',
            'JB_18072012_3-1_A_pH6.mzML',
            'JB_18072012_3-1_A_pH8.mzML',
        ],
        ('rep2_sample_4_1','LTQ XL high res') : [
            'JB_18072012_4-1_A_FT.mzML',
            'JB_18072012_4-1_A_pH3.mzML',
            'JB_18072012_4-1_A_pH4.mzML',
            'JB_18072012_4-1_A_pH5.mzML',
            'JB_18072012_4-1_A_pH6.mzML',
            'JB_18072012_4-1_A_pH8.mzML',
        ],

        ('rep3_sample_2_3','LTQ XL low res'): [
            'JB_FASP_pH3_2-3_28122012.mzML',
            'JB_FASP_pH4_2-3_28122012.mzML',
            'JB_FASP_pH5_2-3_28122012.mzML',
            'JB_FASP_pH6_2-3_28122012.mzML',
            'JB_FASP_pH8_2-3_28122012.mzML',
            'JB_FASP_pH11-FT_2-3_28122012.mzML',
        ],
        ('rep3_sample_2_4','LTQ XL low res'): [
            'JB_FASP_pH3_2-4_28122012.mzML',
            'JB_FASP_pH4_2-4_28122012.mzML',
            'JB_FASP_pH5_2-4_28122012.mzML',
            'JB_FASP_pH6_2-4_28122012.mzML',
            'JB_FASP_pH8_2-4_28122012.mzML',
            'JB_FASP_pH11-FT_2-4_28122012.mzML',
        ],
        ('rep3_sample_3_1','LTQ XL low res'): [
            'JB_FASP_pH3_3-1_28122012.mzML',
            'JB_FASP_pH4_3-1_28122012.mzML',
            'JB_FASP_pH5_3-1_28122012.mzML',
            'JB_FASP_pH6_3-1_28122012.mzML',
            'JB_FASP_pH8_3-1_28122012.mzML',
            'JB_FASP_pH11-FT_3-1_28122012.mzML',
        ],
        ('rep3_sample_4_1','LTQ XL low res'): [
            'JB_FASP_pH3_4-1_28122012.mzML',
            'JB_FASP_pH4_4-1_28122012.mzML',
            'JB_FASP_pH5_4-1_28122012.mzML',
            'JB_FASP_pH6_4-1_28122012.mzML',
            'JB_FASP_pH8_4-1_28122012.mzML',
            'JB_FASP_pH11-FT_4-1_28122012_130121201449.mzML',
        ],
    }

    for (outfolder,profile), mzML_file_list in sorted(output_folder_to_file_list.items()):
        uc.params['ftp_output_folder'] = os.path.join(
            input_params['ftp_output_folder_root'],
            outfolder
        )
        uc.params['ftp_include_ext'] = mzML_file_list

        if os.path.exists(uc.params['ftp_output_folder']) is False:
            os.makedirs( uc.params['ftp_output_folder'] )

        uc.fetch_file(
            engine     = 'get_ftp_files_1_0_0'
        )

    if os.path.exists(input_params['database']) is False:
        uc.fetch_file(
            engine     = 'get_http_files_1_0_0'
        )

    search_engines  = [
        'omssa_2_1_9',
        'xtandem_piledriver',
        'myrimatch_2_1_138',
        'msgfplus_v9979',
        'msamanda_1_0_0_5243',
    ]

    # This dict will be populated with the percolator-validated results
    # of each engine ( 3 replicates x4 conditions = 12 files each )
    percolator_results = {
        'omssa_2_1_9'        : [],
        'xtandem_piledriver' : [],
        'msgfplus_v9979'     : [],
        'myrimatch_2_1_138'  : [],
        'msamanda_1_0_0_5243': [],
    }

    

    five_files_for_venn_diagram = []

    for search_engine in search_engines:

        # This list will collect all 12 result files for each engine,
        # after Percolator validation and filtering for PSMs with a 
        # FDR <= 0.01
        filtered_results_of_engine = []
        for mzML_dir_ext, mass_spectrometer in output_folder_to_file_list.keys():
        # for mass_spectrometer, replicate_dir in replicates:
            # for condition_dir in conditions:
            uc.set_profile( mass_spectrometer )

            mzML_dir = os.path.join(
                input_params['ftp_output_folder_root'],
                mzML_dir_ext
            )
            # i.e. /media/plan-f/mzML/Christian_Fufezan/ROS_Experiment_2012/Juni_2012/2_3/Tech_A/
            # all files ending with .mzml in that directory will be used!

            unified_results_list = []
            for filename in glob.glob( os.path.join( mzML_dir,'*.mzML') ):
                # print(filename)
                if filename.lower().endswith(".mzml"):
                    # print(filename)
                    unified_search_results = uc.search(
                        input_file = filename,
                        engine     = search_engine,
                    )
                    unified_results_list.append(
                        unified_search_results
                    )

            # Merging results from the 6 pH-fractions:
            merged_unified = uc.merge_csvs( unified_results_list )

            # Validation with Percolator:
            percolator_validated = uc.validate(
                input_file = merged_unified,
                engine     = 'percolator_2_08',  # one could replace this with 'qvality'
            )
            percolator_results[ search_engine ].append(
                percolator_validated
            )

            # At this point, the analysis is finished. We got
            # Percolator-validated results for each of the 3
            # replicates and 12 conditions.

            # But let's see how well the five search engines
            # performed! To compare, we collect all PSMs with
            # an estimated FDR <= 0.01 for each engine, and
            # plot this information with the VennDiagram UNode.
            # We will also use the Combine FDR Score method
            # to combine the results from all five engines,
            # and increase the number of identified peptides.

    
    five_large_merged = []
    filtered_final_results = []

    # We will estimate the FDR for all 60 files
    # (5 engines x12 files) when using percolator PEPs as
    # quality score
    uc.params['validation_score_field'] = 'PEP'
    uc.params['bigger_scores_better']   = False
    
    # To make obtain smaller CSV files (and make plotting
    # less RAM-intensive, we remove all decoys and PSMs above
    # 0.06 FDR
    uc.params['csv_filter_rules'] = [
        ['estimated_FDR', 'lte', 0.06],
        ['Is decoy', 'equals', 'false']
    ]
    for engine, percolator_validated_list in percolator_results.items():

        # unfiltered files for cFDR script
        twelve_merged = uc.merge_csvs( percolator_validated_list )

        twelve_filtered = []
        for one_of_12 in percolator_validated_list:
            one_of_12_FDR = uc.add_estimated_fdr(
                input_file = one_of_12,
            )
            one_of_12_FDR_filtered = uc.filter_csv(
                input_file = one_of_12_FDR,
            )
            twelve_filtered.append( one_of_12_FDR_filtered )

        # For the combined FDR scoring, we merge all 12 files:
        filtered_merged = uc.merge_csvs( twelve_filtered )

        five_large_merged.append( twelve_merged )
        filtered_final_results.append( filtered_merged )

    # The five big merged files of each engine are combined:
    cFDR = uc.combine_search_results(
        input_files = five_large_merged,
        engine      = 'combine_FDR_0_1',
    )

    # We estimate the FDR of this combined approach:
    uc.params['validation_score_field'] = 'Combined FDR Score'
    uc.params['bigger_scores_better']   = False

    cFDR_FDR = uc.add_estimated_fdr(
        input_file = cFDR,
    )

    # Removing decoys and low quality hits, to obtain a
    # smaller file:
    uc.params['csv_filter_rules'] = [
        ['estimated_FDR', 'lte', 0.06],
        ['Is decoy', 'equals', 'false']
    ]
    cFDR_filtered_results = uc.filter_csv(
        input_file = cFDR_FDR,
    )
    filtered_final_results.append( cFDR_filtered_results )

    # Since we produced quite a lot of files, let's print the full
    # paths to our most important result files so we find them quickly:
    print('''
    These files can now be easily parsed and plotted with your
    plotting tool of choice! We used the Python plotting library
    matplotlib. Each unique combination of Sequence, modification
    and charge was counted as a unique peptide.
    ''')
    print("\n########### Result files: ##############")
    for x in filtered_final_results:
        print( x )

Example 21

Project: ursgal
Source File: percolator_2_08.py
View license
    def preflight( self ):
        '''
        Formating the command line to via self.params
        '''

        PERCOLATOR_FIELDS = OrderedDict([
            (
                'SpecId', {
                    'csv_field': 'Spectrum Title',
                    'DefaultDirection': 'DefaultDirection'
                }
            ),
            (
                'Label', {
                    'csv_field': '',
                    'DefaultDirection': '-'
                }
            ),
            (
                'ScanNr', {
                    'csv_field': 'Spectrum ID',
                    'DefaultDirection': '-'
                }
            ),
            (
                'lnrSp', {
                    'csv_field': '',
                    'DefaultDirection': 0,
                    'description': \
                'The natural logarithm of the rank of the match based on the Sp score'
                }
            ),
            (
                'deltLCn', {
                    'csv_field': '',
                    'DefaultDirection': 0,
                    'description': \
                "The difference between this PSM's XCorr and the XCorr of the last-ranked \
                PSM for this spectrum, divided by this PSM's XCorr or 1, whichever is larger."
                }
            ),
            (
                'deltCn', {
                    'csv_field': '',
                    'DefaultDirection': 0,
                    'description': \
                    "The difference between this PSM's XCorr and the XCorr of the next-ranked \
                PSM for this spectrum, divided by this PSM's XCorr or 1, whichever is larger. \
                Note that this definition differs from that of the standard delta Cn reported \
                by SEQUEST®"
                        }
            ),
            (
                'Xcorr', {
                    'csv_field': '',
                    'DefaultDirection': 0,
                    'description': \
                        "The SEQUEST cross-correlation score"
                        }
            ),
            (
                'Sp', {
                    'csv_field': '',
                    'DefaultDirection': 0,
                    'description': \
                "The preliminary SEQUEST score."
                        }
            ),
            (
                'IonFrac', {
                    'csv_field': '',
                    'DefaultDirection': 0,
                    'description': \
                "The fraction of b and y ions theoretical ions matched to the spectrum"
                        }
            ),
            (
                'Mass',    {
                    'csv_field': '',
                    'DefaultDirection': 0,
                    'description': \
                "The observed mass [M+H]+"
                        }
            ),
            (
                'PepLen',  {
                    'csv_field': '',
                    'DefaultDirection': 0,
                    'description': \
                "The length of the matched peptide, in residues"
                }
            ),
            (
                'Charge1', {
                    'csv_field': '',
                    'DefaultDirection': 0
                }
            ),
            (
                'Charge2', {
                    'csv_field': '',
                    'DefaultDirection': 0
                }
            ),
            (
                'Charge3', {
                    'csv_field': '',
                    'DefaultDirection': 0
                }
            ),
            (
                'Charge4', {
                    'csv_field': '',
                    'DefaultDirection': 0
                }
            ),
            (
                'Charge5', {
                    'csv_field': '',
                    'DefaultDirection': 0
                }
            ),
            (
                'Charge6', {
                    'csv_field': '',
                    'DefaultDirection': 0
                }
            ),
            (
                'Charge7', {
                    'csv_field': '',
                    'DefaultDirection': 0
                }
            ),
            (
                'Charge8', {
                    'csv_field': '',
                    'DefaultDirection': 0
                }
            ),
            (
                'Charge9', {
                    'csv_field': '',
                    'DefaultDirection': 0
                }
            ),
            (
                'Charge10', {
                    'csv_field': '',
                    'DefaultDirection': 0
                }
            ),
            (
                'enzN', {
                    'csv_field': '',
                    'DefaultDirection': 0,
                    'description': "Is the peptide preceded by an enzymatic (tryptic) site?"
                }
            ),
            (
                'enzC', {
                    'csv_field': '',
                    'DefaultDirection': 0,
                    'description': "Does the peptide have an enzymatic (tryptic) C-terminus?"
                }
            ),
            (
                'enzInt', {
                    'csv_field': '',
                    'DefaultDirection': 0,
                    'description': "Number of missed internal enzymatic (tryptic) sites"
                }
            ),
            (
                'lnNumSP', {
                    'csv_field': '',
                    'DefaultDirection': 0,
                    'description': "The natural logarithm of the number of database peptides within the \
                    specified precursor range"
                }
            ),
            (
                'dM', {
                    'csv_field': '',
                    'DefaultDirection': 0,
                    'description': "The difference between the calculated and observed mass"
                }
            ),
            (
                'absdM', {
                    'csv_field': '',
                    'DefaultDirection': 0,
                    'description': "The absolute value of the difference between the calculated and \
                    observed mass"
                }
            ),
            (
                'Peptide', {
                    'csv_field': '',
                    'DefaultDirection': 0,
                    'description': ""
                }
            ),
            (
                'Proteins', {
                    'csv_field': 'Protein ID',
                    'DefaultDirection': 0,
                    'description': ""
                }
            ),
        ])


        self.params['translations']['csv_input_file'] = os.path.join(
            self.params['input_dir_path'],
            self.params['input_file']
        )

        self.params['translations']['output_file_incl_path'] = os.path.join(
            self.params['output_dir_path'],
            self.params['output_file']
        )

        # this file will contain the decoys only (we need it for fdr calculations and such)
        self.params['translations']['decoy_output_file_incl_path'] = os.path.join(
            self.params['output_dir_path'],
            "decoysOnly_" + self.params['output_file']
        )

        self.params['translations']['percolator_in'] = \
            '{output_file_incl_path}.tsv'.format(
                **self.params['translations']
        )

        self.params['translations']['percolator_out'] = \
            '{output_file_incl_path}.psms'.format(**self.params['translations'])

        # writing the Percolator input file (tab separated format)
        o = open( self.params['translations']['percolator_in'], 'w')
        writer = csv.DictWriter(
            o,
            list(PERCOLATOR_FIELDS.keys()),
            delimiter='\t',
            extrasaction='ignore'
        )
        writer.writeheader()

        self.params['translations']['percolator_decoy_out'] = \
            '{decoy_output_file_incl_path}.psms'.format(**self.params['translations'])
        self.params['command_list'] = [
            # percolator -X pout.xml pin.tab >| yeast-01.psms
            self.exe,
            '--only-psms',
            '{percolator_in}'.format(**self.params['translations']),
            '--results-psms',
            '{percolator_out}'.format(**self.params['translations']),
            '--decoy-results-psms',
            '{percolator_decoy_out}'.format(**self.params['translations'])
        ]

        last_search_engine = self.get_last_search_engine(
            history = self.stats['history']
        )
        self.params['_score_list'] = self.generating_score_list()
        minimum_score = None
        bigger_scores_better = self.UNODE_UPARAMS['bigger_scores_better']['uvalue_style_translation'][last_search_engine]
        if bigger_scores_better is False:
            for p, _score in enumerate(self.params['_score_list']):
                if _score <= 0:
                    # jumping over truly zero or negative values ...
                    continue
                else:
                    s = -1 * math.log( _score, 10 )
                    if minimum_score is None:
                        minimum_score = _score
                # OMSSA HACK OMG PLEASE HELP US ALL
                if "OMSSA" in last_search_engine.upper():
                    fo = transform_score(_score, minimum_score)
                    if fo < 100:
                        minimum_score = _score
                        break
                else:
                    break

        for n, spectrum_title in enumerate( self.params['grouped_psms'].keys()):
            best_score = self.params['grouped_psms'][ spectrum_title ][0][0]
            worst_score = self.params['grouped_psms'][ spectrum_title ][ -1 ][0]

            if bigger_scores_better is False:
                best_score = transform_score(best_score, minimum_score)
                worst_score = transform_score(worst_score, minimum_score)

            for m, (score, line_dict) in enumerate(
                    self.params['grouped_psms'][ spectrum_title ]):
                t = {}
                if bigger_scores_better is True:
                    rank_of_score = bisect.bisect_right(
                        self.params['_score_list'],
                        score
                    )
                else:
                    rank_of_score = bisect.bisect_left(
                        self.params['_score_list'],
                        score
                    )
                    rank_of_score = len( self.params['_score_list'] ) - rank_of_score
                #
                # t['lnrSp'] = math.log( 1 + rank_of_score )
                # t['Sp'] = rank_of_score
                #

                charge      = float(line_dict['Charge'])
                exp_mz      = float(line_dict['Exp m/z'])
                t['Mass']   = ( exp_mz * charge ) - ( charge - 1 ) * PROTON
                #
                t['Xcorr'] = score
                if bigger_scores_better is False:
                    t['Xcorr'] = transform_score(t['Xcorr'], minimum_score)

                t['PepLen'] = len( line_dict['Sequence'] )
                t['Charge{Charge}'.format(**line_dict)] = 1

                normalization = t['Xcorr']
                if t['Xcorr'] < 1:
                    normalization = 1
                if m == len( self.params['grouped_psms'][ spectrum_title ] ) - 1:
                    # last entry
                    deltLCn = 0
                    deltCn = 0
                else:
                    deltLCn = (t['Xcorr'] - worst_score) / normalization
                    next_score = self.params['grouped_psms'][ spectrum_title ][m + 1][0]
                    if bigger_scores_better is False:
                        next_score = transform_score(next_score, minimum_score)

                    deltCn = (t['Xcorr'] - next_score ) / normalization
                t['deltCn'] = deltCn
                t['deltLCn'] = deltLCn
                if line_dict['Is decoy'].upper() == 'TRUE':
                    t['Label'] = -1
                else:
                    t['Label'] = 1

                # if self.params['translations']['decoy_tag'] in line_dict['proteinacc_start_stop_pre_post_;']:
                #     # bug mzIdentML msgf+ convert
                #     t['Label'] = -1

                # this - is - sparta (or, if you like mzIdentML) ...
                # splitted = line_dict['proteinacc_start_stop_pre_post_;'].split('_')
                # aka http://imgur.com/WjiX9
                pre_aa = []
                for prot_pre_aa in line_dict['Sequence Pre AA'].split(self.params['translations']['protein_delimiter']):
                    for p_pre_aa in prot_pre_aa.split(';'):
                        pre_aa.append(p_pre_aa)
                post_aa = []
                for prot_post_aa in line_dict['Sequence Post AA'].split(self.params['translations']['protein_delimiter']):
                    for p_post_aa in prot_post_aa.split(';'):
                        post_aa.append(p_post_aa)
                allowed_aa = set(self.params['translations']['enzyme'].split(';')[0] + '-')
                cleavage_site = self.params['translations']['enzyme'].split(';')[1]
                inhibitor_aa = self.params['translations']['enzyme'].split(';')[2]
                final_pre_aa = pre_aa[0]
                final_post_aa = post_aa[0]
                t['enzN'] = 0
                t['enzC'] = 0
                if cleavage_site == 'C':
                    for i, aa in enumerate(pre_aa):
                        if aa in allowed_aa  \
                            or line_dict['Sequence Start'] in ['1','2']:
                            t['enzN'] = 1
                            final_pre_aa = aa
                            final_post_aa = post_aa[i]
                        if line_dict['Sequence'][-1] in allowed_aa\
                            or post_aa == '-':
                            t['enzC'] = 1
                elif cleavage_site == 'N':
                    for i, aa in enumerate(post_aa):
                        if aa in allowed_aa:
                            t['enzC'] = 1
                            final_post_aa = aa
                            final_pre_aa = pre_aa[i]
                        if line_dict['Sequence'][0] in allowed_aa\
                            or line_dict['Sequence Start'] in ['1','2']:
                            t['enzN'] = 1

                t['enzInt'] = 0
                for aa in line_dict['Sequence'][:-1]:
                    if aa in allowed_aa:
                        t['enzInt'] += 1

                t['dM'] = float(line_dict['Calc m/z']) - float(line_dict['Exp m/z'])
                t['absdM'] = abs(t['dM'])

                mods = line_dict['Modifications']
                if mods.strip() == '':
                    t['Peptide'] = '{0}.{Sequence}.{1}'.format(
                        sorted(pre_aa)[0],
                        sorted(post_aa)[0],
                        **line_dict
                    )
                else:
                    t['Peptide'] = '{0}.{Sequence}#{1}.{2}'.format(
                        sorted(pre_aa)[0],
                        mods,
                        sorted(post_aa)[0],
                        **line_dict
                    )
                # although peptides without mod ill have # at their end,
                # mapping further down will still work ...

                for per_key in PERCOLATOR_FIELDS.keys():
                    mapped_key = PERCOLATOR_FIELDS[ per_key ]['csv_field']
                    if 'Charge' in per_key:
                        if per_key not in t.keys():
                            t[ per_key ] = 0
                    if mapped_key != '':
                        t[ per_key ] = line_dict[ mapped_key ].strip()
                writer.writerow( t )
        o.close()


        # marking temporary files for deletion:
        self.created_tmp_files += [
            self.params['translations']['decoy_output_file_incl_path'],
            self.params['translations']['percolator_in'],
            '{output_file_incl_path}.psms'.format( **self.params['translations'] ),
            '{output_file_incl_path}.peptides'.format( **self.params['translations'] ),
        ]

Example 22

Project: youtube-dl-GUI
Source File: options.py
View license
def parseOpts(overrideArguments=None):
    def _readOptions(filename_bytes, default=[]):
        try:
            optionf = open(filename_bytes)
        except IOError:
            return default  # silently skip if file is not present
        try:
            res = []
            for l in optionf:
                res += shlex.split(l, comments=True)
        finally:
            optionf.close()
        return res

    def _readUserConf():
        xdg_config_home = compat_getenv('XDG_CONFIG_HOME')
        if xdg_config_home:
            userConfFile = os.path.join(xdg_config_home, 'youtube-dl', 'config')
            if not os.path.isfile(userConfFile):
                userConfFile = os.path.join(xdg_config_home, 'youtube-dl.conf')
        else:
            userConfFile = os.path.join(compat_expanduser('~'), '.config', 'youtube-dl', 'config')
            if not os.path.isfile(userConfFile):
                userConfFile = os.path.join(compat_expanduser('~'), '.config', 'youtube-dl.conf')
        userConf = _readOptions(userConfFile, None)

        if userConf is None:
            appdata_dir = compat_getenv('appdata')
            if appdata_dir:
                userConf = _readOptions(
                    os.path.join(appdata_dir, 'youtube-dl', 'config'),
                    default=None)
                if userConf is None:
                    userConf = _readOptions(
                        os.path.join(appdata_dir, 'youtube-dl', 'config.txt'),
                        default=None)

        if userConf is None:
            userConf = _readOptions(
                os.path.join(compat_expanduser('~'), 'youtube-dl.conf'),
                default=None)
        if userConf is None:
            userConf = _readOptions(
                os.path.join(compat_expanduser('~'), 'youtube-dl.conf.txt'),
                default=None)

        if userConf is None:
            userConf = []

        return userConf

    def _format_option_string(option):
        ''' ('-o', '--option') -> -o, --format METAVAR'''

        opts = []

        if option._short_opts:
            opts.append(option._short_opts[0])
        if option._long_opts:
            opts.append(option._long_opts[0])
        if len(opts) > 1:
            opts.insert(1, ', ')

        if option.takes_value():
            opts.append(' %s' % option.metavar)

        return "".join(opts)

    def _comma_separated_values_options_callback(option, opt_str, value, parser):
        setattr(parser.values, option.dest, value.split(','))

    def _hide_login_info(opts):
        opts = list(opts)
        for private_opt in ['-p', '--password', '-u', '--username', '--video-password']:
            try:
                i = opts.index(private_opt)
                opts[i + 1] = 'PRIVATE'
            except ValueError:
                pass
        return opts

    # No need to wrap help messages if we're on a wide console
    columns = get_term_width()
    max_width = columns if columns else 80
    max_help_position = 80

    fmt = optparse.IndentedHelpFormatter(width=max_width, max_help_position=max_help_position)
    fmt.format_option_strings = _format_option_string

    kw = {
        'version': __version__,
        'formatter': fmt,
        'usage': '%prog [OPTIONS] URL [URL...]',
        'conflict_handler': 'resolve',
    }

    parser = optparse.OptionParser(**compat_kwargs(kw))

    general = optparse.OptionGroup(parser, 'General Options')
    general.add_option(
        '-h', '--help',
        action='help',
        help='print this help text and exit')
    general.add_option(
        '-v', '--version',
        action='version',
        help='print program version and exit')
    general.add_option(
        '-U', '--update',
        action='store_true', dest='update_self',
        help='update this program to latest version. Make sure that you have sufficient permissions (run with sudo if needed)')
    general.add_option(
        '-i', '--ignore-errors',
        action='store_true', dest='ignoreerrors', default=False,
        help='continue on download errors, for example to skip unavailable videos in a playlist')
    general.add_option(
        '--abort-on-error',
        action='store_false', dest='ignoreerrors',
        help='Abort downloading of further videos (in the playlist or the command line) if an error occurs')
    general.add_option(
        '--dump-user-agent',
        action='store_true', dest='dump_user_agent', default=False,
        help='display the current browser identification')
    general.add_option(
        '--list-extractors',
        action='store_true', dest='list_extractors', default=False,
        help='List all supported extractors and the URLs they would handle')
    general.add_option(
        '--extractor-descriptions',
        action='store_true', dest='list_extractor_descriptions', default=False,
        help='Output descriptions of all supported extractors')
    general.add_option(
        '--default-search',
        dest='default_search', metavar='PREFIX',
        help='Use this prefix for unqualified URLs. For example "gvsearch2:" downloads two videos from google videos for  youtube-dl "large apple". Use the value "auto" to let youtube-dl guess ("auto_warning" to emit a warning when guessing). "error" just throws an error. The default value "fixup_error" repairs broken URLs, but emits an error if this is not possible instead of searching.')
    general.add_option(
        '--ignore-config',
        action='store_true',
        help='Do not read configuration files. '
        'When given in the global configuration file /etc/youtube-dl.conf: '
        'Do not read the user configuration in ~/.config/youtube-dl/config '
        '(%APPDATA%/youtube-dl/config.txt on Windows)')
    general.add_option(
        '--flat-playlist',
        action='store_const', dest='extract_flat', const='in_playlist',
        default=False,
        help='Do not extract the videos of a playlist, only list them.')

    network = optparse.OptionGroup(parser, 'Network Options')
    network.add_option(
        '--proxy', dest='proxy',
        default=None, metavar='URL',
        help='Use the specified HTTP/HTTPS proxy. Pass in an empty string (--proxy "") for direct connection')
    network.add_option(
        '--socket-timeout',
        dest='socket_timeout', type=float, default=None, metavar='SECONDS',
        help='Time to wait before giving up, in seconds')
    network.add_option(
        '--source-address',
        metavar='IP', dest='source_address', default=None,
        help='Client-side IP address to bind to (experimental)',
    )
    network.add_option(
        '-4', '--force-ipv4',
        action='store_const', const='0.0.0.0', dest='source_address',
        help='Make all connections via IPv4 (experimental)',
    )
    network.add_option(
        '-6', '--force-ipv6',
        action='store_const', const='::', dest='source_address',
        help='Make all connections via IPv6 (experimental)',
    )

    selection = optparse.OptionGroup(parser, 'Video Selection')
    selection.add_option(
        '--playlist-start',
        dest='playliststart', metavar='NUMBER', default=1, type=int,
        help='playlist video to start at (default is %default)')
    selection.add_option(
        '--playlist-end',
        dest='playlistend', metavar='NUMBER', default=None, type=int,
        help='playlist video to end at (default is last)')
    selection.add_option(
        '--match-title',
        dest='matchtitle', metavar='REGEX',
        help='download only matching titles (regex or caseless sub-string)')
    selection.add_option(
        '--reject-title',
        dest='rejecttitle', metavar='REGEX',
        help='skip download for matching titles (regex or caseless sub-string)')
    selection.add_option(
        '--max-downloads',
        dest='max_downloads', metavar='NUMBER', type=int, default=None,
        help='Abort after downloading NUMBER files')
    selection.add_option(
        '--min-filesize',
        metavar='SIZE', dest='min_filesize', default=None,
        help='Do not download any videos smaller than SIZE (e.g. 50k or 44.6m)')
    selection.add_option(
        '--max-filesize',
        metavar='SIZE', dest='max_filesize', default=None,
        help='Do not download any videos larger than SIZE (e.g. 50k or 44.6m)')
    selection.add_option(
        '--date',
        metavar='DATE', dest='date', default=None,
        help='download only videos uploaded in this date')
    selection.add_option(
        '--datebefore',
        metavar='DATE', dest='datebefore', default=None,
        help='download only videos uploaded on or before this date (i.e. inclusive)')
    selection.add_option(
        '--dateafter',
        metavar='DATE', dest='dateafter', default=None,
        help='download only videos uploaded on or after this date (i.e. inclusive)')
    selection.add_option(
        '--min-views',
        metavar='COUNT', dest='min_views', default=None, type=int,
        help='Do not download any videos with less than COUNT views',)
    selection.add_option(
        '--max-views',
        metavar='COUNT', dest='max_views', default=None, type=int,
        help='Do not download any videos with more than COUNT views')
    selection.add_option(
        '--no-playlist',
        action='store_true', dest='noplaylist', default=False,
        help='If the URL refers to a video and a playlist, download only the video.')
    selection.add_option(
        '--age-limit',
        metavar='YEARS', dest='age_limit', default=None, type=int,
        help='download only videos suitable for the given age')
    selection.add_option(
        '--download-archive', metavar='FILE',
        dest='download_archive',
        help='Download only videos not listed in the archive file. Record the IDs of all downloaded videos in it.')
    selection.add_option(
        '--include-ads',
        dest='include_ads', action='store_true',
        help='Download advertisements as well (experimental)')

    authentication = optparse.OptionGroup(parser, 'Authentication Options')
    authentication.add_option(
        '-u', '--username',
        dest='username', metavar='USERNAME',
        help='login with this account ID')
    authentication.add_option(
        '-p', '--password',
        dest='password', metavar='PASSWORD',
        help='account password')
    authentication.add_option(
        '-2', '--twofactor',
        dest='twofactor', metavar='TWOFACTOR',
        help='two-factor auth code')
    authentication.add_option(
        '-n', '--netrc',
        action='store_true', dest='usenetrc', default=False,
        help='use .netrc authentication data')
    authentication.add_option(
        '--video-password',
        dest='videopassword', metavar='PASSWORD',
        help='video password (vimeo, smotri)')

    video_format = optparse.OptionGroup(parser, 'Video Format Options')
    video_format.add_option(
        '-f', '--format',
        action='store', dest='format', metavar='FORMAT', default=None,
        help=(
            'video format code, specify the order of preference using'
            ' slashes, as in -f 22/17/18 . '
            ' Instead of format codes, you can select by extension for the '
            'extensions aac, m4a, mp3, mp4, ogg, wav, webm. '
            'You can also use the special names "best",'
            ' "bestvideo", "bestaudio", "worst". '
            ' By default, youtube-dl will pick the best quality.'
            ' Use commas to download multiple audio formats, such as'
            ' -f  136/137/mp4/bestvideo,140/m4a/bestaudio.'
            ' You can merge the video and audio of two formats into a single'
            ' file using -f <video-format>+<audio-format> (requires ffmpeg or'
            ' avconv), for example -f bestvideo+bestaudio.'))
    video_format.add_option(
        '--all-formats',
        action='store_const', dest='format', const='all',
        help='download all available video formats')
    video_format.add_option(
        '--prefer-free-formats',
        action='store_true', dest='prefer_free_formats', default=False,
        help='prefer free video formats unless a specific one is requested')
    video_format.add_option(
        '--max-quality',
        action='store', dest='format_limit', metavar='FORMAT',
        help='highest quality format to download')
    video_format.add_option(
        '-F', '--list-formats',
        action='store_true', dest='listformats',
        help='list all available formats')
    video_format.add_option(
        '--youtube-include-dash-manifest',
        action='store_true', dest='youtube_include_dash_manifest', default=True,
        help=optparse.SUPPRESS_HELP)
    video_format.add_option(
        '--youtube-skip-dash-manifest',
        action='store_false', dest='youtube_include_dash_manifest',
        help='Do not download the DASH manifest on YouTube videos')
    video_format.add_option(
        '--merge-output-format',
        action='store', dest='merge_output_format', metavar='FORMAT', default=None,
        help=(
            'If a merge is required (e.g. bestvideo+bestaudio), output to given container format. One of mkv, mp4, ogg, webm, flv.'
            'Ignored if no merge is required'))

    subtitles = optparse.OptionGroup(parser, 'Subtitle Options')
    subtitles.add_option(
        '--write-sub', '--write-srt',
        action='store_true', dest='writesubtitles', default=False,
        help='write subtitle file')
    subtitles.add_option(
        '--write-auto-sub', '--write-automatic-sub',
        action='store_true', dest='writeautomaticsub', default=False,
        help='write automatic subtitle file (youtube only)')
    subtitles.add_option(
        '--all-subs',
        action='store_true', dest='allsubtitles', default=False,
        help='downloads all the available subtitles of the video')
    subtitles.add_option(
        '--list-subs',
        action='store_true', dest='listsubtitles', default=False,
        help='lists all available subtitles for the video')
    subtitles.add_option(
        '--sub-format',
        action='store', dest='subtitlesformat', metavar='FORMAT', default='srt',
        help='subtitle format (default=srt) ([sbv/vtt] youtube only)')
    subtitles.add_option(
        '--sub-lang', '--sub-langs', '--srt-lang',
        action='callback', dest='subtitleslangs', metavar='LANGS', type='str',
        default=[], callback=_comma_separated_values_options_callback,
        help='languages of the subtitles to download (optional) separated by commas, use IETF language tags like \'en,pt\'')

    downloader = optparse.OptionGroup(parser, 'Download Options')
    downloader.add_option(
        '-r', '--rate-limit',
        dest='ratelimit', metavar='LIMIT',
        help='maximum download rate in bytes per second (e.g. 50K or 4.2M)')
    downloader.add_option(
        '-R', '--retries',
        dest='retries', metavar='RETRIES', default=10,
        help='number of retries (default is %default)')
    downloader.add_option(
        '--buffer-size',
        dest='buffersize', metavar='SIZE', default='1024',
        help='size of download buffer (e.g. 1024 or 16K) (default is %default)')
    downloader.add_option(
        '--no-resize-buffer',
        action='store_true', dest='noresizebuffer', default=False,
        help='do not automatically adjust the buffer size. By default, the buffer size is automatically resized from an initial value of SIZE.')
    downloader.add_option(
        '--test',
        action='store_true', dest='test', default=False,
        help=optparse.SUPPRESS_HELP)
    downloader.add_option(
        '--playlist-reverse',
        action='store_true',
        help='Download playlist videos in reverse order')

    workarounds = optparse.OptionGroup(parser, 'Workarounds')
    workarounds.add_option(
        '--encoding',
        dest='encoding', metavar='ENCODING',
        help='Force the specified encoding (experimental)')
    workarounds.add_option(
        '--no-check-certificate',
        action='store_true', dest='no_check_certificate', default=False,
        help='Suppress HTTPS certificate validation.')
    workarounds.add_option(
        '--prefer-insecure',
        '--prefer-unsecure', action='store_true', dest='prefer_insecure',
        help='Use an unencrypted connection to retrieve information about the video. (Currently supported only for YouTube)')
    workarounds.add_option(
        '--user-agent',
        metavar='UA', dest='user_agent',
        help='specify a custom user agent')
    workarounds.add_option(
        '--referer',
        metavar='URL', dest='referer', default=None,
        help='specify a custom referer, use if the video access is restricted to one domain',
    )
    workarounds.add_option(
        '--add-header',
        metavar='FIELD:VALUE', dest='headers', action='append',
        help='specify a custom HTTP header and its value, separated by a colon \':\'. You can use this option multiple times',
    )
    workarounds.add_option(
        '--bidi-workaround',
        dest='bidi_workaround', action='store_true',
        help='Work around terminals that lack bidirectional text support. Requires bidiv or fribidi executable in PATH')

    verbosity = optparse.OptionGroup(parser, 'Verbosity / Simulation Options')
    verbosity.add_option(
        '-q', '--quiet',
        action='store_true', dest='quiet', default=False,
        help='activates quiet mode')
    verbosity.add_option(
        '--no-warnings',
        dest='no_warnings', action='store_true', default=False,
        help='Ignore warnings')
    verbosity.add_option(
        '-s', '--simulate',
        action='store_true', dest='simulate', default=False,
        help='do not download the video and do not write anything to disk',)
    verbosity.add_option(
        '--skip-download',
        action='store_true', dest='skip_download', default=False,
        help='do not download the video',)
    verbosity.add_option(
        '-g', '--get-url',
        action='store_true', dest='geturl', default=False,
        help='simulate, quiet but print URL')
    verbosity.add_option(
        '-e', '--get-title',
        action='store_true', dest='gettitle', default=False,
        help='simulate, quiet but print title')
    verbosity.add_option(
        '--get-id',
        action='store_true', dest='getid', default=False,
        help='simulate, quiet but print id')
    verbosity.add_option(
        '--get-thumbnail',
        action='store_true', dest='getthumbnail', default=False,
        help='simulate, quiet but print thumbnail URL')
    verbosity.add_option(
        '--get-description',
        action='store_true', dest='getdescription', default=False,
        help='simulate, quiet but print video description')
    verbosity.add_option(
        '--get-duration',
        action='store_true', dest='getduration', default=False,
        help='simulate, quiet but print video length')
    verbosity.add_option(
        '--get-filename',
        action='store_true', dest='getfilename', default=False,
        help='simulate, quiet but print output filename')
    verbosity.add_option(
        '--get-format',
        action='store_true', dest='getformat', default=False,
        help='simulate, quiet but print output format')
    verbosity.add_option(
        '-j', '--dump-json',
        action='store_true', dest='dumpjson', default=False,
        help='simulate, quiet but print JSON information. See --output for a description of available keys.')
    verbosity.add_option(
        '-J', '--dump-single-json',
        action='store_true', dest='dump_single_json', default=False,
        help='simulate, quiet but print JSON information for each command-line argument. If the URL refers to a playlist, dump the whole playlist information in a single line.')
    verbosity.add_option(
        '--print-json',
        action='store_true', dest='print_json', default=False,
        help='Be quiet and print the video information as JSON (video is still being downloaded).',
    )
    verbosity.add_option(
        '--newline',
        action='store_true', dest='progress_with_newline', default=False,
        help='output progress bar as new lines')
    verbosity.add_option(
        '--no-progress',
        action='store_true', dest='noprogress', default=False,
        help='do not print progress bar')
    verbosity.add_option(
        '--console-title',
        action='store_true', dest='consoletitle', default=False,
        help='display progress in console titlebar')
    verbosity.add_option(
        '-v', '--verbose',
        action='store_true', dest='verbose', default=False,
        help='print various debugging information')
    verbosity.add_option(
        '--dump-intermediate-pages',
        action='store_true', dest='dump_intermediate_pages', default=False,
        help='print downloaded pages to debug problems (very verbose)')
    verbosity.add_option(
        '--write-pages',
        action='store_true', dest='write_pages', default=False,
        help='Write downloaded intermediary pages to files in the current directory to debug problems')
    verbosity.add_option(
        '--youtube-print-sig-code',
        action='store_true', dest='youtube_print_sig_code', default=False,
        help=optparse.SUPPRESS_HELP)
    verbosity.add_option(
        '--print-traffic',
        dest='debug_printtraffic', action='store_true', default=False,
        help='Display sent and read HTTP traffic')
    verbosity.add_option(
        '-C', '--call-home',
        dest='call_home', action='store_true', default=False,
        help='Contact the youtube-dl server for debugging.')
    verbosity.add_option(
        '--no-call-home',
        dest='call_home', action='store_false', default=False,
        help='Do NOT contact the youtube-dl server for debugging.')

    filesystem = optparse.OptionGroup(parser, 'Filesystem Options')
    filesystem.add_option(
        '-a', '--batch-file',
        dest='batchfile', metavar='FILE',
        help='file containing URLs to download (\'-\' for stdin)')
    filesystem.add_option(
        '--id', default=False,
        action='store_true', dest='useid', help='use only video ID in file name')
    filesystem.add_option(
        '-o', '--output',
        dest='outtmpl', metavar='TEMPLATE',
        help=('output filename template. Use %(title)s to get the title, '
              '%(uploader)s for the uploader name, %(uploader_id)s for the uploader nickname if different, '
              '%(autonumber)s to get an automatically incremented number, '
              '%(ext)s for the filename extension, '
              '%(format)s for the format description (like "22 - 1280x720" or "HD"), '
              '%(format_id)s for the unique id of the format (like Youtube\'s itags: "137"), '
              '%(upload_date)s for the upload date (YYYYMMDD), '
              '%(extractor)s for the provider (youtube, metacafe, etc), '
              '%(id)s for the video id, '
              '%(playlist_title)s, %(playlist_id)s, or %(playlist)s (=title if present, ID otherwise) for the playlist the video is in, '
              '%(playlist_index)s for the position in the playlist. '
              '%(height)s and %(width)s for the width and height of the video format. '
              '%(resolution)s for a textual description of the resolution of the video format. '
              '%% for a literal percent. '
              'Use - to output to stdout. Can also be used to download to a different directory, '
              'for example with -o \'/my/downloads/%(uploader)s/%(title)s-%(id)s.%(ext)s\' .'))
    filesystem.add_option(
        '--autonumber-size',
        dest='autonumber_size', metavar='NUMBER',
        help='Specifies the number of digits in %(autonumber)s when it is present in output filename template or --auto-number option is given')
    filesystem.add_option(
        '--restrict-filenames',
        action='store_true', dest='restrictfilenames', default=False,
        help='Restrict filenames to only ASCII characters, and avoid "&" and spaces in filenames')
    filesystem.add_option(
        '-A', '--auto-number',
        action='store_true', dest='autonumber', default=False,
        help='[deprecated; use  -o "%(autonumber)s-%(title)s.%(ext)s" ] number downloaded files starting from 00000')
    filesystem.add_option(
        '-t', '--title',
        action='store_true', dest='usetitle', default=False,
        help='[deprecated] use title in file name (default)')
    filesystem.add_option(
        '-l', '--literal', default=False,
        action='store_true', dest='usetitle',
        help='[deprecated] alias of --title')
    filesystem.add_option(
        '-w', '--no-overwrites',
        action='store_true', dest='nooverwrites', default=False,
        help='do not overwrite files')
    filesystem.add_option(
        '-c', '--continue',
        action='store_true', dest='continue_dl', default=True,
        help='force resume of partially downloaded files. By default, youtube-dl will resume downloads if possible.')
    filesystem.add_option(
        '--no-continue',
        action='store_false', dest='continue_dl',
        help='do not resume partially downloaded files (restart from beginning)')
    filesystem.add_option(
        '--no-part',
        action='store_true', dest='nopart', default=False,
        help='do not use .part files - write directly into output file')
    filesystem.add_option(
        '--no-mtime',
        action='store_false', dest='updatetime', default=True,
        help='do not use the Last-modified header to set the file modification time')
    filesystem.add_option(
        '--write-description',
        action='store_true', dest='writedescription', default=False,
        help='write video description to a .description file')
    filesystem.add_option(
        '--write-info-json',
        action='store_true', dest='writeinfojson', default=False,
        help='write video metadata to a .info.json file')
    filesystem.add_option(
        '--write-annotations',
        action='store_true', dest='writeannotations', default=False,
        help='write video annotations to a .annotation file')
    filesystem.add_option(
        '--write-thumbnail',
        action='store_true', dest='writethumbnail', default=False,
        help='write thumbnail image to disk')
    filesystem.add_option(
        '--load-info',
        dest='load_info_filename', metavar='FILE',
        help='json file containing the video information (created with the "--write-json" option)')
    filesystem.add_option(
        '--cookies',
        dest='cookiefile', metavar='FILE',
        help='file to read cookies from and dump cookie jar in')
    filesystem.add_option(
        '--cache-dir', dest='cachedir', default=None, metavar='DIR',
        help='Location in the filesystem where youtube-dl can store some downloaded information permanently. By default $XDG_CACHE_HOME/youtube-dl or ~/.cache/youtube-dl . At the moment, only YouTube player files (for videos with obfuscated signatures) are cached, but that may change.')
    filesystem.add_option(
        '--no-cache-dir', action='store_const', const=False, dest='cachedir',
        help='Disable filesystem caching')
    filesystem.add_option(
        '--rm-cache-dir',
        action='store_true', dest='rm_cachedir',
        help='Delete all filesystem cache files')

    postproc = optparse.OptionGroup(parser, 'Post-processing Options')
    postproc.add_option(
        '-x', '--extract-audio',
        action='store_true', dest='extractaudio', default=False,
        help='convert video files to audio-only files (requires ffmpeg or avconv and ffprobe or avprobe)')
    postproc.add_option(
        '--audio-format', metavar='FORMAT', dest='audioformat', default='best',
        help='"best", "aac", "vorbis", "mp3", "m4a", "opus", or "wav"; "%default" by default')
    postproc.add_option(
        '--audio-quality', metavar='QUALITY',
        dest='audioquality', default='5',
        help='ffmpeg/avconv audio quality specification, insert a value between 0 (better) and 9 (worse) for VBR or a specific bitrate like 128K (default %default)')
    postproc.add_option(
        '--recode-video',
        metavar='FORMAT', dest='recodevideo', default=None,
        help='Encode the video to another format if necessary (currently supported: mp4|flv|ogg|webm|mkv)')
    postproc.add_option(
        '-k', '--keep-video',
        action='store_true', dest='keepvideo', default=False,
        help='keeps the video file on disk after the post-processing; the video is erased by default')
    postproc.add_option(
        '--no-post-overwrites',
        action='store_true', dest='nopostoverwrites', default=False,
        help='do not overwrite post-processed files; the post-processed files are overwritten by default')
    postproc.add_option(
        '--embed-subs',
        action='store_true', dest='embedsubtitles', default=False,
        help='embed subtitles in the video (only for mp4 videos)')
    postproc.add_option(
        '--embed-thumbnail',
        action='store_true', dest='embedthumbnail', default=False,
        help='embed thumbnail in the audio as cover art')
    postproc.add_option(
        '--add-metadata',
        action='store_true', dest='addmetadata', default=False,
        help='write metadata to the video file')
    postproc.add_option(
        '--xattrs',
        action='store_true', dest='xattrs', default=False,
        help='write metadata to the video file\'s xattrs (using dublin core and xdg standards)')
    postproc.add_option(
        '--fixup',
        metavar='POLICY', dest='fixup', default='detect_or_warn',
        help='(experimental) Automatically correct known faults of the file. '
             'One of never (do nothing), warn (only emit a warning), '
             'detect_or_warn(check whether we can do anything about it, warn '
             'otherwise')
    postproc.add_option(
        '--prefer-avconv',
        action='store_false', dest='prefer_ffmpeg',
        help='Prefer avconv over ffmpeg for running the postprocessors (default)')
    postproc.add_option(
        '--prefer-ffmpeg',
        action='store_true', dest='prefer_ffmpeg',
        help='Prefer ffmpeg over avconv for running the postprocessors')
    postproc.add_option(
        '--exec',
        metavar='CMD', dest='exec_cmd',
        help='Execute a command on the file after downloading, similar to find\'s -exec syntax. Example: --exec \'adb push {} /sdcard/Music/ && rm {}\'')

    parser.add_option_group(general)
    parser.add_option_group(network)
    parser.add_option_group(selection)
    parser.add_option_group(downloader)
    parser.add_option_group(filesystem)
    parser.add_option_group(verbosity)
    parser.add_option_group(workarounds)
    parser.add_option_group(video_format)
    parser.add_option_group(subtitles)
    parser.add_option_group(authentication)
    parser.add_option_group(postproc)

    if overrideArguments is not None:
        opts, args = parser.parse_args(overrideArguments)
        if opts.verbose:
            write_string('[debug] Override config: ' + repr(overrideArguments) + '\n')
    else:
        commandLineConf = sys.argv[1:]
        if '--ignore-config' in commandLineConf:
            systemConf = []
            userConf = []
        else:
            systemConf = _readOptions('/etc/youtube-dl.conf')
            if '--ignore-config' in systemConf:
                userConf = []
            else:
                userConf = _readUserConf()
        argv = systemConf + userConf + commandLineConf

        opts, args = parser.parse_args(argv)
        if opts.verbose:
            write_string('[debug] System config: ' + repr(_hide_login_info(systemConf)) + '\n')
            write_string('[debug] User config: ' + repr(_hide_login_info(userConf)) + '\n')
            write_string('[debug] Command-line args: ' + repr(_hide_login_info(commandLineConf)) + '\n')

    return parser, opts, args

Example 23

View license
def train(dim_word_desc=400,# word vector dimensionality
          dim_word_q=400,
          dim_word_ans=600,
          dim_proj=300,
          dim=400,# the number of LSTM units
          encoder_desc='lstm',
          encoder_desc_word='lstm',
          encoder_desc_sent='lstm',
          use_dq_sims=False,
          eyem=None,
          learn_h0=False,
          use_desc_skip_c_g=False,
          debug=False,
          encoder_q='lstm',
          patience=10,
          max_epochs=5000,
          dispFreq=100,
          decay_c=0.,
          alpha_c=0.,
          clip_c=-1.,
          lrate=0.01,
          n_words_q=49145,
          n_words_desc=115425,
          n_words_ans=409,
          pkl_train_files=None,
          pkl_valid_files=None,
          maxlen=2000, # maximum length of the description
          optimizer='rmsprop',
          batch_size=2,
          vocab=None,
          valid_batch_size=16,
          use_elu_g=False,
          saveto='model.npz',
          model_dir=None,
          ms_nlayers=3,
          validFreq=1000,
          saveFreq=1000, # save the parameters after every saveFreq updates
          datasets=[None],
          truncate=400,
          momentum=0.9,
          use_bidir=False,
          cost_mask=None,
          valid_datasets=['/u/yyu/stor/caglar/rc-data/cnn/cnn_test_data.h5',
                          '/u/yyu/stor/caglar/rc-data/cnn/cnn_valid_data.h5'],
          dropout_rate=0.5,
          use_dropout=True,
          reload_=True,
          **opt_ds):

    ensure_dir_exists(model_dir)
    mpath = os.path.join(model_dir, saveto)
    mpath_best = os.path.join(model_dir, prfx("best", saveto))
    mpath_last = os.path.join(model_dir, prfx("last", saveto))
    mpath_stats = os.path.join(model_dir, prfx("stats", saveto))

    # Model options
    model_options = locals().copy()
    model_options['use_sent_reps'] = opt_ds['use_sent_reps']
    stats = defaultdict(list)

    del model_options['eyem']
    del model_options['cost_mask']

    if cost_mask is not None:
        cost_mask = sharedX(cost_mask)

    # reload options and parameters
    if reload_:
        print "Reloading the model."
        if os.path.exists(mpath_best):
            print "Reloading the best model from %s." % mpath_best
            with open(os.path.join(mpath_best, '%s.pkl' % mpath_best), 'rb') as f:
                models_options = pkl.load(f)
            params = init_params(model_options)
            params = load_params(mpath_best, params)
        elif os.path.exists(mpath):
            print "Reloading the model from %s." % mpath
            with open(os.path.join(mpath, '%s.pkl' % mpath), 'rb') as f:
                models_options = pkl.load(f)
            params = init_params(model_options)
            params = load_params(mpath, params)
        else:
            raise IOError("Couldn't open the file.")
    else:
        print "Couldn't reload the models initializing from scratch."
        params = init_params(model_options)

    if datasets[0]:
        print "Short dataset", datasets[0]

    print 'Loading data'
    print 'Building model'
    if pkl_train_files is None or pkl_valid_files is None:
        train, valid, test = load_data(path=datasets[0],
                                       valid_path=valid_datasets[0],
                                       test_path=valid_datasets[1],
                                       batch_size=batch_size,
                                       **opt_ds)
    else:
        train, valid, test = load_pkl_data(train_file_paths=pkl_train_files,
                                           valid_file_paths=pkl_valid_files,
                                           batch_size=batch_size,
                                           vocab=vocab,
                                           eyem=eyem,
                                           **opt_ds)

    tparams = init_tparams(params)
    trng, use_noise, inps_d, \
                     opt_ret, \
                     cost, errors, ent_errors, ent_derrors, probs = \
                        build_model(tparams,
                                    model_options,
                                    prepare_data if not opt_ds['use_sent_reps'] \
                                            else prepare_data_sents,
                                    valid,
                                    cost_mask=cost_mask)

    alphas = opt_ret['dec_alphas']

    if opt_ds['use_sent_reps']:
        inps = [inps_d["desc"], \
                inps_d["word_mask"], \
                inps_d["q"], \
                inps_d['q_mask'], \
                inps_d['ans'], \
                inps_d['wlen'],
                inps_d['slen'], inps_d['qlen'],\
                inps_d['ent_mask']
                ]
    else:
        inps = [inps_d["desc"], \
                inps_d["word_mask"], \
                inps_d["q"], \
                inps_d['q_mask'], \
                inps_d['ans'], \
                inps_d['wlen'], \
                inps_d['qlen'], \
                inps_d['ent_mask']]

    outs = [cost, errors, probs, alphas]
    if ent_errors:
        outs += [ent_errors]

    if ent_derrors:
        outs += [ent_derrors]

    # before any regularizer
    print 'Building f_log_probs...',
    f_log_probs = theano.function(inps, outs, profile=profile)
    print 'Done'

    # Apply weight decay on the feed-forward connections
    if decay_c > 0.:
        decay_c = theano.shared(numpy.float32(decay_c), name='decay_c')
        weight_decay = 0.

        for kk, vv in tparams.iteritems():
            if "logit" in kk or "ff" in kk:
                weight_decay += (vv ** 2).sum()

        weight_decay *= decay_c
        cost += weight_decay

    # after any regularizer
    print 'Computing gradient...',
    grads = safe_grad(cost, itemlist(tparams))
    print 'Done'

    # Gradient clipping:
    if clip_c > 0.:
        g2 = get_norms(grads)
        for p, g in grads.iteritems():
            grads[p] = tensor.switch(g2 > (clip_c**2),
                                     (g / tensor.sqrt(g2 + 1e-8)) * clip_c,
                                     g)
    inps.pop()
    if optimizer.lower() == "adasecant":
        learning_rule = Adasecant(delta_clip=25.0,
                                  use_adagrad=True,
                                  grad_clip=0.25,
                                  gamma_clip=0.)
    elif optimizer.lower() == "rmsprop":
        learning_rule = RMSPropMomentum(init_momentum=momentum)
    elif optimizer.lower() == "adam":
        learning_rule = Adam()
    elif optimizer.lower() == "adadelta":
        learning_rule = AdaDelta()

    lr = tensor.scalar(name='lr')
    print 'Building optimizers...',
    learning_rule = None

    if learning_rule:
        f_grad_shared, f_update = learning_rule.get_funcs(learning_rate=lr,
                                                          grads=grads,
                                                          inp=inps,
                                                          cost=cost,
                                                          errors=errors)
    else:
        f_grad_shared, f_update = eval(optimizer)(lr,
                                                  tparams,
                                                  grads,
                                                  inps,
                                                  cost,
                                                  errors)

    print 'Done'
    print 'Optimization'
    history_errs = []
    # reload history
    if reload_ and os.path.exists(mpath):
        history_errs = list(numpy.load(mpath)['history_errs'])

    best_p = None
    bad_count = 0

    if validFreq == -1:
        validFreq = len(train[0]) / batch_size

    if saveFreq == -1:
        saveFreq = len(train[0]) / batch_size

    best_found = False
    uidx = 0
    estop = False

    train_cost_ave, train_err_ave, \
            train_gnorm_ave = reset_train_vals()

    for eidx in xrange(max_epochs):
        n_samples = 0

        if train.done:
            train.reset()

        for d_, q_, a, em in train:
            n_samples += len(a)
            uidx += 1
            use_noise.set_value(1.)

            if opt_ds['use_sent_reps']:
                # To mask the description and the question.
                d, d_mask, q, q_mask, dlen, slen, qlen = prepare_data_sents(d_,
                                                                            q_)

                if d is None:
                    print 'Minibatch with zero sample under length ', maxlen
                    uidx -= 1
                    continue

                ud_start = time.time()
                cost, errors, gnorm, pnorm = f_grad_shared(d,
                                                           d_mask,
                                                           q,
                                                           q_mask,
                                                           a,
                                                           dlen,
                                                           slen,
                                                           qlen)
            else:
                d, d_mask, q, q_mask, dlen, qlen = prepare_data(d_, q_)

                if d is None:
                    print 'Minibatch with zero sample under length ', maxlen
                    uidx -= 1
                    continue

                ud_start = time.time()
                cost, errors, gnorm, pnorm = f_grad_shared(d, d_mask,
                                                           q, q_mask,
                                                           a,
                                                           dlen,
                                                           qlen)

            upnorm = f_update(lrate)
            ud = time.time() - ud_start

            # Collect the running ave train stats.
            train_cost_ave = running_ave(train_cost_ave,
                                         cost)
            train_err_ave = running_ave(train_err_ave,
                                        errors)
            train_gnorm_ave = running_ave(train_gnorm_ave,
                                          gnorm)

            if numpy.isnan(cost) or numpy.isinf(cost):
                print 'NaN detected'
                import ipdb; ipdb.set_trace()

            if numpy.mod(uidx, dispFreq) == 0:
                print 'Epoch ', eidx, ' Update ', uidx, \
                        ' Cost ', cost, ' UD ', ud, \
                        ' UpNorm ', upnorm[0].tolist(), \
                        ' GNorm ', gnorm, \
                        ' Pnorm ', pnorm, 'Terrors ', errors

            if numpy.mod(uidx, saveFreq) == 0:
                print 'Saving...',
                if best_p is not None and best_found:
                    numpy.savez(mpath_best, history_errs=history_errs, **best_p)
                    pkl.dump(model_options, open('%s.pkl' % mpath_best, 'wb'))
                else:
                    params = unzip(tparams)

                numpy.savez(mpath, history_errs=history_errs, **params)
                pkl.dump(model_options, open('%s.pkl' % mpath, 'wb'))
                pkl.dump(stats, open("%s.pkl" % mpath_stats, 'wb'))

                print 'Done'
                print_param_norms(tparams)

            if numpy.mod(uidx, validFreq) == 0:
                use_noise.set_value(0.)
                if valid.done:
                    valid.reset()

                valid_costs, valid_errs, valid_probs, \
                        valid_alphas, error_ent, error_dent = eval_model(f_log_probs,
                                                  prepare_data if not opt_ds['use_sent_reps'] \
                                                    else prepare_data_sents,
                                                  model_options,
                                                  valid,
                                                  use_sent_rep=opt_ds['use_sent_reps'])

                valid_alphas_ = numpy.concatenate([va.argmax(0) for va  in valid_alphas.tolist()], axis=0)
                valid_err = valid_errs.mean()
                valid_cost = valid_costs.mean()
                valid_alpha_ent = -negentropy(valid_alphas)

                mean_valid_alphas = valid_alphas_.mean()
                std_valid_alphas = valid_alphas_.std()

                mean_valid_probs = valid_probs.argmax(1).mean()
                std_valid_probs = valid_probs.argmax(1).std()

                history_errs.append([valid_cost, valid_err])

                stats['train_err_ave'].append(train_err_ave)
                stats['train_cost_ave'].append(train_cost_ave)
                stats['train_gnorm_ave'].append(train_gnorm_ave)

                stats['valid_errs'].append(valid_err)
                stats['valid_costs'].append(valid_cost)
                stats['valid_err_ent'].append(error_ent)
                stats['valid_err_desc_ent'].append(error_dent)

                stats['valid_alphas_mean'].append(mean_valid_alphas)
                stats['valid_alphas_std'].append(std_valid_alphas)
                stats['valid_alphas_ent'].append(valid_alpha_ent)

                stats['valid_probs_mean'].append(mean_valid_probs)
                stats['valid_probs_std'].append(std_valid_probs)

                if uidx == 0 or valid_err <= numpy.array(history_errs)[:, 1].min():
                    best_p = unzip(tparams)
                    bad_counter = 0
                    best_found = True
                else:
                    bst_found = False

                if numpy.isnan(valid_err):
                    import ipdb; ipdb.set_trace()


                print "============================"
                print '\t>>>Valid error: ', valid_err, \
                        ' Valid cost: ', valid_cost
                print '\t>>>Valid pred mean: ', mean_valid_probs, \
                        ' Valid pred std: ', std_valid_probs
                print '\t>>>Valid alphas mean: ', mean_valid_alphas, \
                        ' Valid alphas std: ', std_valid_alphas, \
                        ' Valid alpha negent: ', valid_alpha_ent, \
                        ' Valid error ent: ', error_ent, \
                        ' Valid error desc ent: ', error_dent

                print "============================"
                print "Running average train stats "
                print '\t>>>Train error: ', train_err_ave, \
                        ' Train cost: ', train_cost_ave, \
                        ' Train grad norm: ', train_gnorm_ave
                print "============================"


                train_cost_ave, train_err_ave, \
                    train_gnorm_ave = reset_train_vals()


        print 'Seen %d samples' % n_samples

        if estop:
            break

    if best_p is not None:
        zipp(best_p, tparams)

    use_noise.set_value(0.)
    valid.reset()
    valid_cost, valid_error, valid_probs, \
            valid_alphas, error_ent = eval_model(f_log_probs,
                                      prepare_data if not opt_ds['use_sent_reps'] \
                                           else prepare_data_sents,
                                      model_options, valid,
                                      use_sent_rep=opt_ds['use_sent_rep'])

    print " Final eval resuts: "
    print 'Valid error: ', valid_error.mean()
    print 'Valid cost: ', valid_cost.mean()
    print '\t>>>Valid pred mean: ', valid_probs.mean(), \
            ' Valid pred std: ', valid_probs.std(), \
            ' Valid error ent: ', error_ent

    params = copy.copy(best_p)

    numpy.savez(mpath_last,
                zipped_params=best_p,
                history_errs=history_errs,
                **params)

    return valid_err, valid_cost

Example 24

Project: catkin_tools
Source File: build.py
View license
def build_isolated_workspace(
    context,
    packages=None,
    start_with=None,
    no_deps=False,
    unbuilt=False,
    n_jobs=None,
    force_cmake=False,
    pre_clean=False,
    force_color=False,
    quiet=False,
    interleave_output=False,
    no_status=False,
    limit_status_rate=10.0,
    lock_install=False,
    no_notify=False,
    continue_on_failure=False,
    summarize_build=None,
):
    """Builds a catkin workspace in isolation

    This function will find all of the packages in the source space, start some
    executors, feed them packages to build based on dependencies and topological
    ordering, and then monitor the output of the executors, handling loggings of
    the builds, starting builds, failing builds, and finishing builds of
    packages, and handling the shutdown of the executors when appropriate.

    :param context: context in which to build the catkin workspace
    :type context: :py:class:`catkin_tools.verbs.catkin_build.context.Context`
    :param packages: list of packages to build, by default their dependencies will also be built
    :type packages: list
    :param start_with: package to start with, skipping all packages which proceed it in the topological order
    :type start_with: str
    :param no_deps: If True, the dependencies of packages will not be built first
    :type no_deps: bool
    :param n_jobs: number of parallel package build n_jobs
    :type n_jobs: int
    :param force_cmake: forces invocation of CMake if True, default is False
    :type force_cmake: bool
    :param force_color: forces colored output even if terminal does not support it
    :type force_color: bool
    :param quiet: suppresses the output of commands unless there is an error
    :type quiet: bool
    :param interleave_output: prints the output of commands as they are received
    :type interleave_output: bool
    :param no_status: disables status bar
    :type no_status: bool
    :param limit_status_rate: rate to which status updates are limited; the default 0, places no limit.
    :type limit_status_rate: float
    :param lock_install: causes executors to synchronize on access of install commands
    :type lock_install: bool
    :param no_notify: suppresses system notifications
    :type no_notify: bool
    :param continue_on_failure: do not stop building other jobs on error
    :type continue_on_failure: bool
    :param summarize_build: if True summarizes the build at the end, if None and continue_on_failure is True and the
        the build fails, then the build will be summarized, but if False it never will be summarized.
    :type summarize_build: bool

    :raises: SystemExit if buildspace is a file or no packages were found in the source space
        or if the provided options are invalid
    """
    pre_start_time = time.time()

    # Assert that the limit_status_rate is valid
    if limit_status_rate < 0:
        sys.exit("[build] @[email protected]{rf}Error:@| The value of --status-rate must be greater than or equal to zero.")

    # Declare a buildspace marker describing the build config for error checking
    buildspace_marker_data = {
        'workspace': context.workspace,
        'profile': context.profile,
        'install': context.install,
        'install_space': context.install_space_abs,
        'devel_space': context.devel_space_abs,
        'source_space': context.source_space_abs}

    # Check build config
    if os.path.exists(os.path.join(context.build_space_abs, BUILDSPACE_MARKER_FILE)):
        with open(os.path.join(context.build_space_abs, BUILDSPACE_MARKER_FILE)) as buildspace_marker_file:
            existing_buildspace_marker_data = yaml.load(buildspace_marker_file)
            misconfig_lines = ''
            for (k, v) in existing_buildspace_marker_data.items():
                new_v = buildspace_marker_data.get(k, None)
                if new_v != v:
                    misconfig_lines += (
                        '\n - %s: %s (stored) is not %s (commanded)' %
                        (k, v, new_v))
            if len(misconfig_lines) > 0:
                sys.exit(clr(
                    "\[email protected]{rf}Error:@| Attempting to build a catkin workspace using build space: "
                    "\"%s\" but that build space's most recent configuration "
                    "differs from the commanded one in ways which will cause "
                    "problems. Fix the following options or use @{yf}`catkin "
                    "clean -b`@| to remove the build space: %s" %
                    (context.build_space_abs, misconfig_lines)))

    # Summarize the context
    summary_notes = []
    if force_cmake:
        summary_notes += [clr("@[email protected]{cf}NOTE:@| Forcing CMake to run for each package.")]
    log(context.summary(summary_notes))

    # Make sure there is a build folder and it is not a file
    if os.path.exists(context.build_space_abs):
        if os.path.isfile(context.build_space_abs):
            sys.exit(clr(
                "[build] @{rf}Error:@| Build space '{0}' exists but is a file and not a folder."
                .format(context.build_space_abs)))
    # If it dosen't exist, create it
    else:
        log("[build] Creating build space: '{0}'".format(context.build_space_abs))
        os.makedirs(context.build_space_abs)

    # Write the current build config for config error checking
    with open(os.path.join(context.build_space_abs, BUILDSPACE_MARKER_FILE), 'w') as buildspace_marker_file:
        buildspace_marker_file.write(yaml.dump(buildspace_marker_data, default_flow_style=False))

    # Get all the packages in the context source space
    # Suppress warnings since this is a utility function
    workspace_packages = find_packages(context.source_space_abs, exclude_subspaces=True, warnings=[])

    # Get packages which have not been built yet
    built_packages, unbuilt_pkgs = get_built_unbuilt_packages(context, workspace_packages)

    # Handle unbuilt packages
    if unbuilt:
        # Check if there are any unbuilt
        if len(unbuilt_pkgs) > 0:
            # Add the unbuilt packages
            packages.extend(list(unbuilt_pkgs))
        else:
            log("[build] No unbuilt packages to be built.")
            return

    # If no_deps is given, ensure packages to build are provided
    if no_deps and packages is None:
        log(clr("[build] @[email protected]{rf}Error:@| With no_deps, you must specify packages to build."))
        return

    # Find list of packages in the workspace
    packages_to_be_built, packages_to_be_built_deps, all_packages = determine_packages_to_be_built(
        packages, context, workspace_packages)

    if not no_deps:
        # Extend packages to be built to include their deps
        packages_to_be_built.extend(packages_to_be_built_deps)

    # Also re-sort
    try:
        packages_to_be_built = topological_order_packages(dict(packages_to_be_built))
    except AttributeError:
        log(clr("[build] @[email protected]{rf}Error:@| The workspace packages have a circular "
                "dependency, and cannot be built. Please run `catkin list "
                "--deps` to determine the problematic package(s)."))
        return

    # Check the number of packages to be built
    if len(packages_to_be_built) == 0:
        log(clr('[build] No packages to be built.'))

    # Assert start_with package is in the workspace
    verify_start_with_option(
        start_with,
        packages,
        all_packages,
        packages_to_be_built + packages_to_be_built_deps)

    # Populate .catkin file if we're not installing
    # NOTE: This is done to avoid the Catkin CMake code from doing it,
    # which isn't parallel-safe. Catkin CMake only modifies this file if
    # it's package source path isn't found.
    if not context.install:
        dot_catkin_file_path = os.path.join(context.devel_space_abs, '.catkin')
        # If the file exists, get the current paths
        if os.path.exists(dot_catkin_file_path):
            dot_catkin_paths = open(dot_catkin_file_path, 'r').read().split(';')
        else:
            dot_catkin_paths = []

        # Update the list with the new packages (in topological order)
        packages_to_be_built_paths = [
            os.path.join(context.source_space_abs, path)
            for path, pkg in packages_to_be_built
        ]

        new_dot_catkin_paths = [
            os.path.join(context.source_space_abs, path)
            for path in [os.path.join(context.source_space_abs, path) for path, pkg in all_packages]
            if path in dot_catkin_paths or path in packages_to_be_built_paths
        ]

        # Write the new file if it's different, otherwise, leave it alone
        if dot_catkin_paths == new_dot_catkin_paths:
            wide_log("[build] Package table is up to date.")
        else:
            wide_log("[build] Updating package table.")
            open(dot_catkin_file_path, 'w').write(';'.join(new_dot_catkin_paths))

    # Remove packages before start_with
    if start_with is not None:
        for path, pkg in list(packages_to_be_built):
            if pkg.name != start_with:
                wide_log(clr("@[email protected]{pf}[email protected]|  @{gf}[email protected]| @{cf}{}@|").format(pkg.name))
                packages_to_be_built.pop(0)
            else:
                break

    # Get the names of all packages to be built
    packages_to_be_built_names = [p.name for _, p in packages_to_be_built]
    packages_to_be_built_deps_names = [p.name for _, p in packages_to_be_built_deps]

    # Generate prebuild and prebuild clean jobs, if necessary
    prebuild_jobs = {}
    setup_util_present = os.path.exists(os.path.join(context.devel_space_abs, '_setup_util.py'))
    catkin_present = 'catkin' in (packages_to_be_built_names + packages_to_be_built_deps_names)
    catkin_built = 'catkin' in built_packages
    prebuild_built = 'catkin_tools_prebuild' in built_packages

    # Handle the prebuild jobs if the develspace is linked
    prebuild_pkg_deps = []
    if context.link_devel:
        prebuild_pkg = None

        # Construct a dictionary to lookup catkin package by name
        pkg_dict = dict([(pkg.name, (pth, pkg)) for pth, pkg in all_packages])

        if setup_util_present:
            # Setup util is already there, determine if it needs to be
            # regenerated
            if catkin_built:
                if catkin_present:
                    prebuild_pkg_path, prebuild_pkg = pkg_dict['catkin']
            elif prebuild_built:
                if catkin_present:
                    # TODO: Clean prebuild package
                    ct_prebuild_pkg_path = get_prebuild_package(
                        context.build_space_abs, context.devel_space_abs, force_cmake)
                    ct_prebuild_pkg = parse_package(ct_prebuild_pkg_path)

                    prebuild_jobs['caktin_tools_prebuild'] = create_catkin_clean_job(
                        context,
                        ct_prebuild_pkg,
                        ct_prebuild_pkg_path,
                        dependencies=[],
                        dry_run=False,
                        clean_build=True,
                        clean_devel=True,
                        clean_install=True)

                    # TODO: Build catkin package
                    prebuild_pkg_path, prebuild_pkg = pkg_dict['catkin']
                    prebuild_pkg_deps.append('catkin_tools_prebuild')
            else:
                # How did these get here??
                log("Warning: devel space setup files have an unknown origin.")
        else:
            # Setup util needs to be generated
            if catkin_built or prebuild_built:
                log("Warning: generated devel space setup files have been deleted.")

            if catkin_present:
                # Build catkin package
                prebuild_pkg_path, prebuild_pkg = pkg_dict['catkin']
            else:
                # Generate and buildexplicit prebuild package
                prebuild_pkg_path = get_prebuild_package(context.build_space_abs, context.devel_space_abs, force_cmake)
                prebuild_pkg = parse_package(prebuild_pkg_path)

        if prebuild_pkg is not None:
            # Create the prebuild job
            prebuild_job = create_catkin_build_job(
                context,
                prebuild_pkg,
                prebuild_pkg_path,
                dependencies=prebuild_pkg_deps,
                force_cmake=force_cmake,
                pre_clean=pre_clean,
                prebuild=True)

            # Add the prebuld job
            prebuild_jobs[prebuild_job.jid] = prebuild_job

    # Remove prebuild jobs from normal job list
    for prebuild_jid, prebuild_job in prebuild_jobs.items():
        if prebuild_jid in packages_to_be_built_names:
            packages_to_be_built_names.remove(prebuild_jid)

    # Initial jobs list is just the prebuild jobs
    jobs = [] + list(prebuild_jobs.values())

    # Get all build type plugins
    build_job_creators = {
        ep.name: ep.load()['create_build_job']
        for ep in pkg_resources.iter_entry_points(group='catkin_tools.jobs')
    }

    # It's a problem if there aren't any build types available
    if len(build_job_creators) == 0:
        sys.exit('Error: No build types available. Please check your catkin_tools installation.')

    # Construct jobs
    for pkg_path, pkg in all_packages:
        if pkg.name not in packages_to_be_built_names:
            continue

        # Ignore metapackages
        if 'metapackage' in [e.tagname for e in pkg.exports]:
            continue

        # Get actual execution deps
        deps = [
            p.name for _, p
            in get_cached_recursive_build_depends_in_workspace(pkg, packages_to_be_built)
            if p.name not in prebuild_jobs
        ]
        # All jobs depend on the prebuild jobs if they're defined
        if not no_deps:
            for j in prebuild_jobs.values():
                deps.append(j.jid)

        # Determine the job parameters
        build_job_kwargs = dict(
            context=context,
            package=pkg,
            package_path=pkg_path,
            dependencies=deps,
            force_cmake=force_cmake,
            pre_clean=pre_clean)

        # Create the job based on the build type
        build_type = get_build_type(pkg)

        if build_type in build_job_creators:
            jobs.append(build_job_creators[build_type](**build_job_kwargs))
        else:
            wide_log(clr(
                "[build] @[email protected]{yf}Warning:@| Skipping package `{}` because it "
                "has an unsupported package build type: `{}`"
            ).format(pkg.name, build_type))

            wide_log(clr("[build] Note: Available build types:"))
            for bt_name in build_job_creators.keys():
                wide_log(clr("[build]  - `{}`".format(bt_name)))

    # Queue for communicating status
    event_queue = Queue()

    try:
        # Spin up status output thread
        status_thread = ConsoleStatusController(
            'build',
            ['package', 'packages'],
            jobs,
            n_jobs,
            [pkg.name for _, pkg in context.packages],
            [p for p in context.whitelist],
            [p for p in context.blacklist],
            event_queue,
            show_notifications=not no_notify,
            show_active_status=not no_status,
            show_buffered_stdout=not quiet and not interleave_output,
            show_buffered_stderr=not interleave_output,
            show_live_stdout=interleave_output,
            show_live_stderr=interleave_output,
            show_stage_events=not quiet,
            show_full_summary=(summarize_build is True),
            pre_start_time=pre_start_time,
            active_status_rate=limit_status_rate)
        status_thread.start()

        # Initialize locks
        locks = {
            'installspace': asyncio.Lock() if lock_install else FakeLock()
        }

        # Block while running N jobs asynchronously
        try:
            all_succeeded = run_until_complete(execute_jobs(
                'build',
                jobs,
                locks,
                event_queue,
                context.log_space_abs,
                max_toplevel_jobs=n_jobs,
                continue_on_failure=continue_on_failure,
                continue_without_deps=False))
        except Exception:
            status_thread.keep_running = False
            all_succeeded = False
            status_thread.join(1.0)
            wide_log(str(traceback.format_exc()))

        status_thread.join(1.0)

        # Warn user about new packages
        now_built_packages, now_unbuilt_pkgs = get_built_unbuilt_packages(context, workspace_packages)
        new_pkgs = [p for p in unbuilt_pkgs if p not in now_unbuilt_pkgs]
        if len(new_pkgs) > 0:
            log(clr("[build] @/@!Note:@| @/Workspace packages have changed, "
                    "please re-source setup files to use [email protected]|"))

        if all_succeeded:
            # Create isolated devel setup if necessary
            if context.isolate_devel:
                if not context.install:
                    _create_unmerged_devel_setup(context, now_unbuilt_pkgs)
                else:
                    _create_unmerged_devel_setup_for_install(context)
            return 0
        else:
            return 1

    except KeyboardInterrupt:
        wide_log("[build] Interrupted by user!")
        event_queue.put(None)

Example 25

Project: cgat
Source File: bam2wiggle.py
View license
def main(argv=None):
    """script main.
    """

    if not argv:
        argv = sys.argv

    # setup command line parser
    parser = E.OptionParser(
        version="%prog version: $Id$",
        usage=globals()["__doc__"])

    parser.add_option("-o", "--output-format", dest="output_format",
                      type="choice",
                      choices=(
                          "bedgraph", "wiggle", "bigbed",
                          "bigwig", "bed"),
                      help="output format [default=%default]")

    parser.add_option("-s", "--shift-size", dest="shift", type="int",
                      help="shift reads by a certain amount (ChIP-Seq) "
                      "[%default]")

    parser.add_option("-e", "--extend", dest="extend", type="int",
                      help="extend reads by a certain amount "
                      "(ChIP-Seq) [%default]")

    parser.add_option("-p", "--wiggle-span", dest="span", type="int",
                      help="span of a window in wiggle tracks "
                      "[%default]")

    parser.add_option("-m", "--merge-pairs", dest="merge_pairs",
                      action="store_true",
                      help="merge paired-ended reads into a single "
                      "bed interval [default=%default].")

    parser.add_option("--scale-base", dest="scale_base", type="float",
                      help="number of reads/pairs to scale bigwig file to. "
                      "The default is to scale to 1M reads "
                      "[default=%default]")

    parser.add_option("--scale-method", dest="scale_method", type="choice",
                      choices=("none", "reads",),
                      help="scale bigwig output. 'reads' will normalize by "
                      "the total number reads in the bam file that are used "
                      "to construct the bigwig file. If --merge-pairs is used "
                      "the number of pairs output will be used for "
                      "normalization. 'none' will not scale the bigwig file"
                      "[default=%default]")

    parser.add_option("--max-insert-size", dest="max_insert_size",
                      type="int",
                      help="only merge if insert size less that "
                      "# bases. 0 turns of this filter "
                      "[default=%default].")

    parser.add_option("--min-insert-size", dest="min_insert_size",
                      type="int",
                      help="only merge paired-end reads if they are "
                      "at least # bases apart. "
                      "0 turns of this filter. [default=%default]")

    parser.set_defaults(
        samfile=None,
        output_format="wiggle",
        shift=0,
        extend=0,
        span=1,
        merge_pairs=None,
        min_insert_size=0,
        max_insert_size=0,
        scale_method='none',
        scale_base=1000000,
    )

    # add common options (-h/--help, ...) and parse command line
    (options, args) = E.Start(parser, argv=argv, add_output_options=True)

    if len(args) >= 1:
        options.samfile = args[0]
    if len(args) == 2:
        options.output_filename_pattern = args[1]
    if not options.samfile:
        raise ValueError("please provide a bam file")

    # Read BAM file using Pysam
    samfile = pysam.Samfile(options.samfile, "rb")

    # Create temporary files / folders
    tmpdir = tempfile.mkdtemp()
    E.debug("temporary files are in %s" % tmpdir)
    tmpfile_wig = os.path.join(tmpdir, "wig")
    tmpfile_sizes = os.path.join(tmpdir, "sizes")

    # Create dictionary of contig sizes
    contig_sizes = dict(list(zip(samfile.references, samfile.lengths)))
    # write contig sizes
    outfile_size = IOTools.openFile(tmpfile_sizes, "w")
    for contig, size in sorted(contig_sizes.items()):
        outfile_size.write("%s\t%s\n" % (contig, size))
    outfile_size.close()

    # Shift and extend only available for bigwig format
    if options.shift or options.extend:
        if options.output_format != "bigwig":
            raise ValueError(
                "shift and extend only available for bigwig output")

    # Output filename required for bigwig / bigbed computation
    if options.output_format == "bigwig":
        if not options.output_filename_pattern:
            raise ValueError(
                "please specify an output file for bigwig computation.")

        # Define executable to use for binary conversion
        if options.output_format == "bigwig":
            executable_name = "wigToBigWig"
        else:
            raise ValueError("unknown output format `%s`" %
                             options.output_format)

        # check required executable file is in the path
        executable = IOTools.which(executable_name)
        if not executable:
            raise OSError("could not find %s in path." % executable_name)

        # Open outout file
        outfile = IOTools.openFile(tmpfile_wig, "w")
        E.info("starting output to %s" % tmpfile_wig)
    else:
        outfile = IOTools.openFile(tmpfile_wig, "w")
        E.info("starting output to stdout")

    # Set up output write functions
    if options.output_format in ("wiggle", "bigwig"):
        # wiggle is one-based, so add 1, also step-size is 1, so need
        # to output all bases
        if options.span == 1:
            outf = lambda outfile, contig, start, end, val: \
                outfile.write(
                    "".join(["%i\t%i\n" % (x, val)
                             for x in range(start + 1, end + 1)]))
        else:
            outf = SpanWriter(options.span)
    elif options.output_format == "bedgraph":
        # bed is 0-based, open-closed
        outf = lambda outfile, contig, start, end, val: \
            outfile.write("%s\t%i\t%i\t%i\n" % (contig, start, end, val))

    # initialise counters
    ninput, nskipped, ncontigs = 0, 0, 0

    # set output file name
    output_filename_pattern = options.output_filename_pattern
    if output_filename_pattern:
        output_filename = os.path.abspath(output_filename_pattern)

    # shift and extend or merge pairs. Output temporay bed file
    if options.shift > 0 or options.extend > 0 or options.merge_pairs:
        # Workflow 1: convert to bed intervals and use bedtools
        # genomecov to build a coverage file.
        # Convert to bigwig with UCSC tools bedGraph2BigWig

        if options.merge_pairs:
            # merge pairs using bam2bed
            E.info("merging pairs to temporary file")
            counter = _bam2bed.merge_pairs(
                samfile,
                outfile,
                min_insert_size=options.min_insert_size,
                max_insert_size=options.max_insert_size,
                bed_format=3)
            E.info("merging results: {}".format(counter))
            if counter.output == 0:
                raise ValueError("no pairs output after merging")
        else:
            # create bed file with shifted/extended tags
            shift, extend = options.shift, options.extend
            shift_extend = shift + extend
            counter = E.Counter()

            for contig in samfile.references:
                E.debug("output for %s" % contig)
                lcontig = contig_sizes[contig]

                for read in samfile.fetch(contig):
                    pos = read.pos
                    if read.is_reverse:
                        start = max(0, read.pos + read.alen - shift_extend)
                    else:
                        start = max(0, read.pos + shift)

                    # intervals extending beyond contig are removed
                    if start >= lcontig:
                        continue

                    end = min(lcontig, start + extend)
                    outfile.write("%s\t%i\t%i\n" % (contig, start, end))
                    counter.output += 1

        outfile.close()

        if options.scale_method == "reads":
            scale_factor = float(options.scale_base) / counter.output

            E.info("scaling: method=%s scale_quantity=%i scale_factor=%f" %
                   (options.scale_method,
                    counter.output,
                    scale_factor))
            scale = "-scale %f" % scale_factor
        else:
            scale = ""

        # Convert bed file to coverage file (bedgraph)
        tmpfile_bed = os.path.join(tmpdir, "bed")
        E.info("computing coverage")
        # calculate coverage - format is bedgraph
        statement = """bedtools genomecov -bg -i %(tmpfile_wig)s %(scale)s
        -g %(tmpfile_sizes)s > %(tmpfile_bed)s""" % locals()
        E.run(statement)

        # Convert bedgraph to bigwig
        E.info("converting to bigwig")
        tmpfile_sorted = os.path.join(tmpdir, "sorted")
        statement = ("sort -k 1,1 -k2,2n %(tmpfile_bed)s > %(tmpfile_sorted)s;"
                     "bedGraphToBigWig %(tmpfile_sorted)s %(tmpfile_sizes)s "
                     "%(output_filename_pattern)s" % locals())
        E.run(statement)

    else:

        # Workflow 2: use pysam column iterator to build a
        # wig file. Then convert to bigwig of bedgraph file
        # with UCSC tools.
        def column_iter(iterator):
            start = None
            end = 0
            n = None
            for t in iterator:
                if t.pos - end > 1 or n != t.n:
                    if start is not None:
                        yield start, end, n
                    start = t.pos
                    end = t.pos
                    n = t.n
                end = t.pos
            yield start, end, n

        if options.scale_method != "none":
            raise NotImplementedError(
                "scaling not implemented for pileup method")

        # Bedgraph track definition
        if options.output_format == "bedgraph":
            outfile.write("track type=bedGraph\n")

        for contig in samfile.references:
            # if contig != "chrX": continue
            E.debug("output for %s" % contig)
            lcontig = contig_sizes[contig]

            # Write wiggle header
            if options.output_format in ("wiggle", "bigwig"):
                outfile.write("variableStep chrom=%s span=%i\n" %
                              (contig, options.span))

            # Generate pileup per contig using pysam and iterate over columns
            for start, end, val in column_iter(samfile.pileup(contig)):
                # patch: there was a problem with bam files and reads
                # overextending at the end. These are usually Ns, but
                # need to check as otherwise wigToBigWig fails.
                if lcontig <= end:
                    E.warn("read extending beyond contig: %s: %i > %i" %
                           (contig, end, lcontig))
                    end = lcontig
                    if start >= end:
                        continue

                if val > 0:
                    outf(outfile, contig, start, end, val)
            ncontigs += 1

        # Close output file
        if type(outf) == type(SpanWriter):
            outf.flush(outfile)
        else:
            outfile.flush()

        E.info("finished output")

        # Report counters
        E.info("ninput=%i, ncontigs=%i, nskipped=%i" %
               (ninput, ncontigs, nskipped))

        # Convert to binary formats
        if options.output_format == "bigwig":
            outfile.close()

            E.info("starting %s conversion" % executable)
            try:
                retcode = subprocess.call(
                    " ".join((executable,
                              tmpfile_wig,
                              tmpfile_sizes,
                              output_filename_pattern)),
                    shell=True)
                if retcode != 0:
                    E.warn("%s terminated with signal: %i" %
                           (executable, -retcode))
                    return -retcode
            except OSError as msg:
                E.warn("Error while executing bigwig: %s" % msg)
                return 1
            E.info("finished bigwig conversion")
        else:
            with open(tmpfile_wig) as inf:
                sys.stdout.write(inf.read())

    # Cleanup temp files
    shutil.rmtree(tmpdir)

    E.Stop()

Example 26

Project: fbht
Source File: main.py
View license
def main():
    global globalLogin
    global globalEmail
    global globalPassword

    print '                       \n \n \n \n \n \n          '
    print '                    ______ ____  _    _ _______   '
    print '                   |  ____|  _ \| |  | |__   __|  '
    print '                   | |__  | |_) | |__| |  | |     '
    print '                   |  __| |  _ <|  __  |  | |     '
    print '                   | |    | |_) | |  | |  | |     '
    print '                   |_|    |____/|_|  |_|  |_|     '
    print '                              ____  __            '
    print '                             |___ \/_ |           '
    print '                        __   ____) || |           '
    print '                        \ \ / /__ < | |           '
    print '                         \ V /___) || |           '
    print '                          \_/|____(_)_|           '
    print '                                                                   '
    print '               _     _                                             '
    print '    ____      | |   (_)                                            '
    print '   / __ \  ___| |__  _ _ __   ___   ___   __ _  __ ___      ____ _ '
    print '  / / _` |/ __| \'_ \| | \'_ \ / _ \ / _ \ / _` |/ _` \ \ /\ / / _` |'
    print ' | | (_| | (__| | | | | | | | (_) | (_) | (_| | (_| |\ V  V | (_| |'
    print '  \ \__,_|\___|_| |_|_|_| |_|\___/ \___/ \__, |\__,_| \_/\_/ \__,_|'
    print '   \____/                                 __/ |                    '
    print '                                         |___/                     '
    print '\n\n\n\n\n\n'	

    raw_input('Enjoy it :D . Press enter to get started')
    
    def testAccounts():
        option = 0
        def createAcc():
            if (globalLogin == False):
                email,password = setMail()
            else:
                email = globalEmail
                password = globalPassword
            appID = raw_input('Enter Application ID: ')
            
            if (login(email,password,'real'))!= -1:
                number = raw_input('Insert the amount of accounts for creation (4 min): ')
                for i in range(int(number)/4):
                    sTime = time()
                    devTest(appID)
                getTest(appID)
                           
        def deleteAcc():
            if (globalLogin == False):
                email,password = setMail()
            else:
                email = globalEmail
                password = globalPassword
                
            if (login(email,password,'real'))!= -1:
                deleteUser()
                deleteAccounts()
                sTime = time()
                raw_input('Execution time : %d' %(time() - sTime) + '\nPress Enter to continue: ')                
              
        def connectAcc():
            if (globalLogin == False):
                email,password = setMail()
            else:
                email = globalEmail
                password = globalPassword
                
            if (login(email,password,'real'))!= -1:
                sTime = time()
                massLoginTest()
                raw_input('Execution time : %d' %(time() - sTime) + '\nPress Enter to continue:')                
         
        def friendRequest():
            sTime = time()
            massLogin()
            friendshipRequest()
            raw_input('Execution time : %d' %(time() - sTime) + '\nPress Enter to continue:')
        
        def friendAccept():
            sTime = time()
            massLogin()
            acceptRequest()
            raw_input('Execution time : %d' %(time() - sTime) + '\nPress Enter to continue:')                
        
        def back():
            option = 0

        testAccountsOptions = {
                               1 : createAcc,
                               2 : deleteAcc,
                               3 : connectAcc,
                               4 : friendRequest,
                               5 : friendAccept,
                               6 : back,
                            }
        while option not in testAccountsOptions.keys():
            print '======= Test account options ======='
            print '1)  Create accounts\n'
            print '2)  Delete all accounts for a given user\n'
            print '3)  Connect all the accounts of the database\n'
            print '4)  Send friendship requests (Test Accounts)\n'
            print '5)  Accept friendship requests (Test Accounts)\n'
            print '6)  Take me back\n'
            try:
                option = int(raw_input('Insert your choice: '))
            except:
                print 'That\'s not an integer, try again'
        
        #Executes and restores variable after
        testAccountsOptions[option]()
        option = 0
        
    def phishingVectors():
        option = 0

        def previewSimple():
            if (globalLogin == False):
                email,password = setMail()
            else:
                email = globalEmail
                password = globalPassword
            
            if (login(email,password,'real'))!= -1:
                option = raw_input("Insert option for privacy 0:Public 1:Friends 2:Only Me : ")
                if option in privacySet:  
                    summary = raw_input('Insert a summary for the link: ')
                    link = raw_input('Insert de evil link: ')
                    realLink = raw_input('Insert de real link: ')
                    title = raw_input('Insert a title for the link: ')
                    image = raw_input('Insert the image url for the post: ')
                    comment = raw_input('Insert a comment for the post associated: ')
                    linkPreview(link,realLink,title,summary,comment,image, privacy[option])
                else:
                    print "Wrong privacy value, try again "                       

        def previewYoutube():
            if (globalLogin == False):
                email,password = setMail()
            else:
                email = globalEmail
                password = globalPassword
            
            if (login(email,password,'real'))!= -1:
                option = raw_input("Insert option for privacy 0:Public 1:Friends 2:Only Me : ")
                if option in privacySet:
                    summary = raw_input('Insert a summary for the video: ')
                    link = raw_input('Insert de evil link: ')
                    videoLink = raw_input('Insert de youtube link: ')
                    title = raw_input('Insert a title for the video: ')
                    videoID = raw_input('Insert the video ID (w?=): ')
                    comment = raw_input('Insert a comment for the post associated to the video: ')
                    linkPreviewYoutube(link,videoLink,title,summary,comment,videoID,privacy[option])
                else:
                    print "Wrong privacy value, try again "                 
        
        def youtubeHijack():   
            if (globalLogin == False):
                email,password = setMail()
            else:
                email = globalEmail
                password = globalPassword

            if (login(email,password,'real'))!= -1:
                option = raw_input("Insert option for privacy 0:Public 1:Friends 2:Only Me : ")
                if option in privacySet:
                    summary = raw_input('Insert a summary for the video: ')
                    videoLink = raw_input('Insert de youtube link: ')
                    title = raw_input('Insert a title for the video: ')
                    videoID = raw_input('Insert the video ID (watch?v=): ')
                    comment = raw_input('Insert a comment for the post associated to the video: ')
                    hijackedVideo = raw_input('Insert the ID for the hijacked video (watch?v=): ')
                    hijackVideo(videoLink,title,summary,comment,videoID,hijackedVideo,privacy[option])
                else:
                    print "Wrong privacy value, try again "
            
        def messageSimple():
            if (globalLogin == False):
                email,password = setMail()
            else:
                email = globalEmail
                password = globalPassword

            if (login(email,password,'real'))!= -1:
                victim = raw_input('Insert the victims user ID: ')
                realLink = raw_input('Insert the real link: ')
                title = raw_input('Insert a title for the link: ')
                subject = raw_input('Insert the subject: ')
                summary = raw_input('Insert a summary for the link: ')
                message = raw_input('Insert the body of the message: ')            
                evilLink = raw_input('Insert the evil link: ')
                imageLink = raw_input('Insert the image associated to the post: ')
                privateMessageLink(message,victim,subject,realLink,title,summary,imageLink,evilLink)    

        def messageYoutube():
            if (globalLogin == False):
                email,password = setMail()
            else:
                email = globalEmail
                password = globalPassword

            if (login(email,password,'real'))!= -1:
                victimId = raw_input('Insert the victims user ID: ')
                subject = raw_input('Insert the subject: ')
                message = raw_input('Insert the message: ')
                title = raw_input('Insert a title for the video: ')
                summary = raw_input('Insert a summary for the video: ')
                videoLink = raw_input('Insert de youtube link: ')
                evilLink = raw_input('Insert the evil link (For hijacking insert same link as above): ')
                videoID = raw_input('Insert the video ID (watch?v=): ')
                hijackedVideo = raw_input('Insert the ID for the hijacked video (watch?v=) - For Non-Hijackig press enter: ')          
                privateMessagePhishing(victimId,message,subject,evilLink,videoLink,title,summary,videoID,hijackedVideo)

        def appSpoof():
            if (globalLogin == False):
                email,password = setMail()
            else:
                email = globalEmail
                password = globalPassword            
            
            if (login(email,password,'real'))!= -1:
                appId = raw_input('Insert a valid AppId: ')
                link = raw_input('Insert de evil link: ')
                picture = raw_input('Insert a link to a picture for the post: ')
                title = raw_input('Insert a title for the post: ')
                domain = raw_input('Insert a domain for the post: ')
                description = raw_input('Insert a description for the post: ')
                comment = raw_input('Insert a comment for the post: ')
                appMessageSpoof(appId,link,picture,title,domain,description,comment)
        
        def requestList():              
            warning = True
            while ( (warning is not '0') and (warning is not '1')):
                warning = raw_input('Your account could be blocked.. Continue? 0|1: ')

            
            if (warning == '1'):
                
                victim = raw_input('Insert the victim username (Bypass friends list first): ')
                
                if (globalLogin == False):
                    email,password = setMail()
                else:
                    email = globalEmail
                    password = globalPassword
    
                if (login(email,password,'real'))!= -1:
                    sendRequestToList(victim)               
        
        def back():
            option = 0

        phishingVectorsOptions = {
                               1 : previewSimple,
                               2 : previewYoutube,
                               3 : youtubeHijack,
                               4 : messageSimple,
                               5 : messageYoutube,
                               6 : appSpoof,
                               7 : requestList,
                               8 : back,
                            }
        while option not in phishingVectorsOptions.keys():
            print '======= Phishing vector options ======='
            print '1)  Link Preview hack (Simple web version)\n'
            print '2)  Link Preview hack (Youtube version)\n'
            print '3)  Youtube hijack\n'
            print '4)  Private message, Link Preview hack (Simple web version)\n'
            print '5)  Private message, Link Preview hack (Youtube version)\n'
            print '6)  Publish a post as an App (App Message Spoof)\n'
            print '7)  Send friend request to disclosed friend list from your account\n'
            print '8)  Take me back\n'
            try:
                option = int(raw_input('Insert your choice: '))
            except:
                print 'That\'s not an integer, try again'
        
        #Executes and restores variable after
        phishingVectorsOptions[option]()
        option = 0
        
    def OSINT():
        option = 0
        
        def bypass():
            if (globalLogin == False):
                email,password = setMail()
            else:
                email = globalEmail
                password = globalPassword

            if (login(email,password,'real'))!= -1:
                victim = raw_input('Insert the victim username or userId: ')
                transitive = raw_input('Insert the transitive username or userId: ')
                
                print "The information will be stored in %s. \n" % os.path.join("dumps",victim+".txt")
                bypassFriendshipPrivacy(victim, transitive)            

        def bypassGraph():
            if (globalLogin == False):
                email,password = setMail()
            else:
                email = globalEmail
                password = globalPassword

            if (login(email,password,'real'))!= -1:
                victim = raw_input('Insert the victim username or userId: ')
                check = checkPrivacy(victim)
                if (check == -1):
                    transitive = raw_input('Insert the transitive username or userId: ')
                    print 'The information will be stored in %s \n' % os.path.join("dumps",victim,victim+".txt")
                    bypassFriendshipPrivacyPlot(victim, transitive)
                else:
                    print 'Friends available public ;D'
                    victim = checkMe(victim)
                    friendList, friendsName = friendshipPlot(check,victim)
                    simpleGraph(friendList, victim)         
        
        def analize():
            if (globalLogin == False):
                email,password = setMail()
            else:
                email = globalEmail
                password = globalPassword
            
            if (login(email,password,'real'))!= -1:
                analize = int(raw_input('Analyze an existing one, or a new one? (0|1): '))
                victim = raw_input('Insert the victim username or userId: ')
                if (analize == 1):
                    analyzeGraph(victim)
                else:
                    reAnalyzeGraph(victim)                
        
        
        def linkDisclosed():
            fileName = raw_input('Insert the victim username: ')
            linkFriends(fileName)            

        def bypassDot():
            if (globalLogin == False):
                email,password = setMail()
            else:
                email = globalEmail
                password = globalPassword

            if (login(email,password,'real'))!= -1:
                victim = raw_input('Insert the victim username or userId: ')
                check = checkPrivacy(victim)
                if (check == -1):
                    transitive = raw_input('Insert the transitive username or userId: ')
                    print 'The information will be stored in %s \n' % os.path.join("dumps",victim,victim+".txt")
                    dotFile(victim, transitive)
                else:
                    print 'Friends publicly available ;D'
                    friendList, friendsName = friendshipPlot(check,victim)
                    simpleDotGraph(friendsName, victim)            

        def bypassDB():
            if (globalLogin == False):
                email,password = setMail()
            else:
                email = globalEmail
                password = globalPassword             
            if (login(email,password,'real'))!= -1:
                victim = raw_input('Insert the victim username or userId: ')
                if ( checkTableExistence(victim) != True):
                    if (createVictimTable(victim) != -1):
                        check = checkPrivacy(victim)
                        if (check == -1):
                            transitive = raw_input('Insert the transitive username or userId: ')
                            dotFileDatabase(victim, transitive)
                            plotDOT(victim)    
                        else:
                            print 'Friends publicly available ;D'
                            friendList, friendsName = friendshipPlot(check,victim)
                            simpleDotGraphDatabase(friendsName, victim)
                            plotDOT(victim)
                                     
        def publicFriends():
            email,password = setMail()
            if (login(email,password,'real'))!= -1:
                username = raw_input("Insert the username: ")
                getFriends(username)             

        def idFromUsername():
            email,password = setMail()
            if (login(email,password,'real'))!= -1:
                username = raw_input("Insert the username: ")
                getUserIDS(username)
                
        def back():
            option = 0


        OSINTOptions = {
                               1 : bypass,
                               2 : bypassGraph,
                               3 : analize,
                               4 : linkDisclosed,
                               5 : bypassDot,
                               6 : bypassDB,
                               7 : publicFriends,
                               8 : idFromUsername,
                               9 : back,
                            }
        while option not in OSINTOptions.keys():
            print '======= OSINT options ======='
            print '1)  Bypass friendship privacy\n'
            print '2)  Bypass friendship privacy with graph support\n'
            print '3)  Analyze an existing graph\n'
            print '4)  Link to disclosed friendships\n'
            print '5)  Bypass friendship (only .dot without graph integration)\n'
            print '6)  Bypass - database support (Beta) \n '
            print '7)  Get public friends\n'
            print '8)  Get userIDS from usernames\n'
            print '9)  Take me back\n'
            
            try:
                option = int(raw_input('Insert your choice: '))
                if option > len(OSINTOptions):
                    raise
            except:
                print 'That\'s not an integer, try again'
        
        #Executes and restores variable after
        OSINTOptions[option]()
        option = 0
        
    def bruteforcing():
        option = 0

        def userEnumeration():
            mailFile = raw_input('Insert the filename that contains the list of emails (place it in PRIVATE folder first): ')
            raw_input('Verified emails will be stored in PRIVATE --> existence --> verified.txt ')
            accountexists(mailFile)                

        def bruteforce():
            mailFile = raw_input('Insert the filename that contains the list of emails and passwords (place it in PRIVATE folder first) with email:password pattern: ')
            raw_input('Verified loggins will be stored in PRIVATE --> loggedin --> loggedin.txt ')
            checkLogin(mailFile)            

        def celBruteforce():
            if (globalLogin == False):
                email,password = setMail()
            else:
                email = globalEmail
                password = globalPassword
            if (login(email,password,'real'))!= -1:
                raw_input('Dumps will be stored in cellphones --> cellphones.txt')
                first = raw_input('Insert the zone code: ')
                start = raw_input('Insert the start number: ')
                end = raw_input('Insert the end number: ')
                bruteforceCel(first,start,end)
        
        def back():
            option = 0
            
        bruteforcingOptions = {
                               1 : userEnumeration,
                               2 : bruteforce,
                               3 : celBruteforce,
                               4 : back,
                            }
        while option not in bruteforcingOptions.keys():
            print '======= Bruteforce options ======='
            print '1)  Check existence of mails\n'
            print '2)  Check working account and passwords\n'
            print '3)  Bruteforce cellphones\n'
            print '4)  Take me back\n'
            try:
                option = int(raw_input('Insert your choice: '))
                if option > len(bruteforcingOptions):
                    raise
            except:
                print 'That\'s not an integer, try again'
        
        #Executes and restores variable after
        bruteforcingOptions[option]()
        option = 0
        
    def gathering():
        option = 0
        
        def photosSingleCredential():
            if (globalLogin == False):
                email,password = setMail()
            else:
                email = globalEmail
                password = globalPassword

            if (login(email,password,'real'))!= -1:
                threads = raw_input('Usage: insert the threads filename and place it in massive folder first: ')
                raw_input('Dumps will be stored in massive\\photos')
                takePhotos(threads)
        
        def photosMassive():
            raw_input('You must first run option 30, if you didn\'t I will fail :D ')
            raw_input('Dumps will be stored in massive\\photos')
            steal()
        
        def back():
            option = 0
            
        gatheringOptions = {
                               1 : photosSingleCredential,
                               2 : photosMassive,
                               3 : back,
                            }
        while option not in gatheringOptions.keys():
            print '======= Gathering options ======='
            print '1)  Take the photos!\n'
            print '2)  Steal private photos from password verified dump\n'
            print '3)  Take me back\n'
            
            try:
                option = int(raw_input('Insert your choice: '))
                if option > len(gatheringOptions):
                    raise
            except:
                print 'That\'s not an integer, try again'
        
        #Executes and restores variable after
        gatheringOptions[option]()
        option = 0            
    
    def miscellaneous():
        option = 0

        def broadcast():
            while True:
                online = raw_input("Send only to online friends? 0|1: ")
                if ((int(online) == 1) or (int(online) == 0)):
                    break
                
            email,password = setMail()
            if (login(email,password,'real'))!= -1:
                sendBroadcast(int(1))               

        def ddos():
            print 'Facebook note DDoS attack, discovered by chr13: http://chr13.com/about-me/'
            if (globalLogin == False):
                email,password = setMail()
            else:
                email = globalEmail
                password = globalPassword

            if (login(email,password,'real'))!= -1:
                raw_input('Usage: First you must create an empty note. Once your note is created, write down the note ID number from the URL. ENTER TO CONTINUE...')
                imageURL = raw_input('Insert the image URL from the site attack: ')
                noteID = raw_input('Insert the note ID: ')
                option = raw_input("Insert option for privacy 0:Public 1:Friends 2:Only Me : ")
                if option in privacySet:
                    noteDDoS(imageURL,noteID, privacy[option])                

        def spam():
            if (globalLogin == False):
                email,password = setMail()
            else:
                email = globalEmail
                password = globalPassword             
            if (login(email,password,'real'))!= -1:
                victim = raw_input('Insert the victim username or userId: ')
                if ( checkTableExistence(victim) != True):
                    if (createVictimTable(victim) != -1):
                        check = checkPrivacy(victim)
                        if (check == -1):
                            transitive = raw_input('Insert the transitive username or userId: ')
                            dotFileDatabase(victim, transitive)
                            plotDOT(victim)    
                        else:
                            print 'Friends publicly available ;D'
                            friendList, friendsName = friendshipPlot(check,victim)
                            simpleDotGraphDatabase(friendsName, victim)
                            plotDOT(victim)

        def friendly():

            if (globalLogin == False):
                email,password = setMail()
            else:
                email = globalEmail
                password = globalPassword

            if (login(email,password,'real'))!= -1:
                raw_input('Usage: First you must create an empty note. Once your note is created, write down the note ID number from the URL. ENTER TO CONTINUE...')
                noteID = raw_input('Insert the note ID: ')
                option = raw_input("Insert option for privacy 0:Public 1:Friends 2:Only Me : ")
                if option in privacySet:
                    friendlyLogout(noteID, privacy[option])
                       
        def newLike():
            try:
                counter = 0
                postId = []
                
                print "Insert the Post ID's (Must be from a page). If no more posts for adding,insert '0' :"
                while True:
                    response = raw_input('post[%d]:'%counter)
                    if ( response is not '0' ):
                        counter+=1
                        postId.append(response)
                    else:
                        break
                        
                likeDev(postId)
            except EOFError:
                print 'EOFError'
                stdin.flush()
                pass
            except signalCaught as e:
                print ' %s' %e.args[0]
                        
        def oldLike():

            try:
                counter = 0
                postId = []
                
                print 'Insert the Post ID\'s (Must be from a page). If no more posts for adding,insert \'0\' :'
                while True:
                    response = raw_input('post[%d]:'%counter)
                    if ( response is not '0' ):
                        counter+=1
                        postId.append(response)
                    else:
                        break
                
                quantity = raw_input('Insert the amount of likes: ')
                like(postId, quantity)
            except EOFError:
                print 'EOFError'
                stdin.flush()
                pass
            except signalCaught as e:
                print ' %s' %e.args[0]  
                raw_input('Press enter to continue..')

        def dead():
            print 'Mail bomber through test accounts'
            print 'Test accounts massive creation'
            print 'Blocked Test account login bypass'
            print 'We hope this tool to be useless in the future'
            raw_input('Press enter to continue: ')

        def back():
            option = 0

        miscellaneousOptions = {
                               1 : broadcast,
                               2 : ddos,
                               3 : spam,
                               4 : friendly,
                               5 : newLike,
                               6 : oldLike,
                               7 : dead,
                               8 : back,
                            }
        while option not in miscellaneousOptions.keys():
            print '======= Miscellaneous options ======='
            print '1)  Send broadcast to friends (Individual messages)\n'
            print '2)  Note DDoS attack\n'
            print '3)  SPAM any fanpage inbox\n'
            print '4)  Logout all your friends - FB blackout \n'
            print '5)  NEW Like flood\n'
            print '6)  Old Like Flood (Not working)\n'
            print '7)  Print dead attacks :\'( \n'
            print '8)  Take me back\n'
            
            try:
                option = int(raw_input('Insert your choice: '))
                if option > len(miscellaneousOptions):
                    raise
            except:
                print 'That\'s not an integer, try again'
        
        #Executes and restores variable after
        miscellaneousOptions[option]()
        option = 0  
        
    def configuration():
        option = 0

        def statusDB():

            status()
            raw_input('Press enter to continue: ')


        def loggingLevel():

            print 'This will increase the execution time significantly'
            setGlobalLogginng()
            
        def back():
            option = 0

        configurationOptions = {
                               1 : statusDB,
                               2 : loggingLevel,
                               3 : back,
                            }
        while option not in configurationOptions.keys():
            print '======= Configuration options ======='
            print '1)  Print database status\n'
            print '2)  Increase logging level globally\n'
            print '3)  Take me back\n'
            
            try:
                option = int(raw_input('Insert your choice: '))
                if option > len(configurationOptions):
                    raise
            except:
                print 'That\'s not an integer, try again'
        
        #Executes and restores variable after
        configurationOptions[option]()
        option = 0  
        
    def exitFBHT():

        connect.close()
        
        print '\n \n \n \n \n \n\n \n \n \n \n \n\n \n \n \n \n \n\n \n \n \n \n \n\n \n \n \n '                        
        print ' _    _            _      _______ _            _____  _                  _   _  '
        print '| |  | |          | |    |__   __| |          |  __ \| |                | | | | '
        print '| |__| | __ _  ___| | __    | |  | |__   ___  | |__) | | __ _ _ __   ___| |_| | '
        print '|  __  |/ _` |/ __| |/ /    | |  |  _ \ / _ \ |  ___/| |/ _` |  _ \ / _ \ __| | '
        print '| |  | | (_| | (__|   <     | |  | | | |  __/ | |    | | (_| | | | |  __/ |_|_| '
        print '|_|  |_|\__,_|\___|_|\_\    |_|  |_| |_|\___| |_|    |_|\__,_|_| |_|\___|\__(_) '
        print '\n \n \n \n \n \n\n \n \n \n \n \n\n \n \n \n \n \n\n \n \n \n \n \n\n \n \n \n '

        exit(0)
        
    options = {
               1 : testAccounts,
               2 : phishingVectors,
               3 : OSINT,
               4 : bruteforcing,
               5 : gathering,
               6 : miscellaneous,
               7 : configuration,
               8 : exitFBHT,
            }
  
    while 1:
        signal.signal(signal.SIGINT, signal_handler)
        option = -1
        while option not in options.keys():
            
            print '1) Test accounts'
            print '2) Phishing vectors'
            print '3) OSINT'
            print '4) Bruteforcing'
            print '5) Gathering information with credentials'
            print '6) Miscellaneous'
            print '7) Configuration'
            print '8) Take me out of here'
            
            try:
                option = int(raw_input('Insert your choice: '))
                if option > len(options):
                    raise
            except:
                print 'That\'s not an integer, try again'
        
        #Executes and restores variable after
        options[option]()
        option = 0

Example 27

Project: tensorflow-char-rnn
Source File: train.py
View license
def main():
    parser = argparse.ArgumentParser()

    # Data and vocabulary file
    parser.add_argument('--data_file', type=str,
                        default='data/tiny_shakespeare.txt',
                        help='data file')

    parser.add_argument('--encoding', type=str,
                        default='utf-8',
                        help='the encoding of the data file.')

    # Parameters for saving models.
    parser.add_argument('--output_dir', type=str, default='output',
                        help=('directory to store final and'
                              ' intermediate results and models.'))

    # Parameters to configure the neural network.
    parser.add_argument('--hidden_size', type=int, default=128,
                        help='size of RNN hidden state vector')
    parser.add_argument('--embedding_size', type=int, default=0,
                        help='size of character embeddings')
    parser.add_argument('--num_layers', type=int, default=2,
                        help='number of layers in the RNN')
    parser.add_argument('--num_unrollings', type=int, default=10,
                        help='number of unrolling steps.')
    parser.add_argument('--model', type=str, default='lstm',
                        help='which model to use (rnn, lstm or gru).')
    
    # Parameters to control the training.
    parser.add_argument('--num_epochs', type=int, default=50,
                        help='number of epochs')
    parser.add_argument('--batch_size', type=int, default=20,
                        help='minibatch size')
    parser.add_argument('--train_frac', type=float, default=0.9,
                        help='fraction of data used for training.')
    parser.add_argument('--valid_frac', type=float, default=0.05,
                        help='fraction of data used for validation.')
    # test_frac is computed as (1 - train_frac - valid_frac).
    parser.add_argument('--dropout', type=float, default=0.0,
                        help='dropout rate, default to 0 (no dropout).')

    parser.add_argument('--input_dropout', type=float, default=0.0,
                        help=('dropout rate on input layer, default to 0 (no dropout),'
                              'and no dropout if using one-hot representation.'))

    # Parameters for gradient descent.
    parser.add_argument('--max_grad_norm', type=float, default=5.,
                        help='clip global grad norm')
    parser.add_argument('--learning_rate', type=float, default=2e-3,
                        help='initial learning rate')
    parser.add_argument('--decay_rate', type=float, default=0.95,
                        help='decay rate')

    # Parameters for logging.
    parser.add_argument('--log_to_file', dest='log_to_file', action='store_true',
                        help=('whether the experiment log is stored in a file under'
                              '  output_dir or printed at stdout.'))
    parser.set_defaults(log_to_file=False)
    
    parser.add_argument('--progress_freq', type=int,
                        default=100,
                        help=('frequency for progress report in training'
                              ' and evalution.'))

    parser.add_argument('--verbose', type=int,
                        default=0,
                        help=('whether to show progress report in training'
                              ' and evalution.'))

    # Parameters to feed in the initial model and current best model.
    parser.add_argument('--init_model', type=str,
                        default='',
                        help=('initial model'))
    parser.add_argument('--best_model', type=str,
                        default='',
                        help=('current best model'))
    parser.add_argument('--best_valid_ppl', type=float,
                        default=np.Inf,
                        help=('current valid perplexity'))
    
    # Parameters for using saved best models.
    parser.add_argument('--init_dir', type=str, default='',
                        help='continue from the outputs in the given directory')

    # Parameters for debugging.
    parser.add_argument('--debug', dest='debug', action='store_true',
                        help='show debug information')
    parser.set_defaults(debug=False)

    # Parameters for unittesting the implementation.
    parser.add_argument('--test', dest='test', action='store_true',
                        help=('use the first 1000 character to as data'
                              ' to test the implementation'))
    parser.set_defaults(test=False)
    
    args = parser.parse_args()

    # Specifying location to store model, best model and tensorboard log.
    args.save_model = os.path.join(args.output_dir, 'save_model/model')
    args.save_best_model = os.path.join(args.output_dir, 'best_model/model')
    args.tb_log_dir = os.path.join(args.output_dir, 'tensorboard_log/')
    args.vocab_file = ''

    # Create necessary directories.
    if args.init_dir:
        args.output_dir = args.init_dir
    else:
        if os.path.exists(args.output_dir):
            shutil.rmtree(args.output_dir)
        for paths in [args.save_model, args.save_best_model,
                      args.tb_log_dir]:
            os.makedirs(os.path.dirname(paths))

    # Specify logging config.
    if args.log_to_file:
        args.log_file = os.path.join(args.output_dir, 'experiment_log.txt')
    else:
        args.log_file = 'stdout'

    # Set logging file.
    if args.log_file == 'stdout':
        logging.basicConfig(stream=sys.stdout,
                            format='%(asctime)s %(levelname)s:%(message)s', 
                            level=logging.INFO,
                            datefmt='%I:%M:%S')
    else:
        logging.basicConfig(filename=args.log_file,
                            format='%(asctime)s %(levelname)s:%(message)s', 
                            level=logging.INFO,
                            datefmt='%I:%M:%S')

    print('=' * 60)
    print('All final and intermediate outputs will be stored in %s/' % args.output_dir)
    print('All information will be logged to %s' % args.log_file)
    print('=' * 60 + '\n')
    
    if args.debug:
        logging.info('args are:\n%s', args)

    # Prepare parameters.
    if args.init_dir:
        with open(os.path.join(args.init_dir, 'result.json'), 'r') as f:
            result = json.load(f)
        params = result['params']
        args.init_model = result['latest_model']
        best_model = result['best_model']
        best_valid_ppl = result['best_valid_ppl']
        if 'encoding' in result:
            args.encoding = result['encoding']
        else:
            args.encoding = 'utf-8'
        args.vocab_file = os.path.join(args.init_dir, 'vocab.json')
    else:
        params = {'batch_size': args.batch_size,
                  'num_unrollings': args.num_unrollings,
                  'hidden_size': args.hidden_size,
                  'max_grad_norm': args.max_grad_norm,
                  'embedding_size': args.embedding_size,
                  'num_layers': args.num_layers,
                  'learning_rate': args.learning_rate,
                  'model': args.model,
                  'dropout': args.dropout,
                  'input_dropout': args.input_dropout}
        best_model = ''
    logging.info('Parameters are:\n%s\n', json.dumps(params, sort_keys=True, indent=4))

    # Read and split data.
    logging.info('Reading data from: %s', args.data_file)
    with codecs.open(args.data_file, 'r', encoding=args.encoding) as f:
        text = f.read()

    if args.test:
        text = text[:1000]
    logging.info('Number of characters: %s', len(text))

    if args.debug:
        n = 10        
        logging.info('First %d characters: %s', n, text[:n])

    logging.info('Creating train, valid, test split')
    train_size = int(args.train_frac * len(text))
    valid_size = int(args.valid_frac * len(text))
    test_size = len(text) - train_size - valid_size
    train_text = text[:train_size]
    valid_text = text[train_size:train_size + valid_size]
    test_text = text[train_size + valid_size:]

    if args.vocab_file:
        vocab_index_dict, index_vocab_dict, vocab_size = load_vocab(args.vocab_file, args.encoding)
    else:
        logging.info('Creating vocabulary')
        vocab_index_dict, index_vocab_dict, vocab_size = create_vocab(text)
        vocab_file = os.path.join(args.output_dir, 'vocab.json')
        save_vocab(vocab_index_dict, vocab_file, args.encoding)
        logging.info('Vocabulary is saved in %s', vocab_file)
        args.vocab_file = vocab_file

    params['vocab_size'] = vocab_size
    logging.info('Vocab size: %d', vocab_size)

    # Create batch generators.
    batch_size = params['batch_size']
    num_unrollings = params['num_unrollings']
    train_batches = BatchGenerator(train_text, batch_size, num_unrollings, vocab_size, 
                                   vocab_index_dict, index_vocab_dict)
    # valid_batches = BatchGenerator(valid_text, 1, 1, vocab_size,
    #                                vocab_index_dict, index_vocab_dict)
    valid_batches = BatchGenerator(valid_text, batch_size, num_unrollings, vocab_size,
                                   vocab_index_dict, index_vocab_dict)

    test_batches = BatchGenerator(test_text, 1, 1, vocab_size,
                                  vocab_index_dict, index_vocab_dict)

    if args.debug:
        logging.info('Test batch generators')
        logging.info(batches2string(train_batches.next(), index_vocab_dict))
        logging.info(batches2string(valid_batches.next(), index_vocab_dict))
        logging.info('Show vocabulary')
        logging.info(vocab_index_dict)
        logging.info(index_vocab_dict)
        
    # Create graphs
    logging.info('Creating graph')
    graph = tf.Graph()
    with graph.as_default():
        with tf.name_scope('training'):
            train_model = CharRNN(is_training=True, use_batch=True, **params)
        tf.get_variable_scope().reuse_variables()
        with tf.name_scope('validation'):
            valid_model = CharRNN(is_training=False, use_batch=True, **params)
        with tf.name_scope('evaluation'):
            test_model = CharRNN(is_training=False, use_batch=False, **params)
            saver = tf.train.Saver(name='checkpoint_saver')
            best_model_saver = tf.train.Saver(name='best_model_saver')

    logging.info('Model size (number of parameters): %s\n', train_model.model_size)
    logging.info('Start training\n')

    result = {}
    result['params'] = params
    result['vocab_file'] = args.vocab_file
    result['encoding'] = args.encoding

    try:
        # Use try and finally to make sure that intermediate
        # results are saved correctly so that training can
        # be continued later after interruption.
        with tf.Session(graph=graph) as session:
            # Version 8 changed the api of summary writer to use
            # graph instead of graph_def.
            if TF_VERSION >= 8:
                graph_info = session.graph
            else:
                graph_info = session.graph_def

            train_writer = tf.train.SummaryWriter(args.tb_log_dir + 'train/', graph_info)
            valid_writer = tf.train.SummaryWriter(args.tb_log_dir + 'valid/', graph_info)

            # load a saved model or start from random initialization.
            if args.init_model:
                saver.restore(session, args.init_model)
            else:
                tf.initialize_all_variables().run()
            for i in range(args.num_epochs):
                logging.info('=' * 19 + ' Epoch %d ' + '=' * 19 + '\n', i)
                logging.info('Training on training set')
                # training step
                ppl, train_summary_str, global_step = train_model.run_epoch(
                    session,
                    train_size,
                    train_batches,
                    is_training=True,
                    verbose=args.verbose,
                    freq=args.progress_freq)
                # record the summary
                train_writer.add_summary(train_summary_str, global_step)
                train_writer.flush()
                # save model
                saved_path = saver.save(session, args.save_model,
                                                    global_step=train_model.global_step)
                logging.info('Latest model saved in %s\n', saved_path)
                logging.info('Evaluate on validation set')

                # valid_ppl, valid_summary_str, _ = valid_model.run_epoch(
                valid_ppl, valid_summary_str, _ = valid_model.run_epoch(
                    session,
                    valid_size,
                    valid_batches, 
                    is_training=False,
                    verbose=args.verbose,
                    freq=args.progress_freq)

                # save and update best model
                if (not best_model) or (valid_ppl < best_valid_ppl):
                    best_model = best_model_saver.save(
                        session,
                        args.save_best_model,
                        global_step=train_model.global_step)
                    best_valid_ppl = valid_ppl
                valid_writer.add_summary(valid_summary_str, global_step)
                valid_writer.flush()
                logging.info('Best model is saved in %s', best_model)
                logging.info('Best validation ppl is %f\n', best_valid_ppl)
                result['latest_model'] = saved_path
                result['best_model'] = best_model
                # Convert to float because numpy.float is not json serializable.
                result['best_valid_ppl'] = float(best_valid_ppl)
                result_path = os.path.join(args.output_dir, 'result.json')
                if os.path.exists(result_path):
                    os.remove(result_path)
                with open(result_path, 'w') as f:
                    json.dump(result, f, indent=2, sort_keys=True)

            logging.info('Latest model is saved in %s', saved_path)
            logging.info('Best model is saved in %s', best_model)
            logging.info('Best validation ppl is %f\n', best_valid_ppl)
            logging.info('Evaluate the best model on test set')
            saver.restore(session, best_model)
            test_ppl, _, _ = test_model.run_epoch(session, test_size, test_batches,
                                                   is_training=False,
                                                   verbose=args.verbose,
                                                   freq=args.progress_freq)
            result['test_ppl'] = float(test_ppl)
    finally:
        result_path = os.path.join(args.output_dir, 'result.json')
        if os.path.exists(result_path):
            os.remove(result_path)
        with open(result_path, 'w') as f:
            json.dump(result, f, indent=2, sort_keys=True)

Example 28

Project: crossbar
Source File: router.py
View license
    def _create_resource(self, path_config, nested=True):
        """
        Creates child resource to be added to the parent.

        :param path_config: Configuration for the new child resource.
        :type path_config: dict

        :returns: Resource -- the new child resource
        """
        # WAMP-WebSocket resource
        #
        if path_config['type'] == 'websocket':

            ws_factory = WampWebSocketServerFactory(self._router_session_factory, self.config.extra.cbdir, path_config, self._templates)

            # FIXME: Site.start/stopFactory should start/stop factories wrapped as Resources
            ws_factory.startFactory()

            return WebSocketResource(ws_factory)

        # Static file hierarchy resource
        #
        elif path_config['type'] == 'static':

            static_options = path_config.get('options', {})

            if 'directory' in path_config:

                static_dir = os.path.abspath(os.path.join(self.config.extra.cbdir, path_config['directory']))

            elif 'package' in path_config:

                if 'resource' not in path_config:
                    raise ApplicationError(u"crossbar.error.invalid_configuration", "missing resource")

                try:
                    mod = importlib.import_module(path_config['package'])
                except ImportError as e:
                    emsg = "Could not import resource {} from package {}: {}".format(path_config['resource'], path_config['package'], e)
                    self.log.error(emsg)
                    raise ApplicationError(u"crossbar.error.invalid_configuration", emsg)
                else:
                    try:
                        static_dir = os.path.abspath(pkg_resources.resource_filename(path_config['package'], path_config['resource']))
                    except Exception as e:
                        emsg = "Could not import resource {} from package {}: {}".format(path_config['resource'], path_config['package'], e)
                        self.log.error(emsg)
                        raise ApplicationError(u"crossbar.error.invalid_configuration", emsg)

            else:

                raise ApplicationError(u"crossbar.error.invalid_configuration", "missing web spec")

            static_dir = static_dir.encode('ascii', 'ignore')  # http://stackoverflow.com/a/20433918/884770

            # create resource for file system hierarchy
            #
            if static_options.get('enable_directory_listing', False):
                static_resource_class = StaticResource
            else:
                static_resource_class = StaticResourceNoListing

            cache_timeout = static_options.get('cache_timeout', DEFAULT_CACHE_TIMEOUT)

            static_resource = static_resource_class(static_dir, cache_timeout=cache_timeout)

            # set extra MIME types
            #
            static_resource.contentTypes.update(EXTRA_MIME_TYPES)
            if 'mime_types' in static_options:
                static_resource.contentTypes.update(static_options['mime_types'])
            patchFileContentTypes(static_resource)

            # render 404 page on any concrete path not found
            #
            static_resource.childNotFound = Resource404(self._templates, static_dir)

            return static_resource

        # WSGI resource
        #
        elif path_config['type'] == 'wsgi':

            if not _HAS_WSGI:
                raise ApplicationError(u"crossbar.error.invalid_configuration", "WSGI unsupported")

            if 'module' not in path_config:
                raise ApplicationError(u"crossbar.error.invalid_configuration", "missing WSGI app module")

            if 'object' not in path_config:
                raise ApplicationError(u"crossbar.error.invalid_configuration", "missing WSGI app object")

            # import WSGI app module and object
            mod_name = path_config['module']
            try:
                mod = importlib.import_module(mod_name)
            except ImportError as e:
                raise ApplicationError(u"crossbar.error.invalid_configuration", "WSGI app module '{}' import failed: {} - Python search path was {}".format(mod_name, e, sys.path))
            else:
                obj_name = path_config['object']
                if obj_name not in mod.__dict__:
                    raise ApplicationError(u"crossbar.error.invalid_configuration", "WSGI app object '{}' not in module '{}'".format(obj_name, mod_name))
                else:
                    app = getattr(mod, obj_name)

            # Create a threadpool for running the WSGI requests in
            pool = ThreadPool(maxthreads=path_config.get("maxthreads", 20),
                              minthreads=path_config.get("minthreads", 0),
                              name="crossbar_wsgi_threadpool")
            self._reactor.addSystemEventTrigger('before', 'shutdown', pool.stop)
            pool.start()

            # Create a Twisted Web WSGI resource from the user's WSGI application object
            try:
                wsgi_resource = WSGIResource(self._reactor, pool, app)

                if not nested:
                    wsgi_resource = WSGIRootResource(wsgi_resource, {})
            except Exception as e:
                raise ApplicationError(u"crossbar.error.invalid_configuration", "could not instantiate WSGI resource: {}".format(e))
            else:
                return wsgi_resource

        # Redirecting resource
        #
        elif path_config['type'] == 'redirect':
            redirect_url = path_config['url'].encode('ascii', 'ignore')
            return RedirectResource(redirect_url)

        # Reverse proxy resource
        #
        elif path_config['type'] == 'reverseproxy':
            host = path_config['host']
            port = int(path_config.get('port', 80))
            path = path_config.get('path', '').encode('ascii', 'ignore')
            return ReverseProxyResource(host, port, path)

        # JSON value resource
        #
        elif path_config['type'] == 'json':
            value = path_config['value']

            return JsonResource(value)

        # CGI script resource
        #
        elif path_config['type'] == 'cgi':

            cgi_processor = path_config['processor']
            cgi_directory = os.path.abspath(os.path.join(self.config.extra.cbdir, path_config['directory']))
            cgi_directory = cgi_directory.encode('ascii', 'ignore')  # http://stackoverflow.com/a/20433918/884770

            return CgiDirectory(cgi_directory, cgi_processor, Resource404(self._templates, cgi_directory))

        # WAMP-Longpoll transport resource
        #
        elif path_config['type'] == 'longpoll':

            path_options = path_config.get('options', {})

            lp_resource = WampLongPollResource(self._router_session_factory,
                                               timeout=path_options.get('request_timeout', 10),
                                               killAfter=path_options.get('session_timeout', 30),
                                               queueLimitBytes=path_options.get('queue_limit_bytes', 128 * 1024),
                                               queueLimitMessages=path_options.get('queue_limit_messages', 100),
                                               debug_transport_id=path_options.get('debug_transport_id', None)
                                               )
            lp_resource._templates = self._templates

            return lp_resource

        # Publisher resource (part of REST-bridge)
        #
        elif path_config['type'] == 'publisher':

            # create a vanilla session: the publisher will use this to inject events
            #
            publisher_session_config = ComponentConfig(realm=path_config['realm'], extra=None)
            publisher_session = ApplicationSession(publisher_session_config)

            # add the publisher session to the router
            #
            self._router_session_factory.add(publisher_session, authrole=path_config.get('role', 'anonymous'))

            # now create the publisher Twisted Web resource
            #
            return PublisherResource(path_config.get('options', {}), publisher_session)

        # Webhook resource (part of REST-bridge)
        #
        elif path_config['type'] == 'webhook':

            # create a vanilla session: the webhook will use this to inject events
            #
            webhook_session_config = ComponentConfig(realm=path_config['realm'], extra=None)
            webhook_session = ApplicationSession(webhook_session_config)

            # add the webhook session to the router
            #
            self._router_session_factory.add(webhook_session, authrole=path_config.get('role', 'anonymous'))

            # now create the webhook Twisted Web resource
            #
            return WebhookResource(path_config.get('options', {}), webhook_session)

        # Caller resource (part of REST-bridge)
        #
        elif path_config['type'] == 'caller':

            # create a vanilla session: the caller will use this to inject calls
            #
            caller_session_config = ComponentConfig(realm=path_config['realm'], extra=None)
            caller_session = ApplicationSession(caller_session_config)

            # add the calling session to the router
            #
            self._router_session_factory.add(caller_session, authrole=path_config.get('role', 'anonymous'))

            # now create the caller Twisted Web resource
            #
            return CallerResource(path_config.get('options', {}), caller_session)

        # File Upload resource
        #
        elif path_config['type'] == 'upload':

            upload_directory = os.path.abspath(os.path.join(self.config.extra.cbdir, path_config['directory']))
            upload_directory = upload_directory.encode('ascii', 'ignore')  # http://stackoverflow.com/a/20433918/884770
            if not os.path.isdir(upload_directory):
                emsg = "configured upload directory '{}' in file upload resource isn't a directory".format(upload_directory)
                self.log.error(emsg)
                raise ApplicationError(u"crossbar.error.invalid_configuration", emsg)

            if 'temp_directory' in path_config:
                temp_directory = os.path.abspath(os.path.join(self.config.extra.cbdir, path_config['temp_directory']))
                temp_directory = temp_directory.encode('ascii', 'ignore')  # http://stackoverflow.com/a/20433918/884770
            else:
                temp_directory = os.path.abspath(tempfile.gettempdir())
                temp_directory = os.path.join(temp_directory, 'crossbar-uploads')
                if not os.path.exists(temp_directory):
                    os.makedirs(temp_directory)

            if not os.path.isdir(temp_directory):
                emsg = "configured temp directory '{}' in file upload resource isn't a directory".format(temp_directory)
                self.log.error(emsg)
                raise ApplicationError(u"crossbar.error.invalid_configuration", emsg)

            # file upload progress and finish events are published via this session
            #
            upload_session_config = ComponentConfig(realm=path_config['realm'], extra=None)
            upload_session = ApplicationSession(upload_session_config)

            self._router_session_factory.add(upload_session, authrole=path_config.get('role', 'anonymous'))

            self.log.info("File upload resource started. Uploads to {upl} using temp folder {tmp}.", upl=upload_directory, tmp=temp_directory)

            return FileUploadResource(upload_directory, temp_directory, path_config['form_fields'], upload_session, path_config.get('options', {}))

        # Generic Twisted Web resource
        #
        elif path_config['type'] == 'resource':

            try:
                klassname = path_config['classname']

                self.log.debug("Starting class '{name}'", name=klassname)

                c = klassname.split('.')
                module_name, klass_name = '.'.join(c[:-1]), c[-1]
                module = importlib.import_module(module_name)
                make = getattr(module, klass_name)

                return make(path_config.get('extra', {}))

            except Exception as e:
                emsg = "Failed to import class '{}' - {}".format(klassname, e)
                self.log.error(emsg)
                self.log.error("PYTHONPATH: {pythonpath}", pythonpath=sys.path)
                raise ApplicationError(u"crossbar.error.class_import_failed", emsg)

        # Schema Docs resource
        #
        elif path_config['type'] == 'schemadoc':

            realm = path_config['realm']

            if realm not in self.realm_to_id:
                raise ApplicationError(u"crossbar.error.no_such_object", "No realm with URI '{}' configured".format(realm))

            realm_id = self.realm_to_id[realm]

            realm_schemas = self.realms[realm_id].session._schemas

            return SchemaDocResource(self._templates, realm, realm_schemas)

        # Nested subpath resource
        #
        elif path_config['type'] == 'path':

            nested_paths = path_config.get('paths', {})

            if '/' in nested_paths:
                nested_resource = self._create_resource(nested_paths['/'])
            else:
                nested_resource = Resource404(self._templates, b'')

            # nest subpaths under the current entry
            #
            self._add_paths(nested_resource, nested_paths)

            return nested_resource

        else:
            raise ApplicationError(u"crossbar.error.invalid_configuration",
                                   "invalid Web path type '{}' in {} config".format(path_config['type'],
                                                                                    'nested' if nested else 'root'))

Example 29

Project: pygame_cffi
Source File: run_tests.py
View license
def run(*args, **kwds):
    """Run the Pygame unit test suite and return (total tests run, fails dict)

    Positional arguments (optional):
    The names of tests to include. If omitted then all tests are run. Test
    names need not include the trailing '_test'.

    Keyword arguments:
    incomplete - fail incomplete tests (default False)
    nosubprocess - run all test suites in the current process
                   (default False, use separate subprocesses)
    dump - dump failures/errors as dict ready to eval (default False)
    file - if provided, the name of a file into which to dump failures/errors
    timings - if provided, the number of times to run each individual test to
              get an average run time (default is run each test once)
    exclude - A list of TAG names to exclude from the run. The items may be
              comma or space separated.
    show_output - show silenced stderr/stdout on errors (default False)
    all - dump all results, not just errors (default False)
    randomize - randomize order of tests (default False)
    seed - if provided, a seed randomizer integer
    multi_thread - if provided, the number of THREADS in which to run
                   subprocessed tests
    time_out - if subprocess is True then the time limit in seconds before
               killing a test (default 30)
    fake - if provided, the name of the fake tests package in the
           run_tests__tests subpackage to run instead of the normal
           Pygame tests
    python - the path to a python executable to run subprocessed tests
             (default sys.executable)
    interative - allow tests tagged 'interative'.

    Return value:
    A tuple of total number of tests run, dictionary of error information. The
    dictionary is empty if no errors were recorded.

    By default individual test modules are run in separate subprocesses. This
    recreates normal Pygame usage where pygame.init() and pygame.quit() are
    called only once per program execution, and avoids unfortunate
    interactions between test modules. Also, a time limit is placed on test
    execution, so frozen tests are killed when there time allotment expired.
    Use the single process option if threading is not working properly or if
    tests are taking too long. It is not guaranteed that all tests will pass
    in single process mode.

    Tests are run in a randomized order if the randomize argument is True or a
    seed argument is provided. If no seed integer is provided then the system
    time is used.

    Individual test modules may have a corresponding *_tags.py module,
    defining a __tags__ attribute, a list of tag strings used to selectively
    omit modules from a run. By default only the 'interactive', 'ignore', and
    'subprocess_ignore' tags are ignored. 'interactive' is for modules that
    take user input, like cdrom_test.py. 'ignore' and 'subprocess_ignore' for
    for disabling modules for foreground and subprocess modes respectively.
    These are for disabling tests on optional modules or for experimental
    modules with known problems. These modules can be run from the console as
    a Python program.

    This function can only be called once per Python session. It is not
    reentrant.

    """

    global was_run

    if was_run:
        raise RuntimeError("run() was already called this session")
    was_run = True
                           
    options = kwds.copy()
    option_nosubprocess = options.get('nosubprocess', False)
    option_dump = options.pop('dump', False)
    option_file = options.pop('file', None)
    option_all = options.pop('all', False)
    option_randomize = options.get('randomize', False)
    option_seed = options.get('seed', None)
    option_multi_thread = options.pop('multi_thread', 1)
    option_time_out = options.pop('time_out', 120)
    option_fake = options.pop('fake', None)
    option_python = options.pop('python', sys.executable)
    option_exclude = options.pop('exclude', ())
    option_interactive = options.pop('interactive', False)

    if not option_interactive and 'interactive' not in option_exclude:
        option_exclude += ('interactive',)
    if not option_nosubprocess and 'subprocess_ignore' not in option_exclude:
        option_exclude += ('subprocess_ignore',)
    elif 'ignore' not in option_exclude:
        option_exclude += ('ignore',)
    if sys.version_info < (3, 0, 0):
        option_exclude += ('python2_ignore',)
    else:
        option_exclude += ('python3_ignore',)

    main_dir, test_subdir, fake_test_subdir = prepare_test_env()
    test_runner_py = os.path.join(test_subdir, "test_utils", "test_runner.py")
    cur_working_dir = os.path.abspath(os.getcwd())

    ###########################################################################
    # Compile a list of test modules. If fake, then compile list of fake
    # xxxx_test.py from run_tests__tests

    TEST_MODULE_RE = re.compile('^(.+_test)\.py$')

    test_mods_pkg_name = test_pkg_name
    
    if option_fake is not None:
        test_mods_pkg_name = '.'.join([test_mods_pkg_name,
                                       'run_tests__tests',
                                       option_fake])
        test_subdir = os.path.join(fake_test_subdir, option_fake)
        working_dir = test_subdir
    else:
        working_dir = main_dir


    # Added in because some machines will need os.environ else there will be
    # false failures in subprocess mode. Same issue as python2.6. Needs some
    # env vars.

    test_env = os.environ

    fmt1 = '%s.%%s' % test_mods_pkg_name
    fmt2 = '%s.%%s_test' % test_mods_pkg_name
    if args:
        test_modules = [
            m.endswith('_test') and (fmt1 % m) or (fmt2 % m) for m in args
        ]
    else:
        test_modules = []
        for f in sorted(os.listdir(test_subdir)):
            for match in TEST_MODULE_RE.findall(f):
                test_modules.append(fmt1 % (match,))

    ###########################################################################
    # Remove modules to be excluded.

    tmp = test_modules
    test_modules = []
    for name in tmp:
        tag_module_name = "%s_tags" % (name[0:-5],)
        try:
            tag_module = import_submodule(tag_module_name)
        except ImportError:
            test_modules.append(name)
        else:
            try:
                tags = tag_module.__tags__
            except AttributeError:
                print ("%s has no tags: ignoring" % (tag_module_name,))
                test_module.append(name)
            else:
                for tag in tags:
                    if tag in option_exclude:
                        print ("skipping %s (tag '%s')" % (name, tag))
                        break
                else:
                    test_modules.append(name)
    del tmp, tag_module_name, name

    ###########################################################################
    # Meta results

    results = {}
    meta_results = {'__meta__' : {}}
    meta = meta_results['__meta__']

    ###########################################################################
    # Randomization

    if option_randomize or option_seed is not None:
        if option_seed is None:
            option_seed = time.time()
        meta['random_seed'] = option_seed
        print ("\nRANDOM SEED USED: %s\n" % option_seed)
        random.seed(option_seed)
        random.shuffle(test_modules)

    ###########################################################################
    # Single process mode

    if option_nosubprocess:
        unittest_patch.patch(**options)

        options['exclude'] = option_exclude
        t = time.time()
        for module in test_modules:
            results.update(run_test(module, **options))
        t = time.time() - t

    ###########################################################################
    # Subprocess mode
    #

    if not option_nosubprocess:
        if is_pygame_pkg:
            from pygame.tests.test_utils.async_sub import proc_in_time_or_kill
        else:
            from test.test_utils.async_sub import proc_in_time_or_kill

        pass_on_args = ['--exclude', ','.join(option_exclude)]
        for option in ['timings', 'seed']:
            value = options.pop(option, None)
            if value is not None:
                pass_on_args.append('--%s' % option)
                pass_on_args.append(str(value))
        for option, value in options.items():
            option = option.replace('_', '-')
            if value:
                pass_on_args.append('--%s' % option)

        def sub_test(module):
            print ('loading %s' % module)

            cmd = [option_python, test_runner_py, module ] + pass_on_args

            return (module,
                    (cmd, test_env, working_dir),
                    proc_in_time_or_kill(cmd, option_time_out, env=test_env,
                                         wd=working_dir))

        if option_multi_thread > 1:
            def tmap(f, args):
                return pygame.threads.tmap (
                    f, args, stop_on_error = False,
                    num_workers = option_multi_thread
                )
        else:
            tmap = map

        t = time.time()

        for module, cmd, (return_code, raw_return) in tmap(sub_test,
                                                           test_modules):
            test_file = '%s.py' % os.path.join(test_subdir, module)
            cmd, test_env, working_dir = cmd

            test_results = get_test_results(raw_return)
            if test_results:
                results.update(test_results)
            else:
                results[module] = {}

            add_to_results = [
                'return_code', 'raw_return',  'cmd', 'test_file',
                'test_env', 'working_dir', 'module',
            ]

            results[module].update(from_namespace(locals(), add_to_results))

        t = time.time() - t

    ###########################################################################
    # Output Results
    #

    untrusty_total, combined = combine_results(results, t)
    total, fails = test_failures(results)

    meta['total_tests'] = total
    meta['combined'] = combined
    results.update(meta_results)

    if option_nosubprocess:
        assert total == untrusty_total

    if not option_dump:
        print (combined)
    else:
        results = option_all and results or fails
        print (TEST_RESULTS_START)
        print (pformat(results))

    if option_file is not None:
        results_file = open(option_file, 'w')
        try:
            results_file.write(pformat(results))
        finally:
            results_file.close()

    return total, fails

Example 30

Project: ck-crowdtuning
Source File: module.py
View license
def html_viewer(i):
    """      
    Input:  {
            }

    Output: {
              return       - return code =  0, if successful
                                         >  0, if error
              (error)      - error text if return > 0

              html
              (style)      - styles - useful for plotting JavaScript-based graphs
            }

    """

    import os
    global cfg, work

    orig_module_uid=work['self_module_uid']

    mcfg=i.get('module_cfg',{})
    if len(mcfg)>0: cfg=mcfg

    mwork=i.get('module_work',{})
    if len(mwork)>0: work=mwork

    st=''

    url0=i['url_base']

    ap=i.get('all_params',{})

    ruoa=i.get('repo_uoa','')
    muoa=work['self_module_uoa']
    muid=work['self_module_uid']
    duoa=i.get('data_uoa','')

    ik=cfg['improvements_keys']

    # Load program optimization entry
    rx=ck.access({'action':'load',
                  'module_uoa':cfg['module_deps']['module'],
                  'data_uoa':cfg['module_deps']['program.optimization']})
    if rx['return']>0: return rx
    urld=rx['dict'].get('url_discuss','')

    # Load Entry
    r=ck.access({'action':'load',
                 'repo_uoa':ruoa,
                 'module_uoa':muoa,
                 'data_uoa':duoa})
    if r['return']>0: 
       return {'return':0, 'html':'<b>CK error:</b> '+r['error']+'!'}

    p=r['path']
    d=r['dict']
    duid=r['data_uid']

    # Load program module to get desc keys
    r=ck.access({'action':'load',
                 'module_uoa':cfg['module_deps']['module'],
                 'data_uoa':cfg['replay_desc']['module_uoa']})
    if r['return']>0: return r
    pdesc=r.get('desc',{})
    xxkey=cfg['replay_desc'].get('desc_key','')
    if xxkey!='':
       pdesc=pdesc.get(xxkey,{})

    h='<center>\n'
    h+='\n\n<script language="JavaScript">function copyToClipboard (text) {window.prompt ("Copy to clipboard: Ctrl+C, Enter", text);}</script>\n\n' 

    h+='<H2>Distinct solutions after online classification ('+cfg['desc']+')</H2>\n'
    h+='</center>\n'

    h+='<p>\n'

    cid=muid+':'+duid

    h+='<table border="0" cellpadding="4" cellspacing="0">\n'
    x=muid
    if muoa!=muid: x+=' ('+muoa+')'
    h+='<tr><td><b>Scenario UID</b></td><td>'+x+'</td></tr>\n'
    h+='<tr><td><b>Data UID</b></td><td>'+duid+'</td></tr>\n'
    h+='<tr><td><td></td></tr>\n'

    url5=ck.cfg.get('wiki_data_web','')

    if url5!='' or urld!='':
       if url5!='':
          x='<a href="'+url5+muid+'_'+duid+'">GitHub wiki</a>'
       if urld!='':
          if x!='': x+=', '
          x+='<a href="'+urld+'">Google group</a>' 

       h+='<tr><td><b>Discuss:</b></td><td>'+x+'</td></tr>\n'

       h+='<tr><td><td></td></tr>\n'

    urlx=url0+'action=get&cid='+cfg['module_deps']['program.optimization']+':'+duid+'&scenario_module_uoa='+muid+'&out=json'
    urls=url0+'action=pull&common_action=yes&cid='+muid+':'+duid+'&filename=summary.json'
    urlc=url0+'action=pull&common_action=yes&cid='+muid+':'+duid+'&filename=classification.json'

    x=''
    if urls!='':
       x+='[ <a href="'+urls+'">All solutions in JSON</a> ]'
    if urlc!='':
       if x!='': x+=', '
       x+='[ <a href="'+urlc+'">Solutions\' classification in JSON</a> ]'

    if x!='':
       h+='<tr><td><b>Download:</b></td><td>'+x+'</td></tr>\n'

    h+='<tr><td><b>Reproduce all (with reactions):</b></td><td><i>ck replay '+cid+'</i></td></tr>\n'

    h+='<tr><td><td></td></tr>\n'

    pr=cfg.get('prune_results',[])
    mm=d.get('meta',{})
    em=d.get('extra_meta',{})
    obj=mm.get('objective','')

    for k in pr:
        qd=k.get('desc','')
        qi=k.get('id','')
        qr=k.get('ref_uid','')
        qm=k.get('ref_module_uoa','')

        x=mm.get(qi,'')
        if x!='' and qm!='' and qr!='':
           xuid=mm.get(qr,'')
           if xuid!='':
              x='<a href="'+url0+'wcid='+qm+':'+xuid+'">'+x+'</a>'

        h+='<tr><td><b>'+qd+'</b></td><td>'+x+'</td></tr>\n'

    h+='<tr><td><b>Objective</b></td><td>'+obj+'</td></tr>\n'

    h+='<tr><td></td><td></td></tr>\n'

    kk=0
    for kx in range(0, len(ik)):
        k=ik[kx]
        k1=k.replace('$#obj#$',obj)
        ik[kx]=k1

        if pdesc.get(k1,{}).get('desc','')!='':
           k1=pdesc[k1]['desc']

        kk+=1

        h+='<tr><td><b>Improvement key IK'+str(kk)+'</b></td><td>'+k1+'</td></tr>\n'

    ik0=ik[0] # first key to sort

    h+='</table>\n'

    h+='<p>\n'
    h+='<center>\n'

    bgraph={"0":[], "1":[]} # graph with highest improvements

    # Load summary
    sols=[]

    psum=os.path.join(p, fsummary)
    if os.path.isfile(psum):
       rx=ck.load_json_file({'json_file':psum})
       if rx['return']>0: return rx
       sols=rx['dict']

    # Load classification file
    classification={}
    pcl=os.path.join(p, fclassification)
    if os.path.isfile(pcl):
       rx=ck.load_json_file({'json_file':pcl})
       if rx['return']>0: return rx
       classification=rx['dict']

    h+='<p>\n'
    h+='$#graph#$\n'
    h+='<p>\n'

    # List solutions
    if len(sols)==0:
       h+='<h2>No distinct solutions found!</h2>\n'
    else:
       # Check host URL prefix and default module/action
       h+='<table class="ck_table" border="0">\n'

       h+=' <tr style="background-color:#cfcfff;">\n'
       h+='  <td colspan="1"></td>\n'
       h+='  <td colspan="1" style="background-color:#bfbfff;"></td>\n'
       h+='  <td colspan="'+str(len(ik))+'" align="center"><b>Improvements (<4% variation)</b></td>\n'
       h+='  <td colspan="2" align="center" style="background-color:#bfbfff;"></td>\n'
       h+='  <td colspan="4"></td>\n'
       h+='  <td colspan="4" align="center" style="background-color:#bfbfff;"><b>Distinct workload for highest improvement</b></td>\n'
       h+='  <td colspan="4"></td>\n'
       h+='  <td colspan="1" align="center" style="background-color:#bfbfff;"></td>\n'
       h+=' </tr>\n'

       h+=' <tr style="background-color:#cfcfff;">\n'
       h+='  <td><b>\n'
       h+='   #\n'
       h+='  </b></td>\n'
       h+='  <td style="background-color:#bfbfff;"><b>\n'
       h+='   Solution UID\n'
       h+='  </b></td>\n'

       for k in range(0, len(ik)):
           h+='  <td align="right"><b>\n'
           h+='   IK'+str(k+1)+'\n'
           h+='  </b></td>\n'

       h+='  <td align="center" style="background-color:#bfbfff;"><b>\n'
       h+='   New distinct optimization choices\n'
       h+='  </b></td>\n'
       h+='  <td align="center" style="background-color:#bfbfff;" align="right"><b>\n'
       h+='   Ref\n'
       h+='  </b></td>\n'

       h+='  <td align="center"><b>\n'
       h+='   Best species\n'
       h+='  </b></td>\n'
       h+='  <td align="center"><b>\n'
       h+='   Worst species\n'
       h+='  </b></td>\n'
       h+='  <td align="center"><b>\n'
       h+='   Touched\n'
       h+='  </b></td>\n'
       h+='  <td align="center"><b>\n'
       h+='   Iters\n'
       h+='  </b></td>\n'
       h+='  <td style="background-color:#bfbfff;"><b>\n'
       h+='   Program\n'
       h+='  </b></td>\n'
       h+='  <td style="background-color:#bfbfff;"><b>\n'
       h+='   CMD\n'
       h+='  </b></td>\n'
       h+='  <td style="background-color:#bfbfff;"><b>\n'
       h+='   Dataset\n'
       h+='  </b></td>\n'
       h+='  <td style="background-color:#bfbfff;"><b>\n'
       h+='   Dataset file\n'
       h+='  </b></td>\n'
       h+='  <td align="right"><b>\n'
       h+='   CPU freq (MHz)\n'
       h+='  </b></td>\n'
       h+='  <td align="right"><b>\n'
       h+='   Cores\n'
       h+='  </b></td>\n'
       h+='  <td><b>\n'
       h+='   Platform\n'
       h+='  </b></td>\n'
       h+='  <td><b>\n'
       h+='   OS\n'
       h+='  </b></td>\n'

       h+='  <td align="center" style="background-color:#bfbfff;">\n'
       h+='   <b>Replay</b>\n'
       h+='  </td>\n'

       h+=' </tr>\n'

       # List
       num=0
       iq=-1
       iq1=0

       res={}
       sres=[]
       ires=0

       em={}

       cls={}

       while iq1<len(sols): # already sorted by most "interesting" solutions (such as highest speedups)
           if iq!=iq1:
              num+=1

              iq+=1
              q=sols[iq]

              em=q.get('extra_meta',{})

              suid=q['solution_uid']

              cls=classification.get(suid,{})

              xcls=cls.get('highest_improvements_workload',{})
              program_uoa=xcls.get('program_uoa','')
              cmd=xcls.get('cmd_key','')
              dataset_uoa=xcls.get('dataset_uoa','')
              dataset_file=xcls.get('dataset_file','')

              wl_best=len(cls.get('best',[]))
              wl_worst=len(cls.get('worst',[]))

              url_wl=url0+'action=get_workloads&cid='+cfg['module_deps']['program.optimization']+':'+duid+'&scenario_module_uoa='+muid+'&solution_uid='+suid+'&out=json'
              url_wl_best=url_wl+'&key=best'
              url_wl_worst=url_wl+'&key=worst'

              res={}
              ref_res={}
              sres=[]
              ires=0

              # Try to load all solutions
              p1=os.path.join(p, suid)

              try:
                 dirList=os.listdir(p1)
              except Exception as e:
                  None
              else:
                  for fn in dirList:
                      if fn.startswith('ckp-') and fn.endswith('.flat.json'):
                         uid=fn[4:-10]

                         px=os.path.join(p1, fn)
                         rx=ck.load_json_file({'json_file':px})
                         if rx['return']>0: return rx
                         d1=rx['dict']

                         px=os.path.join(p1,'ckp-'+uid+'.features_flat.json')
                         if rx['return']>0: return rx
                         d2=rx['dict']

                         x={'flat':d1, 'features_flat':d2}

                         px=os.path.join(p1, 'ckp-'+uid+'.features.json')
                         rx=ck.load_json_file({'json_file':px})
                         if rx['return']>0: return rx
                         dx=rx['dict']

                         if dx.get('permanent','')=='yes':
                            ref_res==x
                         else:
                            res[uid]=x
                         
                  rr=list(res.keys())
                  sres=sorted(rr, key=lambda v: (float(res[v].get('flat',{}).get(ik0,0.0))), reverse=True)

           rr={}
           if ires<len(sres):
              rr=res.get(sres[ires],{})
              ires+=1

              iterations=q.get('iterations',1)
              touched=q.get('touched',1)

              choices=q['choices']

              ref_sol=q.get('ref_choices',{})
              ref_sol_order=q.get('ref_choices_order',[])

              target_os=choices.get('target_os','')

              speedup=''

              cmd1=''
              cmd2=''

              ss='S'+str(num)
              h+=' <tr>\n'
              h+='  <td valign="top" style="background-color:#efefff;">\n'
              if ires<2:

                 h+='   '+ss+'\n'
              h+='  </td>\n'

              h+='  <td valign="top">\n'
              if ires<2 and urlx!='':
                 h+='   <a href="'+urlx+'&solution_uid='+suid+'">'+suid+'</a>\n'
              h+='  </td>\n'

              for k in range(0, len(ik)):
                  h+='  <td valign="top" align="right" style="background-color:#efefff;">\n'

#                  dv=rr.get('flat',{}).get(ik[k],'')

                  dv=''
                  dvw=''
                  points=q.get('points',[])
                  iresx=ires-1
#                  if iresx<len(points):
#                     dv=points[iresx].get('improvements_best',{}).get(ik[k],'')
#                     dvw=points[iresx].get('improvements_worst',{}).get(ik[k],'')

                  # Add to graph (first dimension and first solution)
#                  if k==0 and ires==1:
                  dv=cls.get('highest_improvements',{}).get(ik[k],'')
                  dvw=cls.get('highest_degradations',{}).get(ik[k],'')

                  if k==0:
                     bgraph['0'].append([ss,dv])
                     bgraph['1'].append([ss,dvw])

                  y=''
                  if type(dv)==int or type(dv)==ck.type_long:
                     y=str(dv)
                  else:
                     try:
                        y=('%.2f' % dv)
                     except Exception as e: 
                        y=dv
                        pass

                  if dv!='':
                     if dv>1.0:
                        y='<span style="color:#bf0000">'+y+'</span>'
                     elif dv!=0:
                        y='<span style="color:#0000bf">'+y+'</span>'
                  

                  h+=str(y)+'\n'
                  h+='  </td>\n'


              h+='  <td valign="top">\n'
              dv=rr.get('flat',{}).get('##characteristics#compile#joined_compiler_flags#min','')
              h+='   '+dv+'\n'
              h+='  </td>\n'

              h+='  <td valign="top" align="center">\n'
              if ires<2:
                 # Ideally should add pipeline description somewhere
                 # to properly recreate flags. However since it is most of the time -Ox
                 # we don't need to make it complex at the moment 

                 ry=rebuild_cmd({'choices':ref_sol,
                                 'choices_order':ref_sol_order,
                                 'choices_desc':{}})
                 if ry['return']>0: return ry
                 ref=ry['cmd']

                 h+='   '+ref+'\n'
              h+='   \n'
              h+='  </td>\n'

              h+='  <td valign="top" align="center" style="background-color:#efefff;">\n'
              if ires<2:
                 h+='   <a href="'+url_wl_best+'">'+str(wl_best)+'</a>\n'
              h+='  </td>\n'

              h+='  <td valign="top" align="center" style="background-color:#efefff;">\n'
              if ires<2:
                 h+='   <a href="'+url_wl_worst+'">'+str(wl_worst)+'</a>\n'
              h+='  </td>\n'

              h+='  <td valign="top" align="center" style="background-color:#efefff;">\n'
              if ires<2:
                 h+='   '+str(touched)+'\n'
              h+='  </td>\n'

              h+='  <td valign="top" align="center" style="background-color:#efefff;">\n'
              if ires<2:
                 h+='   '+str(iterations)+'\n'
              h+='  </td>\n'

              h+='  <td valign="top">\n'
              if ires<2:
                 h+='   <a href="'+url0+'wcid=program:'+program_uoa+'">'+program_uoa+'</a>\n'
              h+='  </td>\n'

              h+='  <td valign="top">\n'
              if ires<2:
                 h+='   '+cmd+'\n'
              h+='  </td>\n'

              h+='  <td valign="top">\n'
              if ires<2:
                 h+='   <a href="'+url0+'wcid=dataset:'+dataset_uoa+'">'+dataset_uoa+'</a>\n'
              h+='  </td>\n'

              h+='  <td valign="top">\n'
              if ires<2:
                 h+='   <a href="'+url0+'action=pull&common_func=yes&cid=dataset:'+dataset_uoa+'&filename='+dataset_file+'">'+dataset_file+'</a>\n'
              h+='  </td>\n'

#              h+='  <td valign="top" align="right">\n'
#              if ires<2:
#                 h+='   '+str(em.get('kernel_repetitions',-1))+'\n'
#              h+='  </td>\n'

              h+='  <td valign="top" align="right" style="background-color:#efefff;">\n'
              if ires<2:
                 x=''
                 qq=em.get('cpu_cur_freq',[])
                 for q in qq:
                     xq=qq[q]
                     if x!='': x+=', '
                     x+=str(xq)
                 h+='   '+x+'\n'
              h+='  </td>\n'

              h+='  <td valign="top" align="right" style="background-color:#efefff;">\n'
              if ires<2:
                 qq=em.get('cpu_num_proc',1)
                 h+='   '+str(qq)+'\n'
              h+='  </td>\n'

              h+='  <td valign="top" style="background-color:#efefff;">\n'
              if ires<2:
                 h+='   '+str(em.get('platform_name',''))+'\n'
              h+='  </td>\n'

              h+='  <td valign="top" style="background-color:#efefff;">\n'
              if ires<2:
                 h+='   '+str(em.get('os_name',''))+'\n'
              h+='  </td>\n'

              x='ck replay '+cid+' --solution_uid='+suid
              y=ck.cfg.get('add_extra_to_replay','')
              if y!='':x+=' '+y
              h+='    <td valign="top" align="center"><input type="button" class="ck_small_button" style="height:60px;" onClick="copyToClipboard(\''+x+'\');" value="Copy \nto \nclipboard"></td>\n'

              h+=' </tr>\n'

           else:
              iq1+=1


       h+='</table>\n'
       h+='<br><a href="http://arxiv.org/abs/1506.06256"><img src="'+url0+'action=pull&common_action=yes&cid='+cfg['module_deps']['module']+':'+orig_module_uid+'&filename=images/image-workflow1.png"></a>\n'

    h+='</center>\n'

    h+='<br><br>\n'

    rx=ck.access({'action':'links',
                  'module_uoa':cfg['module_deps']['program.optimization']})
    if rx['return']>0: return rx
    h+=rx['html']

    # Plot graph
    hg=''
    ftmp=''

    if len(bgraph['0'])>0:
       ii={'action':'plot',
           'module_uoa':cfg['module_deps']['graph'],

           "table":bgraph,

           "h_lines":[1.0],

           "ymin":0,

           "ignore_point_if_none":"yes",

           "plot_type":"d3_2d_bars",

           "display_y_error_bar":"no",

           "title":"Powered by Collective Knowledge",

           "axis_x_desc":"Distinct optimization solutions (highest improvement vs highest degradation)",
           "axis_y_desc":"Max improvement ( IK1 = Ref / Solution )",

           "plot_grid":"yes",

           "d3_div":"ck_interactive",

           "image_width":"900",
           "image_height":"400",

           "wfe_url":url0}

       # Trick to save to file (for interactive/live articles)
       if ap.get('fgg_save_graph_to_file','')=='yes':
          import copy
          iii=copy.deepcopy(ii)
          iii["substitute_x_with_loop"]="yes"
          iii["plot_type"]="mpl_2d_bars" 
          if 'ymin' in iii: del(iii['ymin'])
          if 'ymax' in iii: del(iii['ymax'])

          # Prepare batch file
          rx=ck.gen_tmp_file({'prefix':'tmp-', 'suffix':'.json'})
          if rx['return']>0: return rx
          ftmp=rx['file_name']

          rx=ck.save_json_to_file({'json_file':ftmp, 'dict':iii, 'sort_keys':'yes'})
          if rx['return']>0: return rx

       r=ck.access(ii)
       if r['return']==0:
          x=r.get('html','')
          if x!='':
             st=r.get('style','')

             hg='<div id="ck_box_with_shadow" style="width:920px;">\n'
             if ftmp!='':
                hg+='<center><b>Note: graph info has been saved to file '+ftmp+' for interactive publications</b></center>'
             hg+=' <div id="ck_interactive" style="text-align:center">\n'
             hg+=x+'\n'
             hg+=' </div>\n'
             hg+='</div>\n'

    h=h.replace('$#graph#$', hg)

    return {'return':0, 'html':h, 'style':st}

Example 31

Project: ck-wa
Source File: module.py
View license
def show(i):
    """
    Input:  {
               (crowd_module_uoa) - if rendered from experiment crowdsourcing
               (crowd_key)        - add extra name to Web keys to avoid overlapping with original crowdsourcing HTML
               (crowd_on_change)  - reuse onchange doc from original crowdsourcing HTML
            }

    Output: {
              return       - return code =  0, if successful
                                         >  0, if error
              (error)      - error text if return > 0
            }

    """

    import os

    st=''

    cmuoa=i.get('crowd_module_uoa','')
    ckey=i.get('crowd_key','')

    conc=i.get('crowd_on_change','')
    if conc=='':
        conc=onchange

    h='<hr>\n'
    h+='<center>\n'
    h+='\n\n<script language="JavaScript">function copyToClipboard (text) {window.prompt ("Copy to clipboard: Ctrl+C, Enter", text);}</script>\n\n' 

    h+='<h2>All WA results</h2>\n'

    h+=hextra

    # Check host URL prefix and default module/action
    rx=ck.access({'action':'form_url_prefix',
                  'module_uoa':'wfe',
                  'host':i.get('host',''), 
                  'port':i.get('port',''), 
                  'template':i.get('template','')})
    if rx['return']>0: return rx
    url0=rx['url']
    template=rx['template']

    url=url0
    action=i.get('action','')
    muoa=i.get('module_uoa','')

    st=''

    url+='action=index&module_uoa=wfe&native_action='+action+'&'+'native_module_uoa='+muoa
    url1=url

    # List entries
    ii={'action':'search',
        'module_uoa':work['self_module_uid'],
        'add_meta':'yes'}

    if cmuoa!='':
        ii['module_uoa']=cmuoa

    r=ck.access(ii)
    if r['return']>0: return r

    lst=r['lst']

    # Check unique entries
    choices={}
    wchoices={}

    for q in lst:
        d=q['meta']
        meta=d.get('meta',{})

        for kk in selector:
            kx=kk['key']
            k=ckey+kx

            if k not in choices: 
                choices[k]=[]
                wchoices[k]=[{'name':'','value':''}]

            v=meta.get(kx,'')
            if v!='':
                if v not in choices[k]: 
                    choices[k].append(v)
                    wchoices[k].append({'name':v, 'value':v})

    # Prepare query div ***************************************************************
    if cmuoa=='':
        # Start form + URL (even when viewing entry)
        r=ck.access({'action':'start_form',
                     'module_uoa':cfg['module_deps']['wfe'],
                     'url':url1,
                     'name':form_name})
        if r['return']>0: return r
        h+=r['html']

    for kk in selector:
        k=ckey+kk['key']
        n=kk['name']

        nl=kk.get('new_line','')
        if nl=='yes':
            h+='<br>\n<div id="ck_entries_space8"></div>\n'

        v=''
        if i.get(k,'')!='':
            v=i[k]
            kk['value']=v

        # Show hardware
        ii={'action':'create_selector',
            'module_uoa':cfg['module_deps']['wfe'],
            'data':wchoices.get(k,[]),
            'name':k,
            'onchange':conc, 
            'skip_sort':'no',
            'selected_value':v}
        r=ck.access(ii)
        if r['return']>0: return r

        h+='<b>'+n+':</b> '+r['html'].strip()+'\n'

    h+='<br><br>'

    # Prune list
    plst=[]
    for q in lst:
        d=q['meta']
        meta=d.get('meta',{})

        # Check selector
        skip=False
        for kk in selector:
            k=kk['key']
            n=kk['name']
            v=kk.get('value','')

            if v!='' and meta.get(k,'')!=v:
                skip=True

        if not skip:
            plst.append(q)

    # Check if too many
    lplst=len(plst)
    if lplst==0:
        h+='<b>No results found!</b>'
        return {'return':0, 'html':h, 'style':st}
    elif lplst>50:
        h+='<b>Too many entries to show ('+str(lplst)+') - please, prune list further!</b>'
        return {'return':0, 'html':h, 'style':st}

    # Prepare table
    h+='<table border="1" cellpadding="7" cellspacing="0">\n'

    ha='align="center" valign="top"'
    hb='align="left" valign="top"'

    h+='  <tr style="background-color:#dddddd">\n'
    h+='   <td '+ha+'><b>All raw files</b></td>\n'
    h+='   <td '+ha+'><b>Workload</b></td>\n'
    h+='   <td '+ha+'><b>Scenario</b></td>\n'
    h+='   <td '+ha+'><b>Platform</b></td>\n'
    h+='   <td '+ha+'><b>serial number / adb device ID</b></td>\n'
    h+='   <td '+ha+'><b>CPU</b></td>\n'
    h+='   <td '+ha+'><b>GPU</b></td>\n'
    h+='   <td '+ha+'><b>OS</b></td>\n'
    h+='   <td '+ha+'><b>APK</b></td>\n'
    h+='   <td '+ha+'><b>WA version</b></td>\n'
    h+='   <td '+ha+'><b>Fail?</b></td>\n'
    h+='   <td '+hb+'><b>Choices</b></td>\n'
    h+='   <td '+hb+'><b>Characteristics</b></td>\n'
    h+='   <td '+ha+'><b>JSON results</b></td>\n'
    h+='   <td '+ha+'><b>Replay</b></td>\n'
    h+='  <tr>\n'

    # Dictionary to hold target meta
    tm={}

    ix=0
    bgraph={"0":[]} # Just for graph demo

    for q in sorted(plst, key=lambda x: x.get('meta',{}).get('meta',{}).get('workload_name','')):
        ix+=1

        duid=q['data_uid']
        path=q['path']

        d=q['meta']

        meta=d.get('meta',{})

        params=d.get('choices',{}).get('params',{}).get('params',{})

        pname=meta.get('program_uoa','')
        wname=meta.get('workload_name','')
        wuid=meta.get('program_uid','')

        apk_name=meta.get('apk_name','')
        apk_ver=meta.get('apk_version','')

        wa_ver=meta.get('wa_version','')

        scenario=meta.get('scenario','')

        ltarget_uoa=meta.get('local_target_uoa','')
        ltarget_uid=meta.get('local_target_uid','')

        if ltarget_uid!='' and ltarget_uid not in tm:
            # Load machine meta
            rx=ck.access({'action':'load',
                          'module_uoa':cfg['module_deps']['machine'],
                          'data_uoa':ltarget_uid})
            if rx['return']==0:
                tm[ltarget_uid]=rx['dict']

        plat_name=meta.get('plat_name','')
        cpu_name=meta.get('cpu_name','')
        os_name=meta.get('os_name','')
        gpu_name=meta.get('gpu_name','')

        adb_id=tm.get(ltarget_uid,{}).get('device_id','')
        sn=meta.get('serial_number','')

        te=d.get('characteristics',{}).get('run',{})
        tet=te.get('total_execution_time',0)

        bgc='afffaf'
        fail=d.get('state',{}).get('fail','')
        fail_reason=d.get('state',{}).get('fail_reason','')
        if fail=='yes':
            if fail_reason=='': fail_reason='yes'

            bgc='ffafaf'
        else:
            if i.get(ckey+'workload_name','')!='' and i.get(ckey+'scenario','')!='':
                bgraph['0'].append([ix,tet])

        bg=' style="background-color:#'+bgc+';"'

        h+='  <tr'+bg+'>\n'

        x=work['self_module_uid']
        if cmuoa!='': x=cmuoa
        h+='   <td '+ha+'>'+str(ix)+')&nbsp;<a href="'+url0+'&wcid='+x+':'+duid+'">'+duid+'</a></td>\n'

        x=wname
        if wuid!='': x='<a href="'+url0+'&wcid='+cfg['module_deps']['program']+':'+wuid+'">'+x+'</a>'
        h+='   <td '+ha+'>'+x+'</td>\n'

        x=''
        if scenario!='':
            x='<a href="'+url0+'&wcid='+cfg['module_deps']['wa-scenario']+':'+scenario+'">'+scenario+'</a>'
        h+='   <td '+ha+'>'+x+'</td>\n'

        x=plat_name
        if ltarget_uid!='':
           x='<a href="'+url0+'&wcid='+cfg['module_deps']['machine']+':'+ltarget_uid+'">'+x+'</a>'
        h+='   <td '+ha+'>'+x+'</td>\n'

        x=sn
        if adb_id!='' and adb_id!=sn: x+=' / '+adb_id
        h+='   <td '+ha+'>'+x+'</td>\n'

        h+='   <td '+ha+'>'+cpu_name+'</td>\n'
        h+='   <td '+ha+'>'+gpu_name+'</td>\n'
        h+='   <td '+ha+'>'+os_name+'</td>\n'

        # APK
        x=apk_name
        if apk_ver!='': x+=' (V'+apk_ver+')'
#        x=x.replace("'","\'").replace('"',"\\'").replace('\n','\\n')
        x=x.replace("\'","'").replace("'","\\'").replace('\"','"').replace('"',"\\'").replace('\n','\\n')

        x1=''
        if x!='':
            x1='<input type="button" class="ck_small_button" onClick="alert(\''+x+'\');" value="See">'

        h+='   <td '+ha+'>'+x1+'</td>\n'

        h+='   <td '+ha+'>'+wa_ver+'</td>\n'

        x=fail_reason
        if x=='': 
            x='No'
        else:
            fail_reason=fail_reason.replace("\'","'").replace("'","\\'").replace('\"','"').replace('"',"\\'").replace('\n','\\n')
            x='Yes <input type="button" class="ck_small_button" onClick="alert(\''+fail_reason+'\');" value="Log">'

        h+='   <td '+ha+'>'+x+'</td>\n'

        # Params
#        x='<table border="0" cellpadding="0" cellspacing="2">\n'
        x=''
        for k in sorted(params):
            v=params[k]
            x+=str(k)+'='+str(v)+'\n'
#            x+='<tr><td>'+str(k)+'=</td><td>'+str(v)+'</td></tr>\n'
#        x+='</table>\n'
#        x=x.replace("'","\'").replace('"',"\\'").replace('\n','\\n')
        x=x.replace("\'","'").replace("'","\\'").replace('\"','"').replace('"',"\\'").replace('\n','\\n')

        x1=''
        if x!='':
            x1='<input type="button" class="ck_small_button" onClick="alert(\''+x+'\');" value="See">'

        h+='   <td '+hb+'>'+x1+'</td>\n'

        # Characteristics
        # Check if has statistics
        dstat={}
        fstat=os.path.join(path,'ck-stat-flat-characteristics.json')
        if os.path.isfile(fstat):
            r=ck.load_json_file({'json_file':fstat, 'dict':dstat})
            if r['return']>0: return r
            dstat=r['dict']

        x=''
        if tet>0: x=('%.1f'%tet)+' sec.'

        # Check if has stats
        x1=dstat.get("##characteristics#run#total_execution_time#center",None)
        x2=dstat.get("##characteristics#run#total_execution_time#halfrange",None)
        if x1!=None and x2!=None:
            x=('%.1f'%x1)+' &PlusMinus; '+('%.1f'%x2)+' sec.'

        # Check all
        x5=''
        for k in sorted(te):
            v=te[k]

            kx="##characteristics#run#"+k

            kx1=dstat.get(kx+'#center',None)
            kx2=dstat.get(kx+'#halfrange',None)

            x6=''
            if type(v)==int:
                if kx1!=None and kx2!=None:
                    x6=str(kx1)+' +- '+str(kx2)
                else:
                    x6=str(v)
            elif type(v)==float:
                if kx1!=None and kx2!=None:
                    x6=('%.1f'%kx1)+' +- '+('%.1f'%kx2)
                else:
                    x6=('%.1f'%v)

            if x6!='':
                x5+=str(k)+'='+x6+'\n'

#        x5=x5.replace("'","\'").replace('"',"\\'").replace('\n','\\n')
        x5=x5.replace("\'","'").replace("'","\\'").replace('\"','"').replace('"',"\\'").replace('\n','\\n')
        if x5!='':
            x+='<br><input type="button" class="ck_small_button" onClick="alert(\''+x5+'\');" value="All">'

        h+='   <td '+ha+'>'+x+'</td>\n'

        # Check directories with results
        x=''
        xf1='wa-output'
        xf2='results.json'
        xf=xf1+'/'+xf2
        for d0 in os.listdir(path):
            found=False
            brk=False

            d1=os.path.join(d0,xf)
            d2=os.path.join(path,d1)

            if os.path.isfile(d2):
                found=True
            else:
                d1=xf
                d2=os.path.join(path,d1)

                if os.path.isfile(d2):
                    d0=xf1
                    found=True
                    brk=True

            if found:
                if x!='': x+='<br>\n'
                x1=work['self_module_uid']
                if cmuoa!='':
                    x1=cmuoa
                x+='[&nbsp;<a href="'+url0+'action=pull&common_action=yes&cid='+x1+':'+duid+'&filename='+d1+'">'+d0+'</a>&nbsp;]\n'

                if brk:
                    break
        h+='   <td '+ha+'>'+x+'</td>\n'

        h+='   <td '+ha+'><input type="button" class="ck_small_button" onClick="copyToClipboard(\'ck replay wa:'+wname+'\');" value="Replay"></td>\n'

        h+='  <tr>\n'

    h+='</table>\n'
    h+='</center>\n'

    if cmuoa=='':
        h+='</form>\n'

    if len(bgraph['0'])>0:
       ii={'action':'plot',
           'module_uoa':cfg['module_deps']['graph'],

           "table":bgraph,

           "h_lines":[1.0],

           "ymin":0,

           "ignore_point_if_none":"yes",

           "plot_type":"d3_2d_bars",

           "display_y_error_bar":"no",

           "title":"Powered by Collective Knowledge",

           "axis_x_desc":"Platform",
           "axis_y_desc":"Execution time (sec.)",

           "plot_grid":"yes",

           "d3_div":"ck_interactive",

           "image_width":"900",
           "image_height":"400",

           "wfe_url":url0}

       r=ck.access(ii)
       if r['return']==0:
          x=r.get('html','')
          if x!='':
             st+=r.get('style','')

             h+='<br>\n'
             h+='<center>\n'
             h+='<div id="ck_box_with_shadow" style="width:920px;">\n'
             h+=' <div id="ck_interactive" style="text-align:center">\n'
             h+=x+'\n'
             h+=' </div>\n'
             h+='</div>\n'
             h+='</center>\n'

    return {'return':0, 'html':h, 'style':st}

Example 32

Project: ck-wa
Source File: module.py
View license
def show(i):
    """
    Input:  {
               (crowd_module_uoa) - if rendered from experiment crowdsourcing
               (crowd_key)        - add extra name to Web keys to avoid overlapping with original crowdsourcing HTML
               (crowd_on_change)  - reuse onchange doc from original crowdsourcing HTML
            }

    Output: {
              return       - return code =  0, if successful
                                         >  0, if error
              (error)      - error text if return > 0
            }

    """

    import os

    st=''

    cmuoa=i.get('crowd_module_uoa','')
    ckey=i.get('crowd_key','')

    conc=i.get('crowd_on_change','')
    if conc=='':
        conc=onchange

    h='<hr>\n'
    h+='<center>\n'
    h+='\n\n<script language="JavaScript">function copyToClipboard (text) {window.prompt ("Copy to clipboard: Ctrl+C, Enter", text);}</script>\n\n' 

    h+='<h2>All WA results</h2>\n'

    h+=hextra

    # Check host URL prefix and default module/action
    rx=ck.access({'action':'form_url_prefix',
                  'module_uoa':'wfe',
                  'host':i.get('host',''), 
                  'port':i.get('port',''), 
                  'template':i.get('template','')})
    if rx['return']>0: return rx
    url0=rx['url']
    template=rx['template']

    url=url0
    action=i.get('action','')
    muoa=i.get('module_uoa','')

    st=''

    url+='action=index&module_uoa=wfe&native_action='+action+'&'+'native_module_uoa='+muoa
    url1=url

    # List entries
    ii={'action':'search',
        'module_uoa':work['self_module_uid'],
        'add_meta':'yes'}

    if cmuoa!='':
        ii['module_uoa']=cmuoa

    r=ck.access(ii)
    if r['return']>0: return r

    lst=r['lst']

    # Check unique entries
    choices={}
    wchoices={}

    for q in lst:
        d=q['meta']
        meta=d.get('meta',{})

        for kk in selector:
            kx=kk['key']
            k=ckey+kx

            if k not in choices: 
                choices[k]=[]
                wchoices[k]=[{'name':'','value':''}]

            v=meta.get(kx,'')
            if v!='':
                if v not in choices[k]: 
                    choices[k].append(v)
                    wchoices[k].append({'name':v, 'value':v})

    # Prepare query div ***************************************************************
    if cmuoa=='':
        # Start form + URL (even when viewing entry)
        r=ck.access({'action':'start_form',
                     'module_uoa':cfg['module_deps']['wfe'],
                     'url':url1,
                     'name':form_name})
        if r['return']>0: return r
        h+=r['html']

    for kk in selector:
        k=ckey+kk['key']
        n=kk['name']

        nl=kk.get('new_line','')
        if nl=='yes':
            h+='<br>\n<div id="ck_entries_space8"></div>\n'

        v=''
        if i.get(k,'')!='':
            v=i[k]
            kk['value']=v

        # Show hardware
        ii={'action':'create_selector',
            'module_uoa':cfg['module_deps']['wfe'],
            'data':wchoices.get(k,[]),
            'name':k,
            'onchange':conc, 
            'skip_sort':'no',
            'selected_value':v}
        r=ck.access(ii)
        if r['return']>0: return r

        h+='<b>'+n+':</b> '+r['html'].strip()+'\n'

    h+='<br><br>'

    # Prune list
    plst=[]
    for q in lst:
        d=q['meta']
        meta=d.get('meta',{})

        # Check selector
        skip=False
        for kk in selector:
            k=kk['key']
            n=kk['name']
            v=kk.get('value','')

            if v!='' and meta.get(k,'')!=v:
                skip=True

        if not skip:
            plst.append(q)

    # Check if too many
    lplst=len(plst)
    if lplst==0:
        h+='<b>No results found!</b>'
        return {'return':0, 'html':h, 'style':st}
    elif lplst>50:
        h+='<b>Too many entries to show ('+str(lplst)+') - please, prune list further!</b>'
        return {'return':0, 'html':h, 'style':st}

    # Prepare table
    h+='<table border="1" cellpadding="7" cellspacing="0">\n'

    ha='align="center" valign="top"'
    hb='align="left" valign="top"'

    h+='  <tr style="background-color:#dddddd">\n'
    h+='   <td '+ha+'><b>All raw files</b></td>\n'
    h+='   <td '+ha+'><b>Workload</b></td>\n'
    h+='   <td '+ha+'><b>Scenario</b></td>\n'
    h+='   <td '+ha+'><b>Platform</b></td>\n'
    h+='   <td '+ha+'><b>serial number / adb device ID</b></td>\n'
    h+='   <td '+ha+'><b>CPU</b></td>\n'
    h+='   <td '+ha+'><b>GPU</b></td>\n'
    h+='   <td '+ha+'><b>OS</b></td>\n'
    h+='   <td '+ha+'><b>APK</b></td>\n'
    h+='   <td '+ha+'><b>WA version</b></td>\n'
    h+='   <td '+ha+'><b>Fail?</b></td>\n'
    h+='   <td '+hb+'><b>Choices</b></td>\n'
    h+='   <td '+hb+'><b>Characteristics</b></td>\n'
    h+='   <td '+ha+'><b>JSON results</b></td>\n'
    h+='   <td '+ha+'><b>Replay</b></td>\n'
    h+='  <tr>\n'

    # Dictionary to hold target meta
    tm={}

    ix=0
    bgraph={"0":[]} # Just for graph demo

    for q in sorted(plst, key=lambda x: x.get('meta',{}).get('meta',{}).get('workload_name','')):
        ix+=1

        duid=q['data_uid']
        path=q['path']

        d=q['meta']

        meta=d.get('meta',{})

        params=d.get('choices',{}).get('params',{}).get('params',{})

        pname=meta.get('program_uoa','')
        wname=meta.get('workload_name','')
        wuid=meta.get('program_uid','')

        apk_name=meta.get('apk_name','')
        apk_ver=meta.get('apk_version','')

        wa_ver=meta.get('wa_version','')

        scenario=meta.get('scenario','')

        ltarget_uoa=meta.get('local_target_uoa','')
        ltarget_uid=meta.get('local_target_uid','')

        if ltarget_uid!='' and ltarget_uid not in tm:
            # Load machine meta
            rx=ck.access({'action':'load',
                          'module_uoa':cfg['module_deps']['machine'],
                          'data_uoa':ltarget_uid})
            if rx['return']==0:
                tm[ltarget_uid]=rx['dict']

        plat_name=meta.get('plat_name','')
        cpu_name=meta.get('cpu_name','')
        os_name=meta.get('os_name','')
        gpu_name=meta.get('gpu_name','')

        adb_id=tm.get(ltarget_uid,{}).get('device_id','')
        sn=meta.get('serial_number','')

        te=d.get('characteristics',{}).get('run',{})
        tet=te.get('total_execution_time',0)

        bgc='afffaf'
        fail=d.get('state',{}).get('fail','')
        fail_reason=d.get('state',{}).get('fail_reason','')
        if fail=='yes':
            if fail_reason=='': fail_reason='yes'

            bgc='ffafaf'
        else:
            if i.get(ckey+'workload_name','')!='' and i.get(ckey+'scenario','')!='':
                bgraph['0'].append([ix,tet])

        bg=' style="background-color:#'+bgc+';"'

        h+='  <tr'+bg+'>\n'

        x=work['self_module_uid']
        if cmuoa!='': x=cmuoa
        h+='   <td '+ha+'>'+str(ix)+')&nbsp;<a href="'+url0+'&wcid='+x+':'+duid+'">'+duid+'</a></td>\n'

        x=wname
        if wuid!='': x='<a href="'+url0+'&wcid='+cfg['module_deps']['program']+':'+wuid+'">'+x+'</a>'
        h+='   <td '+ha+'>'+x+'</td>\n'

        x=''
        if scenario!='':
            x='<a href="'+url0+'&wcid='+cfg['module_deps']['wa-scenario']+':'+scenario+'">'+scenario+'</a>'
        h+='   <td '+ha+'>'+x+'</td>\n'

        x=plat_name
        if ltarget_uid!='':
           x='<a href="'+url0+'&wcid='+cfg['module_deps']['machine']+':'+ltarget_uid+'">'+x+'</a>'
        h+='   <td '+ha+'>'+x+'</td>\n'

        x=sn
        if adb_id!='' and adb_id!=sn: x+=' / '+adb_id
        h+='   <td '+ha+'>'+x+'</td>\n'

        h+='   <td '+ha+'>'+cpu_name+'</td>\n'
        h+='   <td '+ha+'>'+gpu_name+'</td>\n'
        h+='   <td '+ha+'>'+os_name+'</td>\n'

        # APK
        x=apk_name
        if apk_ver!='': x+=' (V'+apk_ver+')'
#        x=x.replace("'","\'").replace('"',"\\'").replace('\n','\\n')
        x=x.replace("\'","'").replace("'","\\'").replace('\"','"').replace('"',"\\'").replace('\n','\\n')

        x1=''
        if x!='':
            x1='<input type="button" class="ck_small_button" onClick="alert(\''+x+'\');" value="See">'

        h+='   <td '+ha+'>'+x1+'</td>\n'

        h+='   <td '+ha+'>'+wa_ver+'</td>\n'

        x=fail_reason
        if x=='': 
            x='No'
        else:
            fail_reason=fail_reason.replace("\'","'").replace("'","\\'").replace('\"','"').replace('"',"\\'").replace('\n','\\n')
            x='Yes <input type="button" class="ck_small_button" onClick="alert(\''+fail_reason+'\');" value="Log">'

        h+='   <td '+ha+'>'+x+'</td>\n'

        # Params
#        x='<table border="0" cellpadding="0" cellspacing="2">\n'
        x=''
        for k in sorted(params):
            v=params[k]
            x+=str(k)+'='+str(v)+'\n'
#            x+='<tr><td>'+str(k)+'=</td><td>'+str(v)+'</td></tr>\n'
#        x+='</table>\n'
#        x=x.replace("'","\'").replace('"',"\\'").replace('\n','\\n')
        x=x.replace("\'","'").replace("'","\\'").replace('\"','"').replace('"',"\\'").replace('\n','\\n')

        x1=''
        if x!='':
            x1='<input type="button" class="ck_small_button" onClick="alert(\''+x+'\');" value="See">'

        h+='   <td '+hb+'>'+x1+'</td>\n'

        # Characteristics
        # Check if has statistics
        dstat={}
        fstat=os.path.join(path,'ck-stat-flat-characteristics.json')
        if os.path.isfile(fstat):
            r=ck.load_json_file({'json_file':fstat, 'dict':dstat})
            if r['return']>0: return r
            dstat=r['dict']

        x=''
        if tet>0: x=('%.1f'%tet)+' sec.'

        # Check if has stats
        x1=dstat.get("##characteristics#run#total_execution_time#center",None)
        x2=dstat.get("##characteristics#run#total_execution_time#halfrange",None)
        if x1!=None and x2!=None:
            x=('%.1f'%x1)+' &PlusMinus; '+('%.1f'%x2)+' sec.'

        # Check all
        x5=''
        for k in sorted(te):
            v=te[k]

            kx="##characteristics#run#"+k

            kx1=dstat.get(kx+'#center',None)
            kx2=dstat.get(kx+'#halfrange',None)

            x6=''
            if type(v)==int:
                if kx1!=None and kx2!=None:
                    x6=str(kx1)+' +- '+str(kx2)
                else:
                    x6=str(v)
            elif type(v)==float:
                if kx1!=None and kx2!=None:
                    x6=('%.1f'%kx1)+' +- '+('%.1f'%kx2)
                else:
                    x6=('%.1f'%v)

            if x6!='':
                x5+=str(k)+'='+x6+'\n'

#        x5=x5.replace("'","\'").replace('"',"\\'").replace('\n','\\n')
        x5=x5.replace("\'","'").replace("'","\\'").replace('\"','"').replace('"',"\\'").replace('\n','\\n')
        if x5!='':
            x+='<br><input type="button" class="ck_small_button" onClick="alert(\''+x5+'\');" value="All">'

        h+='   <td '+ha+'>'+x+'</td>\n'

        # Check directories with results
        x=''
        xf1='wa-output'
        xf2='results.json'
        xf=xf1+'/'+xf2
        for d0 in os.listdir(path):
            found=False
            brk=False

            d1=os.path.join(d0,xf)
            d2=os.path.join(path,d1)

            if os.path.isfile(d2):
                found=True
            else:
                d1=xf
                d2=os.path.join(path,d1)

                if os.path.isfile(d2):
                    d0=xf1
                    found=True
                    brk=True

            if found:
                if x!='': x+='<br>\n'
                x1=work['self_module_uid']
                if cmuoa!='':
                    x1=cmuoa
                x+='[&nbsp;<a href="'+url0+'action=pull&common_action=yes&cid='+x1+':'+duid+'&filename='+d1+'">'+d0+'</a>&nbsp;]\n'

                if brk:
                    break
        h+='   <td '+ha+'>'+x+'</td>\n'

        h+='   <td '+ha+'><input type="button" class="ck_small_button" onClick="copyToClipboard(\'ck replay wa:'+wname+'\');" value="Replay"></td>\n'

        h+='  <tr>\n'

    h+='</table>\n'
    h+='</center>\n'

    if cmuoa=='':
        h+='</form>\n'

    if len(bgraph['0'])>0:
       ii={'action':'plot',
           'module_uoa':cfg['module_deps']['graph'],

           "table":bgraph,

           "h_lines":[1.0],

           "ymin":0,

           "ignore_point_if_none":"yes",

           "plot_type":"d3_2d_bars",

           "display_y_error_bar":"no",

           "title":"Powered by Collective Knowledge",

           "axis_x_desc":"Platform",
           "axis_y_desc":"Execution time (sec.)",

           "plot_grid":"yes",

           "d3_div":"ck_interactive",

           "image_width":"900",
           "image_height":"400",

           "wfe_url":url0}

       r=ck.access(ii)
       if r['return']==0:
          x=r.get('html','')
          if x!='':
             st+=r.get('style','')

             h+='<br>\n'
             h+='<center>\n'
             h+='<div id="ck_box_with_shadow" style="width:920px;">\n'
             h+=' <div id="ck_interactive" style="text-align:center">\n'
             h+=x+'\n'
             h+=' </div>\n'
             h+='</div>\n'
             h+='</center>\n'

    return {'return':0, 'html':h, 'style':st}

Example 33

Project: ck-wa
Source File: module.py
View license
def show(i):
    """
    Input:  {
               (crowd_module_uoa) - if rendered from experiment crowdsourcing
               (crowd_key)        - add extra name to Web keys to avoid overlapping with original crowdsourcing HTML
               (crowd_on_change)  - reuse onchange doc from original crowdsourcing HTML
            }

    Output: {
              return       - return code =  0, if successful
                                         >  0, if error
              (error)      - error text if return > 0
            }

    """

    import os

    st=''

    cmuoa=i.get('crowd_module_uoa','')
    ckey=i.get('crowd_key','')

    conc=i.get('crowd_on_change','')
    if conc=='':
        conc=onchange

    h='<hr>\n'
    h+='<center>\n'
    h+='\n\n<script language="JavaScript">function copyToClipboard (text) {window.prompt ("Copy to clipboard: Ctrl+C, Enter", text);}</script>\n\n' 

    h+='<h2>All WA results</h2>\n'

    h+=hextra

    # Check host URL prefix and default module/action
    rx=ck.access({'action':'form_url_prefix',
                  'module_uoa':'wfe',
                  'host':i.get('host',''), 
                  'port':i.get('port',''), 
                  'template':i.get('template','')})
    if rx['return']>0: return rx
    url0=rx['url']
    template=rx['template']

    url=url0
    action=i.get('action','')
    muoa=i.get('module_uoa','')

    st=''

    url+='action=index&module_uoa=wfe&native_action='+action+'&'+'native_module_uoa='+muoa
    url1=url

    # List entries
    ii={'action':'search',
        'module_uoa':work['self_module_uid'],
        'add_meta':'yes'}

    if cmuoa!='':
        ii['module_uoa']=cmuoa

    r=ck.access(ii)
    if r['return']>0: return r

    lst=r['lst']

    # Check unique entries
    choices={}
    wchoices={}

    for q in lst:
        d=q['meta']
        meta=d.get('meta',{})

        for kk in selector:
            kx=kk['key']
            k=ckey+kx

            if k not in choices: 
                choices[k]=[]
                wchoices[k]=[{'name':'','value':''}]

            v=meta.get(kx,'')
            if v!='':
                if v not in choices[k]: 
                    choices[k].append(v)
                    wchoices[k].append({'name':v, 'value':v})

    # Prepare query div ***************************************************************
    if cmuoa=='':
        # Start form + URL (even when viewing entry)
        r=ck.access({'action':'start_form',
                     'module_uoa':cfg['module_deps']['wfe'],
                     'url':url1,
                     'name':form_name})
        if r['return']>0: return r
        h+=r['html']

    for kk in selector:
        k=ckey+kk['key']
        n=kk['name']

        nl=kk.get('new_line','')
        if nl=='yes':
            h+='<br>\n<div id="ck_entries_space8"></div>\n'

        v=''
        if i.get(k,'')!='':
            v=i[k]
            kk['value']=v

        # Show hardware
        ii={'action':'create_selector',
            'module_uoa':cfg['module_deps']['wfe'],
            'data':wchoices.get(k,[]),
            'name':k,
            'onchange':conc, 
            'skip_sort':'no',
            'selected_value':v}
        r=ck.access(ii)
        if r['return']>0: return r

        h+='<b>'+n+':</b> '+r['html'].strip()+'\n'

    h+='<br><br>'

    # Prune list
    plst=[]
    for q in lst:
        d=q['meta']
        meta=d.get('meta',{})

        # Check selector
        skip=False
        for kk in selector:
            k=kk['key']
            n=kk['name']
            v=kk.get('value','')

            if v!='' and meta.get(k,'')!=v:
                skip=True

        if not skip:
            plst.append(q)

    # Check if too many
    lplst=len(plst)
    if lplst==0:
        h+='<b>No results found!</b>'
        return {'return':0, 'html':h, 'style':st}
    elif lplst>50:
        h+='<b>Too many entries to show ('+str(lplst)+') - please, prune list further!</b>'
        return {'return':0, 'html':h, 'style':st}

    # Prepare table
    h+='<table border="1" cellpadding="7" cellspacing="0">\n'

    ha='align="center" valign="top"'
    hb='align="left" valign="top"'

    h+='  <tr style="background-color:#dddddd">\n'
    h+='   <td '+ha+'><b>All raw files</b></td>\n'
    h+='   <td '+ha+'><b>Workload</b></td>\n'
    h+='   <td '+ha+'><b>Scenario</b></td>\n'
    h+='   <td '+ha+'><b>Platform</b></td>\n'
    h+='   <td '+ha+'><b>serial number / adb device ID</b></td>\n'
    h+='   <td '+ha+'><b>CPU</b></td>\n'
    h+='   <td '+ha+'><b>GPU</b></td>\n'
    h+='   <td '+ha+'><b>OS</b></td>\n'
    h+='   <td '+ha+'><b>APK</b></td>\n'
    h+='   <td '+ha+'><b>WA version</b></td>\n'
    h+='   <td '+ha+'><b>Fail?</b></td>\n'
    h+='   <td '+hb+'><b>Choices</b></td>\n'
    h+='   <td '+hb+'><b>Characteristics</b></td>\n'
    h+='   <td '+ha+'><b>JSON results</b></td>\n'
    h+='   <td '+ha+'><b>Replay</b></td>\n'
    h+='  <tr>\n'

    # Dictionary to hold target meta
    tm={}

    ix=0
    bgraph={"0":[]} # Just for graph demo

    for q in sorted(plst, key=lambda x: x.get('meta',{}).get('meta',{}).get('workload_name','')):
        ix+=1

        duid=q['data_uid']
        path=q['path']

        d=q['meta']

        meta=d.get('meta',{})

        params=d.get('choices',{}).get('params',{}).get('params',{})

        pname=meta.get('program_uoa','')
        wname=meta.get('workload_name','')
        wuid=meta.get('program_uid','')

        apk_name=meta.get('apk_name','')
        apk_ver=meta.get('apk_version','')

        wa_ver=meta.get('wa_version','')

        scenario=meta.get('scenario','')

        ltarget_uoa=meta.get('local_target_uoa','')
        ltarget_uid=meta.get('local_target_uid','')

        if ltarget_uid!='' and ltarget_uid not in tm:
            # Load machine meta
            rx=ck.access({'action':'load',
                          'module_uoa':cfg['module_deps']['machine'],
                          'data_uoa':ltarget_uid})
            if rx['return']==0:
                tm[ltarget_uid]=rx['dict']

        plat_name=meta.get('plat_name','')
        cpu_name=meta.get('cpu_name','')
        os_name=meta.get('os_name','')
        gpu_name=meta.get('gpu_name','')

        adb_id=tm.get(ltarget_uid,{}).get('device_id','')
        sn=meta.get('serial_number','')

        te=d.get('characteristics',{}).get('run',{})
        tet=te.get('total_execution_time',0)

        bgc='afffaf'
        fail=d.get('state',{}).get('fail','')
        fail_reason=d.get('state',{}).get('fail_reason','')
        if fail=='yes':
            if fail_reason=='': fail_reason='yes'

            bgc='ffafaf'
        else:
            if i.get(ckey+'workload_name','')!='' and i.get(ckey+'scenario','')!='':
                bgraph['0'].append([ix,tet])

        bg=' style="background-color:#'+bgc+';"'

        h+='  <tr'+bg+'>\n'

        x=work['self_module_uid']
        if cmuoa!='': x=cmuoa
        h+='   <td '+ha+'>'+str(ix)+')&nbsp;<a href="'+url0+'&wcid='+x+':'+duid+'">'+duid+'</a></td>\n'

        x=wname
        if wuid!='': x='<a href="'+url0+'&wcid='+cfg['module_deps']['program']+':'+wuid+'">'+x+'</a>'
        h+='   <td '+ha+'>'+x+'</td>\n'

        x=''
        if scenario!='':
            x='<a href="'+url0+'&wcid='+cfg['module_deps']['wa-scenario']+':'+scenario+'">'+scenario+'</a>'
        h+='   <td '+ha+'>'+x+'</td>\n'

        x=plat_name
        if ltarget_uid!='':
           x='<a href="'+url0+'&wcid='+cfg['module_deps']['machine']+':'+ltarget_uid+'">'+x+'</a>'
        h+='   <td '+ha+'>'+x+'</td>\n'

        x=sn
        if adb_id!='' and adb_id!=sn: x+=' / '+adb_id
        h+='   <td '+ha+'>'+x+'</td>\n'

        h+='   <td '+ha+'>'+cpu_name+'</td>\n'
        h+='   <td '+ha+'>'+gpu_name+'</td>\n'
        h+='   <td '+ha+'>'+os_name+'</td>\n'

        # APK
        x=apk_name
        if apk_ver!='': x+=' (V'+apk_ver+')'
#        x=x.replace("'","\'").replace('"',"\\'").replace('\n','\\n')
        x=x.replace("\'","'").replace("'","\\'").replace('\"','"').replace('"',"\\'").replace('\n','\\n')

        x1=''
        if x!='':
            x1='<input type="button" class="ck_small_button" onClick="alert(\''+x+'\');" value="See">'

        h+='   <td '+ha+'>'+x1+'</td>\n'

        h+='   <td '+ha+'>'+wa_ver+'</td>\n'

        x=fail_reason
        if x=='': 
            x='No'
        else:
            fail_reason=fail_reason.replace("\'","'").replace("'","\\'").replace('\"','"').replace('"',"\\'").replace('\n','\\n')
            x='Yes <input type="button" class="ck_small_button" onClick="alert(\''+fail_reason+'\');" value="Log">'

        h+='   <td '+ha+'>'+x+'</td>\n'

        # Params
#        x='<table border="0" cellpadding="0" cellspacing="2">\n'
        x=''
        for k in sorted(params):
            v=params[k]
            x+=str(k)+'='+str(v)+'\n'
#            x+='<tr><td>'+str(k)+'=</td><td>'+str(v)+'</td></tr>\n'
#        x+='</table>\n'
#        x=x.replace("'","\'").replace('"',"\\'").replace('\n','\\n')
        x=x.replace("\'","'").replace("'","\\'").replace('\"','"').replace('"',"\\'").replace('\n','\\n')

        x1=''
        if x!='':
            x1='<input type="button" class="ck_small_button" onClick="alert(\''+x+'\');" value="See">'

        h+='   <td '+hb+'>'+x1+'</td>\n'

        # Characteristics
        # Check if has statistics
        dstat={}
        fstat=os.path.join(path,'ck-stat-flat-characteristics.json')
        if os.path.isfile(fstat):
            r=ck.load_json_file({'json_file':fstat, 'dict':dstat})
            if r['return']>0: return r
            dstat=r['dict']

        x=''
        if tet>0: x=('%.1f'%tet)+' sec.'

        # Check if has stats
        x1=dstat.get("##characteristics#run#total_execution_time#center",None)
        x2=dstat.get("##characteristics#run#total_execution_time#halfrange",None)
        if x1!=None and x2!=None:
            x=('%.1f'%x1)+' &PlusMinus; '+('%.1f'%x2)+' sec.'

        # Check all
        x5=''
        for k in sorted(te):
            v=te[k]

            kx="##characteristics#run#"+k

            kx1=dstat.get(kx+'#center',None)
            kx2=dstat.get(kx+'#halfrange',None)

            x6=''
            if type(v)==int:
                if kx1!=None and kx2!=None:
                    x6=str(kx1)+' +- '+str(kx2)
                else:
                    x6=str(v)
            elif type(v)==float:
                if kx1!=None and kx2!=None:
                    x6=('%.1f'%kx1)+' +- '+('%.1f'%kx2)
                else:
                    x6=('%.1f'%v)

            if x6!='':
                x5+=str(k)+'='+x6+'\n'

#        x5=x5.replace("'","\'").replace('"',"\\'").replace('\n','\\n')
        x5=x5.replace("\'","'").replace("'","\\'").replace('\"','"').replace('"',"\\'").replace('\n','\\n')
        if x5!='':
            x+='<br><input type="button" class="ck_small_button" onClick="alert(\''+x5+'\');" value="All">'

        h+='   <td '+ha+'>'+x+'</td>\n'

        # Check directories with results
        x=''
        xf1='wa-output'
        xf2='results.json'
        xf=xf1+'/'+xf2
        for d0 in os.listdir(path):
            found=False
            brk=False

            d1=os.path.join(d0,xf)
            d2=os.path.join(path,d1)

            if os.path.isfile(d2):
                found=True
            else:
                d1=xf
                d2=os.path.join(path,d1)

                if os.path.isfile(d2):
                    d0=xf1
                    found=True
                    brk=True

            if found:
                if x!='': x+='<br>\n'
                x1=work['self_module_uid']
                if cmuoa!='':
                    x1=cmuoa
                x+='[&nbsp;<a href="'+url0+'action=pull&common_action=yes&cid='+x1+':'+duid+'&filename='+d1+'">'+d0+'</a>&nbsp;]\n'

                if brk:
                    break
        h+='   <td '+ha+'>'+x+'</td>\n'

        h+='   <td '+ha+'><input type="button" class="ck_small_button" onClick="copyToClipboard(\'ck replay wa:'+wname+'\');" value="Replay"></td>\n'

        h+='  <tr>\n'

    h+='</table>\n'
    h+='</center>\n'

    if cmuoa=='':
        h+='</form>\n'

    if len(bgraph['0'])>0:
       ii={'action':'plot',
           'module_uoa':cfg['module_deps']['graph'],

           "table":bgraph,

           "h_lines":[1.0],

           "ymin":0,

           "ignore_point_if_none":"yes",

           "plot_type":"d3_2d_bars",

           "display_y_error_bar":"no",

           "title":"Powered by Collective Knowledge",

           "axis_x_desc":"Platform",
           "axis_y_desc":"Execution time (sec.)",

           "plot_grid":"yes",

           "d3_div":"ck_interactive",

           "image_width":"900",
           "image_height":"400",

           "wfe_url":url0}

       r=ck.access(ii)
       if r['return']==0:
          x=r.get('html','')
          if x!='':
             st+=r.get('style','')

             h+='<br>\n'
             h+='<center>\n'
             h+='<div id="ck_box_with_shadow" style="width:920px;">\n'
             h+=' <div id="ck_interactive" style="text-align:center">\n'
             h+=x+'\n'
             h+=' </div>\n'
             h+='</div>\n'
             h+='</center>\n'

    return {'return':0, 'html':h, 'style':st}

Example 34

Project: cstar_perf
Source File: stress_compare.py
View license
def stress_compare(revisions,
                   title,
                   log,
                   operations = [],
                   subtitle = '',
                   capture_fincore=False,
                   initial_destroy=True,
                   leave_data=False,
                   keep_page_cache=False,
                   git_fetch_before_test=True,
                   bootstrap_before_test=True,
                   teardown_after_test=True
               ):
    """
    Run Stress on multiple C* branches and compare them.

    revisions - List of dictionaries that contain cluster configurations
                to trial. This is combined with the default config.
    title - The title of the comparison
    subtitle - A subtitle for more information (displayed smaller underneath)
    log - The json file path to record stats to
    operations - List of dictionaries indicating the operations. Example:
       [# cassandra-stress command, node defaults to cluster defined 'stress_node'
        {'type': 'stress',
         'command': 'write n=19M -rate threads=50',
         'node': 'node1',
         'wait_for_compaction': True},
        # nodetool command to run in parallel on nodes:
        {'type': 'nodetool',
         'command': 'decomission',
         'nodes': ['node1','node2']},
        # cqlsh script, node defaults to cluster defined 'stress_node'
        {'type': 'cqlsh',
         'script': "use my_ks; INSERT INTO blah (col1, col2) VALUES (val1, val2);",
         'node': 'node1'}
       ]
    capture_fincore - Enables capturing of linux-fincore logs of C* data files.
    initial_destroy - Destroy all data before the first revision is run.
    leave_data - Whether to leave the Cassandra data/commitlog/etc directories intact between revisions.
    keep_page_cache - Whether to leave the linux page cache intact between revisions.
    git_fetch_before_test (bool): If True, will update the cassandra.git with fab_common.git_repos
    bootstrap_before_test (bool): If True, will bootstrap DSE / C* before running the operations
    teardown_after_test (bool): If True, will shutdown DSE / C* after all of the operations
    """
    validate_revisions_list(revisions)
    validate_operations_list(operations)

    pristine_config = copy.copy(fab_config)

    # initial_destroy and git_fetch_before_test can be set in the job configuration,
    # or manually in the call to this function.
    # Either is fine, but they shouldn't conflict. If they do, a ValueError is raised.
    initial_destroy = get_bool_if_method_and_config_values_do_not_conflict('initial_destroy',
                                                                           initial_destroy,
                                                                           pristine_config,
                                                                           method_name='stress_compare')

    if initial_destroy:
        logger.info("Cleaning up from prior runs of stress_compare ...")
        teardown(destroy=True, leave_data=False)

    # https://datastax.jira.com/browse/CSTAR-633
    git_fetch_before_test = get_bool_if_method_and_config_values_do_not_conflict('git_fetch_before_test',
                                                                                 git_fetch_before_test,
                                                                                 pristine_config,
                                                                                 method_name='stress_compare')

    stress_shas = maybe_update_cassandra_git_and_setup_stress(operations, git_fetch=git_fetch_before_test)

    # Flamegraph Setup
    if flamegraph.is_enabled():
        execute(flamegraph.setup)

    with GracefulTerminationHandler() as handler:
        for rev_num, revision_config in enumerate(revisions):
            config = copy.copy(pristine_config)
            config.update(revision_config)
            revision = revision_config['revision']
            config['log'] = log
            config['title'] = title
            config['subtitle'] = subtitle
            product = dse if config.get('product') == 'dse' else cstar

            # leave_data, bootstrap_before_test, and teardown_after_test can be set in the job configuration,
            # or manually in the call to this function.
            # Either is fine, but they shouldn't conflict. If they do, a ValueError is raised.
            leave_data = get_bool_if_method_and_config_values_do_not_conflict('leave_data',
                                                                              leave_data,
                                                                              revision_config,
                                                                              method_name='stress_compare')

            # https://datastax.jira.com/browse/CSTAR-638
            bootstrap_before_test = get_bool_if_method_and_config_values_do_not_conflict('bootstrap_before_test',
                                                                                         bootstrap_before_test,
                                                                                         revision_config,
                                                                                         method_name='stress_compare')

            # https://datastax.jira.com/browse/CSTAR-639
            teardown_after_test = get_bool_if_method_and_config_values_do_not_conflict('teardown_after_test',
                                                                                       teardown_after_test,
                                                                                       revision_config,
                                                                                       method_name='stress_compare')

            logger.info("Bringing up {revision} cluster...".format(revision=revision))

            # Drop the page cache between each revision, especially
            # important when leave_data=True :
            if not keep_page_cache:
                drop_page_cache()

            # Only fetch from git on the first run and if git_fetch_before_test is True
            git_fetch_before_bootstrap = True if rev_num == 0 and git_fetch_before_test else False
            if bootstrap_before_test:
                revision_config['git_id'] = git_id = bootstrap(config,
                                                               destroy=initial_destroy,
                                                               leave_data=leave_data,
                                                               git_fetch=git_fetch_before_bootstrap)
            else:
                revision_config['git_id'] = git_id = config['revision']

            if flamegraph.is_enabled(revision_config):
                execute(flamegraph.ensure_stopped_perf_agent)
                execute(flamegraph.start_perf_agent, rev_num)

            if capture_fincore:
                start_fincore_capture(interval=10)

            last_stress_operation_id = 'None'
            for operation_i, operation in enumerate(operations, 1):
                try:
                    start = datetime.datetime.now()
                    stats = {
                        "id": str(uuid.uuid1()),
                        "type": operation['type'],
                        "revision": revision,
                        "git_id": git_id,
                        "start_date": start.isoformat(),
                        "label": revision_config.get('label', revision_config['revision']),
                        "test": '{operation_i}_{operation}'.format(
                            operation_i=operation_i,
                            operation=operation['type'])
                    }

                    if operation['type'] == 'stress':
                        last_stress_operation_id = stats['id']
                        # Default to all the nodes of the cluster if no
                        # nodes were specified in the command:
                        if operation.has_key('nodes'):
                            cmd = "{command} -node {hosts}".format(
                                command=operation['command'],
                                hosts=",".join(operation['nodes']))
                        elif '-node' in operation['command']:
                            cmd = operation['command']
                        else:
                            cmd = "{command} -node {hosts}".format(
                                command=operation['command'],
                                hosts=",".join([n for n in fab_config['hosts']]))
                        stats['command'] = cmd
                        stats['intervals'] = []
                        stats['test'] = '{operation_i}_{operation}'.format(
                            operation_i=operation_i, operation=cmd.strip().split(' ')[0]).replace(" ", "_")
                        logger.info('Running stress operation : {cmd}  ...'.format(cmd=cmd))
                        # Run stress:
                        # (stress takes the stats as a parameter, and adds
                        #  more as it runs):
                        stress_sha = stress_shas[operation.get('stress_revision', 'default')]
                        stats = stress(cmd, revision, stress_sha, stats=stats)
                        # Wait for all compactions to finish (unless disabled):
                        if operation.get('wait_for_compaction', True):
                            compaction_throughput = revision_config.get("compaction_throughput_mb_per_sec", 16)
                            wait_for_compaction(compaction_throughput=compaction_throughput)

                    elif operation['type'] == 'nodetool':
                        if 'nodes' not in operation:
                            operation['nodes'] = 'all'
                        if operation['nodes'] in ['all','ALL']:
                            nodes = [n for n in fab_config['hosts']]
                        else:
                            nodes = operation['nodes']

                        set_nodetool_path(os.path.join(product.get_bin_path(), 'nodetool'))
                        logger.info("Running nodetool on {nodes} with command: {command}".format(nodes=operation['nodes'], command=operation['command']))
                        stats['command'] = operation['command']
                        output = nodetool_multi(nodes, operation['command'])
                        stats['output'] = output
                        logger.info("Nodetool command finished on all nodes")

                    elif operation['type'] == 'cqlsh':
                        logger.info("Running cqlsh commands on {node}".format(node=operation['node']))
                        set_cqlsh_path(os.path.join(product.get_bin_path(), 'cqlsh'))
                        output = cqlsh(operation['script'], operation['node'])
                        stats['output'] = output.split("\n")
                        stats['command'] = operation['script']
                        logger.info("Cqlsh commands finished")

                    elif operation['type'] == 'bash':
                        nodes = operation.get('nodes', [n for n in fab_config['hosts']])
                        logger.info("Running bash commands on: {nodes}".format(nodes=nodes))
                        stats['output'] = bash(operation['script'], nodes)
                        stats['command'] = operation['script']
                        logger.info("Bash commands finished")

                    elif operation['type'] == 'spark_cassandra_stress':
                        nodes = operation.get('nodes', [n for n in fab_config['hosts']])
                        stress_node = config.get('stress_node', None)
                        # Note: once we have https://datastax.jira.com/browse/CSTAR-617, we should fix this to use
                        # client-tool when DSE_VERSION >= 4.8.0
                        # https://datastax.jira.com/browse/DSP-6025: dse client-tool
                        master_regex = re.compile(r"(.|\n)*(?P<master>spark:\/\/\d+.\d+.\d+.\d+:\d+)(.|\n)*")
                        master_out = dsetool_cmd(nodes[0], options='sparkmaster')[nodes[0]]
                        master_match = master_regex.match(master_out)
                        if not master_match:
                            raise ValueError('Could not find master address from "dsetool sparkmaster" cmd\n'
                                             'Found output: {f}'.format(f=master_out))
                        master_string = master_match.group('master')
                        build_spark_cassandra_stress = bool(distutils.util.strtobool(
                            str(operation.get('build_spark_cassandra_stress', 'True'))))
                        remove_existing_spark_data = bool(distutils.util.strtobool(
                            str(operation.get('remove_existing_spark_data', 'True'))))
                        logger.info("Running spark_cassandra_stress on {stress_node} "
                                    "using spark.cassandra.connection.host={node} and "
                                    "spark-master {master}".format(stress_node=stress_node,
                                                                   node=nodes[0],
                                                                   master=master_string))
                        output = spark_cassandra_stress(operation['script'], nodes, stress_node=stress_node,
                                                        master=master_string,
                                                        build_spark_cassandra_stress=build_spark_cassandra_stress,
                                                        remove_existing_spark_data=remove_existing_spark_data)
                        stats['output'] = output.get('output', 'No output captured')
                        stats['spark_cass_stress_time_in_seconds'] = output.get('stats', {}).get('TimeInSeconds', 'No time captured')
                        stats['spark_cass_stress_ops_per_second'] = output.get('stats', {}).get('OpsPerSecond', 'No ops/s captured')
                        logger.info("spark_cassandra_stress finished")

                    elif operation['type'] == 'ctool':
                        logger.info("Running ctool with parameters: {command}".format(command=operation['command']))
                        ctool = Ctool(operation['command'], common.config)
                        output = execute(ctool.run)
                        stats['output'] = output
                        logger.info("ctool finished")

                    elif operation['type'] == 'dsetool':
                        if 'nodes' not in operation:
                            operation['nodes'] = 'all'
                        if operation['nodes'] in ['all','ALL']:
                            nodes = [n for n in fab_config['hosts']]
                        else:
                            nodes = operation['nodes']

                        dsetool_options = operation['script']
                        logger.info("Running dsetool {command} on {nodes}".format(nodes=operation['nodes'], command=dsetool_options))
                        stats['command'] = dsetool_options
                        output = dsetool_cmd(nodes=nodes, options=dsetool_options)
                        stats['output'] = output
                        logger.info("dsetool command finished on all nodes")

                    elif operation['type'] == 'dse':
                        logger.info("Running dse command on {node}".format(node=operation['node']))
                        output = dse_cmd(node=operation['node'], options=operation['script'])
                        stats['output'] = output.split("\n")
                        stats['command'] = operation['script']
                        logger.info("dse commands finished")

                    end = datetime.datetime.now()
                    stats['end_date'] = end.isoformat()
                    stats['op_duration'] = str(end - start)
                    log_stats(stats, file=log)
                finally:
                    # Copy node logs:
                    retrieve_logs_and_create_tarball(job_id=stats['id'])
                    revision_config['last_log'] = stats['id']

                if capture_fincore:
                    stop_fincore_capture()
                    log_dir = os.path.join(CSTAR_PERF_LOGS_DIR, stats['id'])
                    retrieve_fincore_logs(log_dir)
                    # Restart fincore capture if this is not the last
                    # operation:
                    if operation_i < len(operations):
                        start_fincore_capture(interval=10)

            if flamegraph.is_enabled(revision_config):
                # Generate and Copy node flamegraphs
                execute(flamegraph.stop_perf_agent)
                execute(flamegraph.generate_flamegraph, rev_num)
                flamegraph_dir = os.path.join(os.path.expanduser('~'),'.cstar_perf', 'flamegraph')
                flamegraph_test_dir = os.path.join(flamegraph_dir, last_stress_operation_id)
                retrieve_flamegraph(flamegraph_test_dir, rev_num+1)
                sh.tar('cfvz', "{}.tar.gz".format(stats['id']), last_stress_operation_id, _cwd=flamegraph_dir)
                shutil.rmtree(flamegraph_test_dir)

            log_add_data(log, {'title':title,
                               'subtitle': subtitle,
                               'revisions': revisions})
            if teardown_after_test:
                if revisions[-1].get('leave_data', leave_data):
                    teardown(destroy=False, leave_data=True)
                else:
                    kill_delay = 300 if profiler.yourkit_is_enabled(revision_config) else 0
                    teardown(destroy=True, leave_data=False, kill_delay=kill_delay)

            if profiler.yourkit_is_enabled(revision_config):
                yourkit_config = profiler.yourkit_get_config()
                yourkit_dir = os.path.join(os.path.expanduser('~'),'.cstar_perf', 'yourkit')
                yourkit_test_dir = os.path.join(yourkit_dir, last_stress_operation_id)
                retrieve_yourkit(yourkit_test_dir, rev_num+1)
                sh.tar('cfvz', "{}.tar.gz".format(stats['id']),
                       last_stress_operation_id, _cwd=yourkit_dir)
                shutil.rmtree(yourkit_test_dir)

Example 35

Project: p2pool-n
Source File: main.py
View license
@defer.inlineCallbacks
def main(args, net, datadir_path, merged_urls, worker_endpoint):
    try:
        print 'p2pool (version %s)' % (p2pool.__version__,)
        print
        
        @defer.inlineCallbacks
        def connect_p2p():
            # connect to bitcoind over bitcoin-p2p
            print '''Testing bitcoind P2P connection to '%s:%s'...''' % (args.bitcoind_address, args.bitcoind_p2p_port)
            factory = bitcoin_p2p.ClientFactory(net.PARENT)
            reactor.connectTCP(args.bitcoind_address, args.bitcoind_p2p_port, factory)
            def long():
                print '''    ...taking a while. Common reasons for this include all of bitcoind's connection slots being used...'''
            long_dc = reactor.callLater(5, long)
            yield factory.getProtocol() # waits until handshake is successful
            if not long_dc.called: long_dc.cancel()
            print '    ...success!'
            print
            defer.returnValue(factory)
        
        if args.testnet: # establish p2p connection first if testnet so bitcoind can work without connections
            factory = yield connect_p2p()
        
        # connect to bitcoind over JSON-RPC and do initial getmemorypool
        url = '%s://%s:%i/' % ('https' if args.bitcoind_rpc_ssl else 'http', args.bitcoind_address, args.bitcoind_rpc_port)
        print '''Testing bitcoind RPC connection to '%s' with username '%s'...''' % (url, args.bitcoind_rpc_username)
        bitcoind = jsonrpc.HTTPProxy(url, dict(Authorization='Basic ' + base64.b64encode(args.bitcoind_rpc_username + ':' + args.bitcoind_rpc_password)), timeout=30)
        yield helper.check(bitcoind, net)
        temp_work = yield helper.getwork(bitcoind)
        
        bitcoind_getinfo_var = variable.Variable(None)
        @defer.inlineCallbacks
        def poll_warnings():
            bitcoind_getinfo_var.set((yield deferral.retry('Error while calling getinfo:')(bitcoind.rpc_getinfo)()))
        yield poll_warnings()
        deferral.RobustLoopingCall(poll_warnings).start(20*60)
        
        print '    ...success!'
        print '    Current block hash: %x' % (temp_work['previous_block'],)
        print '    Current block height: %i' % (temp_work['height'] - 1,)
        print
        
        if not args.testnet:
            factory = yield connect_p2p()
        
        print 'Determining payout address...'
        if args.pubkey_hash is None:
            address_path = os.path.join(datadir_path, 'cached_payout_address')
            
            if os.path.exists(address_path):
                with open(address_path, 'rb') as f:
                    address = f.read().strip('\r\n')
                print '    Loaded cached address: %s...' % (address,)
            else:
                address = None
            
            if address is not None:
                res = yield deferral.retry('Error validating cached address:', 5)(lambda: bitcoind.rpc_validateaddress(address))()
                if not res['isvalid'] or not res['ismine']:
                    print '    Cached address is either invalid or not controlled by local bitcoind!'
                    address = None
            
            if address is None:
                print '    Getting payout address from bitcoind...'
                address = yield deferral.retry('Error getting payout address from bitcoind:', 5)(lambda: bitcoind.rpc_getaccountaddress('p2pool'))()
            
            with open(address_path, 'wb') as f:
                f.write(address)
            
            my_pubkey_hash = bitcoin_data.address_to_pubkey_hash(address, net.PARENT)
        else:
            my_pubkey_hash = args.pubkey_hash
        print '    ...success! Payout address:', bitcoin_data.pubkey_hash_to_address(my_pubkey_hash, net.PARENT)
        print
        
        print "Loading shares..."
        shares = {}
        known_verified = set()
        def share_cb(share):
            share.time_seen = 0 # XXX
            shares[share.hash] = share
            if len(shares) % 1000 == 0 and shares:
                print "    %i" % (len(shares),)
        ss = p2pool_data.ShareStore(os.path.join(datadir_path, 'shares.'), net, share_cb, known_verified.add)
        print "    ...done loading %i shares (%i verified)!" % (len(shares), len(known_verified))
        print
        
        
        print 'Initializing work...'
        
        node = p2pool_node.Node(factory, bitcoind, shares.values(), known_verified, net)
        yield node.start()
        
        for share_hash in shares:
            if share_hash not in node.tracker.items:
                ss.forget_share(share_hash)
        for share_hash in known_verified:
            if share_hash not in node.tracker.verified.items:
                ss.forget_verified_share(share_hash)
        node.tracker.removed.watch(lambda share: ss.forget_share(share.hash))
        node.tracker.verified.removed.watch(lambda share: ss.forget_verified_share(share.hash))
        
        def save_shares():
            for share in node.tracker.get_chain(node.best_share_var.value, min(node.tracker.get_height(node.best_share_var.value), 2*net.CHAIN_LENGTH)):
                ss.add_share(share)
                if share.hash in node.tracker.verified.items:
                    ss.add_verified_hash(share.hash)
        deferral.RobustLoopingCall(save_shares).start(60)
        
        print '    ...success!'
        print
        
        
        print 'Joining p2pool network using port %i...' % (args.p2pool_port,)
        
        @defer.inlineCallbacks
        def parse(host):
            port = net.P2P_PORT
            if ':' in host:
                host, port_str = host.split(':')
                port = int(port_str)
            defer.returnValue(((yield reactor.resolve(host)), port))
        
        addrs = {}
        if os.path.exists(os.path.join(datadir_path, 'addrs')):
            try:
                with open(os.path.join(datadir_path, 'addrs'), 'rb') as f:
                    addrs.update(dict((tuple(k), v) for k, v in json.loads(f.read())))
            except:
                print >>sys.stderr, 'error parsing addrs'
        for addr_df in map(parse, net.BOOTSTRAP_ADDRS):
            try:
                addr = yield addr_df
                if addr not in addrs:
                    addrs[addr] = (0, time.time(), time.time())
            except:
                log.err()
        
        connect_addrs = set()
        for addr_df in map(parse, args.p2pool_nodes):
            try:
                connect_addrs.add((yield addr_df))
            except:
                log.err()
        
        node.p2p_node = p2pool_node.P2PNode(node,
            port=args.p2pool_port,
            max_incoming_conns=args.p2pool_conns,
            addr_store=addrs,
            connect_addrs=connect_addrs,
            desired_outgoing_conns=args.p2pool_outgoing_conns,
            advertise_ip=args.advertise_ip,
        )
        node.p2p_node.start()
        
        def save_addrs():
            with open(os.path.join(datadir_path, 'addrs'), 'wb') as f:
                f.write(json.dumps(node.p2p_node.addr_store.items()))
        deferral.RobustLoopingCall(save_addrs).start(60)
        
        print '    ...success!'
        print
        
        if args.upnp:
            @defer.inlineCallbacks
            def upnp_thread():
                while True:
                    try:
                        is_lan, lan_ip = yield ipdiscover.get_local_ip()
                        if is_lan:
                            pm = yield portmapper.get_port_mapper()
                            yield pm._upnp.add_port_mapping(lan_ip, args.p2pool_port, args.p2pool_port, 'p2pool', 'TCP')
                    except defer.TimeoutError:
                        pass
                    except:
                        if p2pool.DEBUG:
                            log.err(None, 'UPnP error:')
                    yield deferral.sleep(random.expovariate(1/120))
            upnp_thread()
        
        # start listening for workers with a JSON-RPC server
        
        print 'Listening for workers on %r port %i...' % (worker_endpoint[0], worker_endpoint[1])
        
        wb = work.WorkerBridge(node, my_pubkey_hash, args.donation_percentage, merged_urls, args.worker_fee)
        web_root = web.get_web_root(wb, datadir_path, bitcoind_getinfo_var)
        caching_wb = worker_interface.CachingWorkerBridge(wb)
        worker_interface.WorkerInterface(caching_wb).attach_to(web_root, get_handler=lambda request: request.redirect('static/'))
        web_serverfactory = server.Site(web_root)
        
        
        serverfactory = switchprotocol.FirstByteSwitchFactory({'{': stratum.StratumServerFactory(caching_wb)}, web_serverfactory)
        deferral.retry('Error binding to worker port:', traceback=False)(reactor.listenTCP)(worker_endpoint[1], serverfactory, interface=worker_endpoint[0])
        
        with open(os.path.join(os.path.join(datadir_path, 'ready_flag')), 'wb') as f:
            pass
        
        print '    ...success!'
        print
        
        
        # done!
        print 'Started successfully!'
        print 'Go to http://127.0.0.1:%i/ to view graphs and statistics!' % (worker_endpoint[1],)
        if args.donation_percentage > 1.1:
            print '''Donating %.1f%% of work towards P2Pool's development. Thanks for the tip!''' % (args.donation_percentage,)
        elif args.donation_percentage < .9:
            print '''Donating %.1f%% of work towards P2Pool's development. Please donate to encourage further development of P2Pool!''' % (args.donation_percentage,)
        else:
            print '''Donating %.1f%% of work towards P2Pool's development. Thank you!''' % (args.donation_percentage,)
            print 'You can increase this amount with --give-author argument! (or decrease it, if you must)'
        print
        
        
        if hasattr(signal, 'SIGALRM'):
            signal.signal(signal.SIGALRM, lambda signum, frame: reactor.callFromThread(
                sys.stderr.write, 'Watchdog timer went off at:\n' + ''.join(traceback.format_stack())
            ))
            signal.siginterrupt(signal.SIGALRM, False)
            deferral.RobustLoopingCall(signal.alarm, 30).start(1)
        
        if args.irc_announce:
            from twisted.words.protocols import irc
            class IRCClient(irc.IRCClient):
                nickname = 'p2pool%02i' % (random.randrange(100),)
                channel = net.ANNOUNCE_CHANNEL
                def lineReceived(self, line):
                    if p2pool.DEBUG:
                        print repr(line)
                    irc.IRCClient.lineReceived(self, line)
                def signedOn(self):
                    self.in_channel = False
                    irc.IRCClient.signedOn(self)
                    self.factory.resetDelay()
                    self.join(self.channel)
                    @defer.inlineCallbacks
                    def new_share(share):
                        if not self.in_channel:
                            return
                        if share.pow_hash <= share.header['bits'].target and abs(share.timestamp - time.time()) < 10*60:
                            yield deferral.sleep(random.expovariate(1/60))
                            message = '\x02%s BLOCK FOUND by %s! %s%064x' % (net.NAME.upper(), bitcoin_data.script2_to_address(share.new_script, net.PARENT), net.PARENT.BLOCK_EXPLORER_URL_PREFIX, share.header_hash)
                            if all('%x' % (share.header_hash,) not in old_message for old_message in self.recent_messages):
                                self.say(self.channel, message)
                                self._remember_message(message)
                    self.watch_id = node.tracker.verified.added.watch(new_share)
                    self.recent_messages = []
                def joined(self, channel):
                    self.in_channel = True
                def left(self, channel):
                    self.in_channel = False
                def _remember_message(self, message):
                    self.recent_messages.append(message)
                    while len(self.recent_messages) > 100:
                        self.recent_messages.pop(0)
                def privmsg(self, user, channel, message):
                    if channel == self.channel:
                        self._remember_message(message)
                def connectionLost(self, reason):
                    node.tracker.verified.added.unwatch(self.watch_id)
                    print 'IRC connection lost:', reason.getErrorMessage()
            class IRCClientFactory(protocol.ReconnectingClientFactory):
                protocol = IRCClient

            reactor.connectTCP("irc.freenode.net", 6667, IRCClientFactory(), bindAddress=(worker_endpoint[0], 0))
        
        @defer.inlineCallbacks
        def status_thread():
            last_str = None
            last_time = 0
            while True:
                yield deferral.sleep(3)
                try:
                    height = node.tracker.get_height(node.best_share_var.value)
                    this_str = 'P2Pool: %i shares in chain (%i verified/%i total) Peers: %i (%i incoming)' % (
                        height,
                        len(node.tracker.verified.items),
                        len(node.tracker.items),
                        len(node.p2p_node.peers),
                        sum(1 for peer in node.p2p_node.peers.itervalues() if peer.incoming),
                    ) + (' FDs: %i R/%i W' % (len(reactor.getReaders()), len(reactor.getWriters())) if p2pool.DEBUG else '')
                    
                    datums, dt = wb.local_rate_monitor.get_datums_in_last()
                    my_att_s = sum(datum['work']/dt for datum in datums)
                    my_shares_per_s = sum(datum['work']/dt/bitcoin_data.target_to_average_attempts(datum['share_target']) for datum in datums)
                    this_str += '\n Local: %sH/s in last %s Local dead on arrival: %s Expected time to share: %s' % (
                        math.format(int(my_att_s)),
                        math.format_dt(dt),
                        math.format_binomial_conf(sum(1 for datum in datums if datum['dead']), len(datums), 0.95),
                        math.format_dt(1/my_shares_per_s) if my_shares_per_s else '???',
                    )
                    
                    if height > 2:
                        (stale_orphan_shares, stale_doa_shares), shares, _ = wb.get_stale_counts()
                        stale_prop = p2pool_data.get_average_stale_prop(node.tracker, node.best_share_var.value, min(60*60//net.SHARE_PERIOD, height))
                        real_att_s = p2pool_data.get_pool_attempts_per_second(node.tracker, node.best_share_var.value, min(height - 1, 60*60//net.SHARE_PERIOD)) / (1 - stale_prop)
                        
                        this_str += '\n Shares: %i (%i orphan, %i dead) Stale rate: %s Efficiency: %s Current payout: %.4f %s' % (
                            shares, stale_orphan_shares, stale_doa_shares,
                            math.format_binomial_conf(stale_orphan_shares + stale_doa_shares, shares, 0.95),
                            math.format_binomial_conf(stale_orphan_shares + stale_doa_shares, shares, 0.95, lambda x: (1 - x)/(1 - stale_prop)),
                            node.get_current_txouts().get(bitcoin_data.pubkey_hash_to_script2(my_pubkey_hash), 0)*1e-8, net.PARENT.SYMBOL,
                        )
                        this_str += '\n Pool: %sH/s Stale rate: %.1f%% Expected time to block: %s' % (
                            math.format(int(real_att_s)),
                            100*stale_prop,
                            math.format_dt(2**256 / node.bitcoind_work.value['bits'].target / real_att_s),
                        )
                        
                        for warning in p2pool_data.get_warnings(node.tracker, node.best_share_var.value, net, bitcoind_getinfo_var.value, node.bitcoind_work.value):
                            print >>sys.stderr, '#'*40
                            print >>sys.stderr, '>>> Warning: ' + warning
                            print >>sys.stderr, '#'*40
                        
                        if gc.garbage:
                            print '%i pieces of uncollectable cyclic garbage! Types: %r' % (len(gc.garbage), map(type, gc.garbage))
                    
                    if this_str != last_str or time.time() > last_time + 15:
                        print this_str
                        last_str = this_str
                        last_time = time.time()
                except:
                    log.err()
        status_thread()
    except:
        reactor.stop()
        log.err(None, 'Fatal error:')

Example 36

Project: p2pool-n
Source File: web.py
View license
def get_web_root(wb, datadir_path, bitcoind_getinfo_var, stop_event=variable.Event()):
    node = wb.node
    start_time = time.time()
    
    web_root = resource.Resource()
    
    def get_users():
        height, last = node.tracker.get_height_and_last(node.best_share_var.value)
        weights, total_weight, donation_weight = node.tracker.get_cumulative_weights(node.best_share_var.value, min(height, 720), 65535*2**256)
        res = {}
        for script in sorted(weights, key=lambda s: weights[s]):
            res[bitcoin_data.script2_to_address(script, node.net.PARENT)] = weights[script]/total_weight
        return res
    
    def get_current_scaled_txouts(scale, trunc=0):
        txouts = node.get_current_txouts()
        total = sum(txouts.itervalues())
        results = dict((script, value*scale//total) for script, value in txouts.iteritems())
        if trunc > 0:
            total_random = 0
            random_set = set()
            for s in sorted(results, key=results.__getitem__):
                if results[s] >= trunc:
                    break
                total_random += results[s]
                random_set.add(s)
            if total_random:
                winner = math.weighted_choice((script, results[script]) for script in random_set)
                for script in random_set:
                    del results[script]
                results[winner] = total_random
        if sum(results.itervalues()) < int(scale):
            results[math.weighted_choice(results.iteritems())] += int(scale) - sum(results.itervalues())
        return results
    
    def get_patron_sendmany(total=None, trunc='0.01'):
        if total is None:
            return 'need total argument. go to patron_sendmany/<TOTAL>'
        total = int(float(total)*1e8)
        trunc = int(float(trunc)*1e8)
        return json.dumps(dict(
            (bitcoin_data.script2_to_address(script, node.net.PARENT), value/1e8)
            for script, value in get_current_scaled_txouts(total, trunc).iteritems()
            if bitcoin_data.script2_to_address(script, node.net.PARENT) is not None
        ))
    
    def get_global_stats():
        # averaged over last hour
        if node.tracker.get_height(node.best_share_var.value) < 10:
            return None
        lookbehind = min(node.tracker.get_height(node.best_share_var.value), 3600//node.net.SHARE_PERIOD)
        
        nonstale_hash_rate = p2pool_data.get_pool_attempts_per_second(node.tracker, node.best_share_var.value, lookbehind)
        stale_prop = p2pool_data.get_average_stale_prop(node.tracker, node.best_share_var.value, lookbehind)
        diff = bitcoin_data.target_to_difficulty(wb.current_work.value['bits'].target)

        return dict(
            pool_nonstale_hash_rate=nonstale_hash_rate,
            pool_hash_rate=nonstale_hash_rate/(1 - stale_prop),
            pool_stale_prop=stale_prop,
            min_difficulty=bitcoin_data.target_to_difficulty(node.tracker.items[node.best_share_var.value].max_target),
            network_block_difficulty=diff,
            network_hashrate=(diff * 2**32 // node.net.PARENT.BLOCK_PERIOD),
        )
    
    def get_local_stats():
        if node.tracker.get_height(node.best_share_var.value) < 10:
            return None
        lookbehind = min(node.tracker.get_height(node.best_share_var.value), 3600//node.net.SHARE_PERIOD)
        
        global_stale_prop = p2pool_data.get_average_stale_prop(node.tracker, node.best_share_var.value, lookbehind)
        
        my_unstale_count = sum(1 for share in node.tracker.get_chain(node.best_share_var.value, lookbehind) if share.hash in wb.my_share_hashes)
        my_orphan_count = sum(1 for share in node.tracker.get_chain(node.best_share_var.value, lookbehind) if share.hash in wb.my_share_hashes and share.share_data['stale_info'] == 'orphan')
        my_doa_count = sum(1 for share in node.tracker.get_chain(node.best_share_var.value, lookbehind) if share.hash in wb.my_share_hashes and share.share_data['stale_info'] == 'doa')
        my_share_count = my_unstale_count + my_orphan_count + my_doa_count
        my_stale_count = my_orphan_count + my_doa_count
        
        my_stale_prop = my_stale_count/my_share_count if my_share_count != 0 else None
        
        my_work = sum(bitcoin_data.target_to_average_attempts(share.target)
            for share in node.tracker.get_chain(node.best_share_var.value, lookbehind - 1)
            if share.hash in wb.my_share_hashes)
        actual_time = (node.tracker.items[node.best_share_var.value].timestamp -
            node.tracker.items[node.tracker.get_nth_parent_hash(node.best_share_var.value, lookbehind - 1)].timestamp)
        share_att_s = my_work / actual_time
        
        miner_hash_rates, miner_dead_hash_rates = wb.get_local_rates()
        (stale_orphan_shares, stale_doa_shares), shares, _ = wb.get_stale_counts()

        miner_last_difficulties = {}
        for addr in wb.last_work_shares.value:
            miner_last_difficulties[addr] = bitcoin_data.target_to_difficulty(wb.last_work_shares.value[addr].target)
        
        return dict(
            my_hash_rates_in_last_hour=dict(
                note="DEPRECATED",
                nonstale=share_att_s,
                rewarded=share_att_s/(1 - global_stale_prop),
                actual=share_att_s/(1 - my_stale_prop) if my_stale_prop is not None else 0, # 0 because we don't have any shares anyway
            ),
            my_share_counts_in_last_hour=dict(
                shares=my_share_count,
                unstale_shares=my_unstale_count,
                stale_shares=my_stale_count,
                orphan_stale_shares=my_orphan_count,
                doa_stale_shares=my_doa_count,
            ),
            my_stale_proportions_in_last_hour=dict(
                stale=my_stale_prop,
                orphan_stale=my_orphan_count/my_share_count if my_share_count != 0 else None,
                dead_stale=my_doa_count/my_share_count if my_share_count != 0 else None,
            ),
            miner_hash_rates=miner_hash_rates,
            miner_dead_hash_rates=miner_dead_hash_rates,
            miner_last_difficulties=miner_last_difficulties,
            efficiency_if_miner_perfect=(1 - stale_orphan_shares/shares)/(1 - global_stale_prop) if shares else None, # ignores dead shares because those are miner's fault and indicated by pseudoshare rejection
            efficiency=(1 - (stale_orphan_shares+stale_doa_shares)/shares)/(1 - global_stale_prop) if shares else None,
            peers=dict(
                incoming=sum(1 for peer in node.p2p_node.peers.itervalues() if peer.incoming),
                outgoing=sum(1 for peer in node.p2p_node.peers.itervalues() if not peer.incoming),
            ),
            shares=dict(
                total=shares,
                orphan=stale_orphan_shares,
                dead=stale_doa_shares,
            ),
            uptime=time.time() - start_time,
            attempts_to_share=bitcoin_data.target_to_average_attempts(node.tracker.items[node.best_share_var.value].max_target),
            attempts_to_block=bitcoin_data.target_to_average_attempts(node.bitcoind_work.value['bits'].target),
            block_value=node.bitcoind_work.value['subsidy']*1e-8,
            warnings=p2pool_data.get_warnings(node.tracker, node.best_share_var.value, node.net, bitcoind_getinfo_var.value, node.bitcoind_work.value),
            donation_proportion=wb.donation_percentage/100,
            version=p2pool.__version__,
            protocol_version=p2p.Protocol.VERSION,
            fee=wb.worker_fee,
        )
    
    class WebInterface(deferred_resource.DeferredResource):
        def __init__(self, func, mime_type='application/json', args=()):
            deferred_resource.DeferredResource.__init__(self)
            self.func, self.mime_type, self.args = func, mime_type, args
        
        def getChild(self, child, request):
            return WebInterface(self.func, self.mime_type, self.args + (child,))
        
        @defer.inlineCallbacks
        def render_GET(self, request):
            request.setHeader('Content-Type', self.mime_type)
            request.setHeader('Access-Control-Allow-Origin', '*')
            res = yield self.func(*self.args)
            defer.returnValue(json.dumps(res) if self.mime_type == 'application/json' else res)
    
    def decent_height():
        return min(node.tracker.get_height(node.best_share_var.value), 720)
    web_root.putChild('rate', WebInterface(lambda: p2pool_data.get_pool_attempts_per_second(node.tracker, node.best_share_var.value, decent_height())/(1-p2pool_data.get_average_stale_prop(node.tracker, node.best_share_var.value, decent_height()))))
    web_root.putChild('difficulty', WebInterface(lambda: bitcoin_data.target_to_difficulty(node.tracker.items[node.best_share_var.value].max_target)))
    web_root.putChild('users', WebInterface(get_users))
    web_root.putChild('user_stales', WebInterface(lambda: dict((bitcoin_data.pubkey_hash_to_address(ph, node.net.PARENT), prop) for ph, prop in
        p2pool_data.get_user_stale_props(node.tracker, node.best_share_var.value, node.tracker.get_height(node.best_share_var.value)).iteritems())))
    web_root.putChild('fee', WebInterface(lambda: wb.worker_fee))
    web_root.putChild('current_payouts', WebInterface(lambda: dict((bitcoin_data.script2_to_address(script, node.net.PARENT), value/1e8) for script, value in node.get_current_txouts().iteritems())))
    web_root.putChild('patron_sendmany', WebInterface(get_patron_sendmany, 'text/plain'))
    web_root.putChild('global_stats', WebInterface(get_global_stats))
    web_root.putChild('local_stats', WebInterface(get_local_stats))
    web_root.putChild('peer_addresses', WebInterface(lambda: ' '.join('%s%s' % (peer.transport.getPeer().host, ':'+str(peer.transport.getPeer().port) if peer.transport.getPeer().port != node.net.P2P_PORT else '') for peer in node.p2p_node.peers.itervalues())))
    web_root.putChild('peer_txpool_sizes', WebInterface(lambda: dict(('%s:%i' % (peer.transport.getPeer().host, peer.transport.getPeer().port), peer.remembered_txs_size) for peer in node.p2p_node.peers.itervalues())))
    web_root.putChild('pings', WebInterface(defer.inlineCallbacks(lambda: defer.returnValue(
        dict([(a, (yield b)) for a, b in
            [(
                '%s:%i' % (peer.transport.getPeer().host, peer.transport.getPeer().port),
                defer.inlineCallbacks(lambda peer=peer: defer.returnValue(
                    min([(yield peer.do_ping().addCallback(lambda x: x/0.001).addErrback(lambda fail: None)) for i in xrange(3)])
                ))()
            ) for peer in list(node.p2p_node.peers.itervalues())]
        ])
    ))))
    web_root.putChild('peer_versions', WebInterface(lambda: dict(('%s:%i' % peer.addr, peer.other_sub_version) for peer in node.p2p_node.peers.itervalues())))
    web_root.putChild('payout_addr', WebInterface(lambda: bitcoin_data.pubkey_hash_to_address(wb.my_pubkey_hash, node.net.PARENT)))
    web_root.putChild('recent_blocks', WebInterface(lambda: [dict(
        ts=s.timestamp,
        hash='%064x' % s.header_hash,
        number=pack.IntType(24).unpack(s.share_data['coinbase'][1:4]) if len(s.share_data['coinbase']) >= 4 else None,
        share='%064x' % s.hash,
    ) for s in node.tracker.get_chain(node.best_share_var.value, min(node.tracker.get_height(node.best_share_var.value), 24*60*60//node.net.SHARE_PERIOD)) if s.pow_hash <= s.header['bits'].target]))
    web_root.putChild('uptime', WebInterface(lambda: time.time() - start_time))
    web_root.putChild('stale_rates', WebInterface(lambda: p2pool_data.get_stale_counts(node.tracker, node.best_share_var.value, decent_height(), rates=True)))
    
    new_root = resource.Resource()
    web_root.putChild('web', new_root)
    
    stat_log = []
    if os.path.exists(os.path.join(datadir_path, 'stats')):
        try:
            with open(os.path.join(datadir_path, 'stats'), 'rb') as f:
                stat_log = json.loads(f.read())
        except:
            log.err(None, 'Error loading stats:')
    def update_stat_log():
        while stat_log and stat_log[0]['time'] < time.time() - 24*60*60:
            stat_log.pop(0)
        
        lookbehind = 3600//node.net.SHARE_PERIOD
        if node.tracker.get_height(node.best_share_var.value) < lookbehind:
            return None
        
        global_stale_prop = p2pool_data.get_average_stale_prop(node.tracker, node.best_share_var.value, lookbehind)
        (stale_orphan_shares, stale_doa_shares), shares, _ = wb.get_stale_counts()
        miner_hash_rates, miner_dead_hash_rates = wb.get_local_rates()
        
        stat_log.append(dict(
            time=time.time(),
            pool_hash_rate=p2pool_data.get_pool_attempts_per_second(node.tracker, node.best_share_var.value, lookbehind)/(1-global_stale_prop),
            pool_stale_prop=global_stale_prop,
            local_hash_rates=miner_hash_rates,
            local_dead_hash_rates=miner_dead_hash_rates,
            shares=shares,
            stale_shares=stale_orphan_shares + stale_doa_shares,
            stale_shares_breakdown=dict(orphan=stale_orphan_shares, doa=stale_doa_shares),
            current_payout=node.get_current_txouts().get(bitcoin_data.pubkey_hash_to_script2(wb.my_pubkey_hash), 0)*1e-8,
            peers=dict(
                incoming=sum(1 for peer in node.p2p_node.peers.itervalues() if peer.incoming),
                outgoing=sum(1 for peer in node.p2p_node.peers.itervalues() if not peer.incoming),
            ),
            attempts_to_share=bitcoin_data.target_to_average_attempts(node.tracker.items[node.best_share_var.value].max_target),
            attempts_to_block=bitcoin_data.target_to_average_attempts(node.bitcoind_work.value['bits'].target),
            block_value=node.bitcoind_work.value['subsidy']*1e-8,
        ))
        
        with open(os.path.join(datadir_path, 'stats'), 'wb') as f:
            f.write(json.dumps(stat_log))
    x = deferral.RobustLoopingCall(update_stat_log)
    x.start(5*60)
    stop_event.watch(x.stop)
    new_root.putChild('log', WebInterface(lambda: stat_log))
    
    def get_share(share_hash_str):
        if int(share_hash_str, 16) not in node.tracker.items:
            return None
        share = node.tracker.items[int(share_hash_str, 16)]
        
        return dict(
            parent='%064x' % share.previous_hash,
            children=['%064x' % x for x in sorted(node.tracker.reverse.get(share.hash, set()), key=lambda sh: -len(node.tracker.reverse.get(sh, set())))], # sorted from most children to least children
            type_name=type(share).__name__,
            local=dict(
                verified=share.hash in node.tracker.verified.items,
                time_first_seen=start_time if share.time_seen == 0 else share.time_seen,
                peer_first_received_from=share.peer_addr,
            ),
            share_data=dict(
                timestamp=share.timestamp,
                target=share.target,
                max_target=share.max_target,
                payout_address=bitcoin_data.script2_to_address(share.new_script, node.net.PARENT),
                donation=share.share_data['donation']/65535,
                stale_info=share.share_data['stale_info'],
                nonce=share.share_data['nonce'],
                desired_version=share.share_data['desired_version'],
                absheight=share.absheight,
                abswork=share.abswork,
            ),
            block=dict(
                hash='%064x' % share.header_hash,
                header=dict(
                    version=share.header['version'],
                    previous_block='%064x' % share.header['previous_block'],
                    merkle_root='%064x' % share.header['merkle_root'],
                    timestamp=share.header['timestamp'],
                    target=share.header['bits'].target,
                    nonce=share.header['nonce'],
                ),
                gentx=dict(
                    hash='%064x' % share.gentx_hash,
                    coinbase=share.share_data['coinbase'].ljust(2, '\x00').encode('hex'),
                    value=share.share_data['subsidy']*1e-8,
                    last_txout_nonce='%016x' % share.contents['last_txout_nonce'],
                ),
                other_transaction_hashes=['%064x' % x for x in share.get_other_tx_hashes(node.tracker)],
            ),
        )
    new_root.putChild('share', WebInterface(lambda share_hash_str: get_share(share_hash_str)))
    new_root.putChild('heads', WebInterface(lambda: ['%064x' % x for x in node.tracker.heads]))
    new_root.putChild('verified_heads', WebInterface(lambda: ['%064x' % x for x in node.tracker.verified.heads]))
    new_root.putChild('tails', WebInterface(lambda: ['%064x' % x for t in node.tracker.tails for x in node.tracker.reverse.get(t, set())]))
    new_root.putChild('verified_tails', WebInterface(lambda: ['%064x' % x for t in node.tracker.verified.tails for x in node.tracker.verified.reverse.get(t, set())]))
    new_root.putChild('best_share_hash', WebInterface(lambda: '%064x' % node.best_share_var.value))
    new_root.putChild('my_share_hashes', WebInterface(lambda: ['%064x' % my_share_hash for my_share_hash in wb.my_share_hashes]))
    def get_share_data(share_hash_str):
        if int(share_hash_str, 16) not in node.tracker.items:
            return ''
        share = node.tracker.items[int(share_hash_str, 16)]
        return p2pool_data.share_type.pack(share.as_share1a())
    new_root.putChild('share_data', WebInterface(lambda share_hash_str: get_share_data(share_hash_str), 'application/octet-stream'))
    new_root.putChild('currency_info', WebInterface(lambda: dict(
        symbol=node.net.PARENT.SYMBOL,
        block_explorer_url_prefix=node.net.PARENT.BLOCK_EXPLORER_URL_PREFIX,
        address_explorer_url_prefix=node.net.PARENT.ADDRESS_EXPLORER_URL_PREFIX,
        tx_explorer_url_prefix=node.net.PARENT.TX_EXPLORER_URL_PREFIX,
    )))
    new_root.putChild('version', WebInterface(lambda: p2pool.__version__))
    
    hd_path = os.path.join(datadir_path, 'graph_db')
    hd_data = _atomic_read(hd_path)
    hd_obj = {}
    if hd_data is not None:
        try:
            hd_obj = json.loads(hd_data)
        except Exception:
            log.err(None, 'Error reading graph database:')
    dataview_descriptions = {
        'last_hour': graph.DataViewDescription(150, 60*60),
        'last_day': graph.DataViewDescription(300, 60*60*24),
        'last_week': graph.DataViewDescription(300, 60*60*24*7),
        'last_month': graph.DataViewDescription(300, 60*60*24*30),
        'last_year': graph.DataViewDescription(300, 60*60*24*365.25),
    }
    hd = graph.HistoryDatabase.from_obj({
        'local_hash_rate': graph.DataStreamDescription(dataview_descriptions, is_gauge=False),
        'local_dead_hash_rate': graph.DataStreamDescription(dataview_descriptions, is_gauge=False),
        'local_share_hash_rates': graph.DataStreamDescription(dataview_descriptions, is_gauge=False,
            multivalues=True, multivalue_undefined_means_0=True,
            default_func=graph.make_multivalue_migrator(dict(good='local_share_hash_rate', dead='local_dead_share_hash_rate', orphan='local_orphan_share_hash_rate'),
                post_func=lambda bins: [dict((k, (v[0] - (sum(bin.get(rem_k, (0, 0))[0] for rem_k in ['dead', 'orphan']) if k == 'good' else 0), v[1])) for k, v in bin.iteritems()) for bin in bins])),
        'pool_rates': graph.DataStreamDescription(dataview_descriptions, multivalues=True,
            multivalue_undefined_means_0=True),
        'current_payout': graph.DataStreamDescription(dataview_descriptions),
        'current_payouts': graph.DataStreamDescription(dataview_descriptions, multivalues=True),
        'peers': graph.DataStreamDescription(dataview_descriptions, multivalues=True, default_func=graph.make_multivalue_migrator(dict(incoming='incoming_peers', outgoing='outgoing_peers'))),
        'miner_hash_rates': graph.DataStreamDescription(dataview_descriptions, is_gauge=False, multivalues=True),
        'miner_dead_hash_rates': graph.DataStreamDescription(dataview_descriptions, is_gauge=False, multivalues=True),
        'desired_version_rates': graph.DataStreamDescription(dataview_descriptions, multivalues=True,
            multivalue_undefined_means_0=True),
        'traffic_rate': graph.DataStreamDescription(dataview_descriptions, is_gauge=False, multivalues=True),
        'getwork_latency': graph.DataStreamDescription(dataview_descriptions),
        'memory_usage': graph.DataStreamDescription(dataview_descriptions),
    }, hd_obj)
    x = deferral.RobustLoopingCall(lambda: _atomic_write(hd_path, json.dumps(hd.to_obj())))
    x.start(100)
    stop_event.watch(x.stop)
    @wb.pseudoshare_received.watch
    def _(work, dead, user):
        t = time.time()
        hd.datastreams['local_hash_rate'].add_datum(t, work)
        if dead:
            hd.datastreams['local_dead_hash_rate'].add_datum(t, work)
        if user is not None:
            hd.datastreams['miner_hash_rates'].add_datum(t, {user: work})
            if dead:
                hd.datastreams['miner_dead_hash_rates'].add_datum(t, {user: work})
    @wb.share_received.watch
    def _(work, dead, share_hash):
        t = time.time()
        if not dead:
            hd.datastreams['local_share_hash_rates'].add_datum(t, dict(good=work))
        else:
            hd.datastreams['local_share_hash_rates'].add_datum(t, dict(dead=work))
        def later():
            res = node.tracker.is_child_of(share_hash, node.best_share_var.value)
            if res is None: res = False # share isn't connected to sharechain? assume orphaned
            if res and dead: # share was DOA, but is now in sharechain
                # move from dead to good
                hd.datastreams['local_share_hash_rates'].add_datum(t, dict(dead=-work, good=work))
            elif not res and not dead: # share wasn't DOA, and isn't in sharechain
                # move from good to orphan
                hd.datastreams['local_share_hash_rates'].add_datum(t, dict(good=-work, orphan=work))
        reactor.callLater(200, later)
    @node.p2p_node.traffic_happened.watch
    def _(name, bytes):
        hd.datastreams['traffic_rate'].add_datum(time.time(), {name: bytes})
    def add_point():
        if node.tracker.get_height(node.best_share_var.value) < 10:
            return None
        lookbehind = min(node.net.CHAIN_LENGTH, 60*60//node.net.SHARE_PERIOD, node.tracker.get_height(node.best_share_var.value))
        t = time.time()
        
        pool_rates = p2pool_data.get_stale_counts(node.tracker, node.best_share_var.value, lookbehind, rates=True)
        pool_total = sum(pool_rates.itervalues())
        hd.datastreams['pool_rates'].add_datum(t, pool_rates)
        
        current_txouts = node.get_current_txouts()
        hd.datastreams['current_payout'].add_datum(t, current_txouts.get(bitcoin_data.pubkey_hash_to_script2(wb.my_pubkey_hash), 0)*1e-8)
        miner_hash_rates, miner_dead_hash_rates = wb.get_local_rates()
        current_txouts_by_address = dict((bitcoin_data.script2_to_address(script, node.net.PARENT), amount) for script, amount in current_txouts.iteritems())
        hd.datastreams['current_payouts'].add_datum(t, dict((user, current_txouts_by_address[user]*1e-8) for user in miner_hash_rates if user in current_txouts_by_address))
        
        hd.datastreams['peers'].add_datum(t, dict(
            incoming=sum(1 for peer in node.p2p_node.peers.itervalues() if peer.incoming),
            outgoing=sum(1 for peer in node.p2p_node.peers.itervalues() if not peer.incoming),
        ))
        
        vs = p2pool_data.get_desired_version_counts(node.tracker, node.best_share_var.value, lookbehind)
        vs_total = sum(vs.itervalues())
        hd.datastreams['desired_version_rates'].add_datum(t, dict((str(k), v/vs_total*pool_total) for k, v in vs.iteritems()))
        try:
            hd.datastreams['memory_usage'].add_datum(t, memory.resident())
        except:
            if p2pool.DEBUG:
                traceback.print_exc()
    x = deferral.RobustLoopingCall(add_point)
    x.start(5)
    stop_event.watch(x.stop)
    @node.bitcoind_work.changed.watch
    def _(new_work):
        hd.datastreams['getwork_latency'].add_datum(time.time(), new_work['latency'])
    new_root.putChild('graph_data', WebInterface(lambda source, view: hd.datastreams[source].dataviews[view].get_data(time.time())))
    
    web_root.putChild('static', static.File(os.path.join(os.path.dirname(os.path.abspath(sys.argv[0])), 'web-static')))
    
    return web_root

Example 37

Project: changes
Source File: config.py
View license
def create_app(_read_config=True, **config):
    app = flask.Flask(__name__,
                      static_folder=None,
                      template_folder=os.path.join(PROJECT_ROOT, 'templates'))

    # app.wsgi_app = TracerMiddleware(app.wsgi_app, app)

    # This key is insecure and you should override it on the server
    app.config['SECRET_KEY'] = 't\xad\xe7\xff%\xd2.\xfe\x03\x02=\xec\xaf\\2+\xb8=\xf7\x8a\x9aLD\xb1'

    app.config['SQLALCHEMY_COMMIT_ON_TEARDOWN'] = True
    app.config['SQLALCHEMY_DATABASE_URI'] = 'postgresql:///changes'
    app.config['SQLALCHEMY_POOL_SIZE'] = 60
    app.config['SQLALCHEMY_MAX_OVERFLOW'] = 20
    # required for flask-debugtoolbar and the db perf metrics we record
    app.config['SQLALCHEMY_RECORD_QUERIES'] = True

    app.config['REDIS_URL'] = 'redis://localhost/0'
    app.config['GROUPER_API_URL'] = 'https://localhost/'
    app.config['GROUPER_PERMISSIONS_ADMIN'] = 'changes.prod.admin'
    app.config['GROUPER_PERMISSIONS_PROJECT_ADMIN'] = 'changes.prod.project.admin'
    app.config['GROUPER_EXCLUDED_ROLES'] = ['np-owner']
    app.config['DEBUG'] = True
    app.config['HTTP_PORT'] = 5000
    app.config['SEND_FILE_MAX_AGE_DEFAULT'] = 0

    app.config['BAZEL_ARTIFACT_SUFFIX'] = '.bazel'

    app.config['BAZEL_TEST_OUTPUT_RELATIVE_PATH'] = 'bazel-testlogs/'

    app.config['API_TRACEBACKS'] = True

    # Expiration delay between when a snapshot image becomes superceded and when
    # it becomes truly expired (and thus no longer included in the sync information
    # for any cluster that runs that particular image's plan)
    app.config['CACHED_SNAPSHOT_EXPIRATION_DELTA'] = timedelta(hours=1)

    # default snapshot ID to use when no project-specific active image available
    app.config['DEFAULT_SNAPSHOT'] = None
    app.config['SNAPSHOT_S3_BUCKET'] = None
    app.config['LXC_PRE_LAUNCH'] = None
    app.config['LXC_POST_LAUNCH'] = None

    # APT mirror URLs to use for new LXC containers created by changes-client.
    # NB: these aren't currently supported in the public changes-client repo.
    app.config['LXC_APT_MIRROR'] = None
    app.config['LXC_APT_SECURITY_MIRROR'] = None

    # name of the template to use for LXC (usually the name of a particular
    # Linux distro). Defaults to ubuntu.
    app.config['LXC_TEMPLATE'] = 'ubuntu'

    # Location of artifacts server that is passed to changes-client
    # (include http:// or https://)
    #
    # The default artifact server url uses a random uri which is expected to fail
    # without being overridden. This value is referenced in test code.
    app.config['ARTIFACTS_SERVER'] = 'http://localhost:1234'

    # The default max artifact size handlers should be capable of processing.
    app.config['MAX_ARTIFACT_BYTES'] = 200 * 1024 * 1024
    # The max artifact size the analytics json handler should be capable of processing.
    app.config['MAX_ARTIFACT_BYTES_ANALYTICS_JSON'] = 70 * 1024 * 1024

    # the binary to use for running changes-client. Default is just
    # "changes-client", but can also be specified as e.g. a full path.
    app.config['CHANGES_CLIENT_BINARY'] = 'changes-client'

    app.config['CHANGES_CLIENT_DEFAULT_BUILD_TYPE'] = 'legacy'

    # Base URI to use for git repos that we want to clone (currently only used
    # for the "other_repos" buildstep config). The repo name is appended
    # directly to this, so it should already contain necessary colons and
    # slashes, etc. For example, if GIT_DEFAULT_BASE_URI is `[email protected]:`
    # and a repo is specified as `changes.git`, the clone url will be
    # `[email protected]:changes.git`
    app.config['GIT_DEFAULT_BASE_URI'] = None
    # Same as GIT_DEFAULT_BASE_URI but used for mercurial repos.
    app.config['MERCURIAL_DEFAULT_BASE_URI'] = None

    # This is a hash from each build type (string identifiers used in
    # build step configuration) to a "build spec", a definition of
    # how to use changes-client to build. To use changes-client, the key
    # 'uses_client' must be set to True.
    #
    # Required build spec keys for client:
    #   adapter -> basic or lxc
    #   jenkins-command -> command to run from jenkins directly ($JENKINS_COMMAND)
    #   commands -> array of hash from script -> string that represents a script
    #
    # Optional keys (lxc-only)
    #   pre-launch -> lxc pre-launch script
    #   post-launch -> lxc post-launch script
    #   release -> lxc release
    app.config['CHANGES_CLIENT_BUILD_TYPES'] = {
        'legacy': {'uses_client': False},
    }

    app.config['CELERY_ACCEPT_CONTENT'] = ['changes_json']
    app.config['CELERY_ACKS_LATE'] = True
    app.config['CELERY_BROKER_URL'] = 'redis://localhost/0'
    app.config['CELERY_DEFAULT_QUEUE'] = "default"
    app.config['CELERY_DEFAULT_EXCHANGE'] = "default"
    app.config['CELERY_DEFAULT_EXCHANGE_TYPE'] = "direct"
    app.config['CELERY_DEFAULT_ROUTING_KEY'] = "default"
    app.config['CELERY_DISABLE_RATE_LIMITS'] = True
    app.config['CELERY_IGNORE_RESULT'] = True
    app.config['CELERY_RESULT_BACKEND'] = None
    app.config['CELERY_RESULT_SERIALIZER'] = 'changes_json'
    app.config['CELERY_SEND_EVENTS'] = False
    app.config['CELERY_TASK_RESULT_EXPIRES'] = 1
    app.config['CELERY_TASK_SERIALIZER'] = 'changes_json'
    app.config['CELERYD_PREFETCH_MULTIPLIER'] = 1
    app.config['CELERYD_MAX_TASKS_PER_CHILD'] = 10000

    # By default, Celery logs writes to stdout/stderr as WARNING, which
    # is a bit harsh considering that some of the code is code we don't
    # own calling 'print'. This flips the default back to INFO, which seems
    # more appropriate. Can be overridden by the Changes config.
    app.config['CELERY_REDIRECT_STDOUTS_LEVEL'] = 'INFO'

    app.config['CELERY_QUEUES'] = (
        Queue('job.sync', routing_key='job.sync'),
        Queue('job.create', routing_key='job.create'),
        Queue('celery', routing_key='celery'),
        Queue('events', routing_key='events'),
        Queue('default', routing_key='default'),
        Queue('delete', routing_key='delete'),
        Queue('repo.sync', Exchange('fanout', 'fanout'), routing_key='repo.sync'),
        Queue('grouper.sync', routing_key='grouper.sync'),
        Broadcast('repo.update'),
    )
    app.config['CELERY_ROUTES'] = {
        'create_job': {
            'queue': 'job.create',
            'routing_key': 'job.create',
        },
        'sync_job': {
            'queue': 'job.sync',
            'routing_key': 'job.sync',
        },
        'sync_job_step': {
            'queue': 'job.sync',
            'routing_key': 'job.sync',
        },
        'sync_build': {
            'queue': 'job.sync',
            'routing_key': 'job.sync',
        },
        'check_repos': {
            'queue': 'repo.sync',
            'routing_key': 'repo.sync',
        },
        'sync_grouper': {
            'queue': 'grouper.sync',
            'routing_key': 'grouper.sync',
        },
        'sync_repo': {
            'queue': 'repo.sync',
            'routing_key': 'repo.sync',
        },
        'run_event_listener': {
            'queue': 'events',
            'routing_key': 'events',
        },
        'fire_signal': {
            'queue': 'events',
            'routing_key': 'events',
        },
        'update_local_repos': {
            'queue': 'repo.update',
        },
        'delete_old_data': {
            'queue': 'delete',
            'routing_key': 'delete',
        },
        'delete_old_data_10m': {
            'queue': 'delete',
            'routing_key': 'delete',
        },
        'delete_old_data_5h_delayed': {
            'queue': 'delete',
            'routing_key': 'delete',
        },
    }

    app.config['EVENT_LISTENERS'] = (
        ('changes.listeners.mail.build_finished_handler', 'build.finished'),
        ('changes.listeners.green_build.revision_result_updated_handler', 'revision_result.updated'),
        ('changes.listeners.build_revision.revision_created_handler', 'revision.created'),
        ('changes.listeners.build_finished_notifier.build_finished_handler', 'build.finished'),
        ('changes.listeners.phabricator_listener.build_finished_handler', 'build.finished'),
        ('changes.listeners.analytics_notifier.build_finished_handler', 'build.finished'),
        ('changes.listeners.analytics_notifier.job_finished_handler', 'job.finished'),
        ('changes.listeners.revision_result.revision_result_build_finished_handler', 'build.finished'),
        ('changes.listeners.stats_notifier.build_finished_handler', 'build.finished'),
        ('changes.listeners.snapshot_build.build_finished_handler', 'build.finished'),
    )

    # restrict outbound notifications to the given domains
    app.config['MAIL_DOMAIN_WHITELIST'] = ()

    app.config['DEBUG_TB_ENABLED'] = True

    app.config['DEBUG_TB_PANELS'] = ('flask_debugtoolbar.panels.versions.VersionDebugPanel',
                                     'flask_debugtoolbar.panels.timer.TimerDebugPanel',
                                     'flask_debugtoolbar.panels.headers.HeaderDebugPanel',
                                     'flask_debugtoolbar.panels.request_vars.RequestVarsDebugPanel',
                                     # Disable the config vars panel by default; it can contain sensitive information.
                                     # 'flask_debugtoolbar.panels.config_vars.ConfigVarsDebugPanel',
                                     'flask_debugtoolbar.panels.template.TemplateDebugPanel',
                                     'flask_debugtoolbar.panels.sqlalchemy.SQLAlchemyDebugPanel',
                                     'flask_debugtoolbar.panels.logger.LoggingPanel',
                                     'flask_debugtoolbar.panels.profiler.ProfilerDebugPanel')

    # celerybeat must be running for our cleanup tasks to execute
    # e.g. celery worker -B
    app.config['CELERYBEAT_SCHEDULE'] = {
        'cleanup-tasks': {
            'task': 'cleanup_tasks',
            'schedule': timedelta(minutes=1),
        },
        'check-repos': {
            'task': 'check_repos',
            'schedule': timedelta(minutes=2),
        },
        'sync-grouper': {
            'task': 'sync_grouper',
            'schedule': timedelta(minutes=1),
        },
        'aggregate-flaky-tests': {
            'task': 'aggregate_flaky_tests',
            # Hour 7 GMT is midnight PST, hopefully a time of low load
            'schedule': crontab(hour=7, minute=0),
        },
        'delete-old-data-10m': {
            'task': 'delete_old_data_10m',
            'schedule': timedelta(minutes=10),
        },
        'delete-old-data-5h-delayed': {
            'task': 'delete_old_data_5h_delayed',
            # This task runs every 4 hours but looks at 5 hours worth of tests
            # so consecutive runs will look at sets of tests that will overlap.
            # This is to make it unlikely to miss tests in between.
            #
            # While this is looking at 5 hours worth of tests, this should not be long running
            # as the shorter delete tasks will catch most cases and this checks
            # a time frame that should've been cleaned by them already.
            'schedule': crontab(hour='*/4'),
        },
        'update-local-repos': {
            'task': 'update_local_repos',
            'schedule': timedelta(minutes=1),
        }
    }
    app.config['CELERY_TIMEZONE'] = 'UTC'

    app.config['SENTRY_DSN'] = None
    app.config['SENTRY_INCLUDE_PATHS'] = [
        'changes',
    ]

    app.config['KOALITY_URL'] = None
    app.config['KOALITY_API_KEY'] = None

    app.config['GOOGLE_CLIENT_ID'] = None
    app.config['GOOGLE_CLIENT_SECRET'] = None
    app.config['GOOGLE_DOMAIN'] = None

    # must be a URL-safe base64-encoded 32-byte key
    app.config['COOKIE_ENCRYPTION_KEY'] = 'theDefaultKeyIs32BytesLongAndTotallyURLSafe='

    app.config['REPO_ROOT'] = None

    app.config['DEFAULT_FILE_STORAGE'] = 'changes.storage.s3.S3FileStorage'
    app.config['S3_ACCESS_KEY'] = None
    app.config['S3_SECRET_KEY'] = None
    app.config['S3_BUCKET'] = None

    app.config['PHABRICATOR_LINK_HOST'] = None
    app.config['PHABRICATOR_API_HOST'] = None
    app.config['PHABRICATOR_USERNAME'] = None
    app.config['PHABRICATOR_CERT'] = None

    # Configuration to access Zookeeper - currently used to discover mesos master leader instance
    # E.g., if mesos master is configured to talk to zk://zk1:2181,zk2:2181/mesos,
    # set ZOOKEEPER_HOSTS = 'zk1:2181,zk2:2181'
    #     ZOOKEEPER_MESOS_MASTER_PATH = '/mesos'
    #
    # This is only used to control mesos slave offline/online status from within Changes

    # Comma-separated list of host:port (or ip:port) to Zookeeper instances.
    app.config['ZOOKEEPER_HOSTS'] = 'zk:2181'
    # Namespace within zookeeper where mesos master election is performed.
    app.config['ZOOKEEPER_MESOS_MASTER_PATH'] = '/mesos'

    # List of valid tables to be written to when reporting project analytics.
    # Analytics artifacts targeting tables not listed here will be considered invalid.
    app.config['ANALYTICS_PROJECT_TABLES'] = []
    # URL any project analytics JSON entries will be posted to.
    # Entries will be posted as JSON, with the intended table specified as 'source' in the URL params.
    app.config['ANALYTICS_PROJECT_POST_URL'] = None

    app.config['SUPPORT_CONTACT'] = 'support'

    app.config['MAIL_DEFAULT_SENDER'] = '[email protected]'
    app.config['BASE_URI'] = 'http://localhost:5000'

    # if set to a string, most (all?) of the frontend js will make API calls
    # to the host this string is set to (e.g. http://changes.bigcompany.com)
    # THIS IS JUST FOR EASIER TESTING IN DEVELOPMENT. Although it won't even
    # work in prod: you'll have to start chrome with --disable-web-security to
    # make this work. Override this in your changes.conf.py file
    app.config['WEBAPP_USE_ANOTHER_HOST'] = None

    # Custom changes content unique to your deployment. This is intended to
    # customize the look and feel, provide contextual help and add custom links
    # to other internal tools. You should put your files in webapp/custom and
    # link them here.
    #
    # e.g. /acmecorp-changes/changes.js
    #
    # Some of the custom_content hooks can show images. Assume that the webserver
    # is willing to serve any file within the directory of the js file
    app.config['WEBAPP_CUSTOM_JS'] = None
    # This can be a .less file. We import it after the variables.less,
    # so you can override them in your file
    # Note: if you change this and nothing seems to happen, try deleting
    # webapp/.webassets-cache and bundled.css. This probably won't happen, though
    # If not specified, we will search for CUSTOM_CSS_FILE in the custom dir.
    app.config['WEBAPP_CUSTOM_CSS'] = None

    # In minutes, the timeout applied to jobs without a timeout specified at build time.
    # A timeout should nearly always be specified; this is just a safeguard so that
    # unspecified timeout doesn't mean "is allowed to run indefinitely".
    app.config['DEFAULT_JOB_TIMEOUT_MIN'] = 60

    # Number of milliseconds a transaction can run before triggering a warning.
    app.config['TRANSACTION_MS_WARNING_THRESHOLD'] = 2500

    # Hard maximum number of jobsteps to retry for a given job
    app.config['JOBSTEP_RETRY_MAX'] = 6
    # Maximum number of machines that we'll retry jobsteps for. This allows us
    # to retry more jobsteps if it's always the same machine failing.
    app.config['JOBSTEP_MACHINE_RETRY_MAX'] = 2

    # the PHID of the user creating quarantine tasks. We can use this to show
    # the list of open quarantine tasks inline
    app.config['QUARANTINE_PHID'] = None

    # The max length a test's output to be stored. If it is longer, the it will
    # be truncated.
    app.config['TEST_MESSAGE_MAX_LEN'] = 64 * 1024

    # List of packages needed to install bazel and any environment.
    app.config['BAZEL_APT_PKGS'] = ['bazel']

    # rsync source for encap
    # Example: rsync://example.com/encap/
    app.config['ENCAP_RSYNC_URL'] = None

    # In some configurations, build slaves might not have access to the Changes API via the
    # normal address; if PATCH_BASE_URI is specified, it'll be used as the base URI for
    # PATCH_URI variables provided to build slaves.
    app.config['PATCH_BASE_URI'] = None

    # name of default cluster to use for autogenerated jobs
    app.config['DEFAULT_CLUSTER'] = None

    # Maximum number of cpus allowed for a bazel executor. Since we expose `bazel.cpus` to
    # the user, this number needs to be bounded to avoid runaway resource allocation (by always
    # allocating large chunks of resources, like 12-16 cores), and to avoid invalid configuration
    # (like, requesting more cpus than available on a single slave, typically 32)
    app.config['MAX_CPUS_PER_EXECUTOR'] = 16

    # Minimum memory allowed per executor (in MB)
    app.config['MIN_MEM_MB_PER_EXECUTOR'] = 1024

    # Maximum memory allowed per executor (in MB)
    app.config['MAX_MEM_MB_PER_EXECUTOR'] = 16384

    # Maximum number of bazel executors allowed.
    app.config['MAX_EXECUTORS'] = 10

    # Absolute path to Bazel root (passed via --output_root to Bazel)
    # Storing bazel cache in tmpfs could be a bad idea because:
    #  - tmpfs means any files stored here will be stored purely in RAM and will eat into container limits
    #  - these containers are not persisted from the snapshot
    #
    # Bazel will create parent directories (if the user has appropriate permissions), if missing.
    app.config['BAZEL_ROOT_PATH'] = '/tmp/bazel_changes'

    # List of mandatory flags to be passed to `bazel test`
    app.config['BAZEL_MANDATORY_TEST_FLAGS'] = [
        '--spawn_strategy=sandboxed',
        '--genrule_strategy=sandboxed',
        '--keep_going',
    ]

    app.config['BAZEL_ADDITIONAL_TEST_FLAGS_WHITELIST_REGEX'] = [
        r'^--test_env=[A-Za-z0-9=]+',
        r'^--test_arg=[A-Za-z0-9=]+',
        r'^--define=[A-Za-z0-9=]+',
    ]

    app.config['SELECTIVE_TESTING_PROPAGATION_LIMIT'] = 30

    app.config['SELECTIVE_TESTING_ENABLED'] = False

    # Debug config entries passed to every autobazel jobstep
    app.config['BAZEL_DEBUG_CONFIG'] = {}

    # Extra test setup commands to be executed before collect-targets or `bazel test` invocations.
    app.config['BAZEL_EXTRA_SETUP_CMD'] = ['exit 0']

    # Jobsteps go from 'pending_allocation' to 'allocated' once an external scheduler claims them, and
    # once they begin running they're updated to 'in_progress'. If the scheduler somehow fails or drops
    # the task, this value is used to time out the 'allocated' status and revert back to 'pending_allocation'.
    # For current and expected schedulers, we don't allocate unless we think we can execute immediately, so
    # a 3 minute timeout is conservative and should be safe.
    app.config['JOBSTEP_ALLOCATION_TIMEOUT_SECONDS'] = 3 * 60

    app.config.update(config)

    if _read_config:
        if os.environ.get('CHANGES_CONF'):
            # CHANGES_CONF=/etc/changes.conf.py
            app.config.from_envvar('CHANGES_CONF')
        else:
            # Look for ~/.changes/changes.conf.py
            path = os.path.normpath(os.path.expanduser('~/.changes/changes.conf.py'))
            app.config.from_pyfile(path, silent=True)

    # default the DSN for changes-client to the server's DSN
    app.config.setdefault('CLIENT_SENTRY_DSN', app.config['SENTRY_DSN'])

    # Backwards compatibility with old configs containing BASE_URI
    if 'WEB_BASE_URI' not in app.config and 'BASE_URI' in app.config:
        app.config['WEB_BASE_URI'] = app.config['BASE_URI']
    if 'INTERNAL_BASE_URI' not in app.config and 'BASE_URI' in app.config:
        app.config['INTERNAL_BASE_URI'] = app.config['BASE_URI']

    parsed_url = urlparse(app.config['WEB_BASE_URI'])
    app.config.setdefault('PREFERRED_URL_SCHEME', 'https')

    if app.debug:
        app.config['SEND_FILE_MAX_AGE_DEFAULT'] = 0
    else:
        app.config['SEND_FILE_MAX_AGE_DEFAULT'] = 30

    app.url_map.converters['uuid'] = UUIDConverter

    # now that config is set up, let's ensure the CUSTOM_JS / CUSTOM_CSS
    # variables are safe (within the changes directory) and convert them to
    # absolute paths
    if app.config['WEBAPP_CUSTOM_CSS']:
        app.config['WEBAPP_CUSTOM_CSS'] = os.path.join(
            PROJECT_ROOT, 'webapp/custom/', app.config['WEBAPP_CUSTOM_CSS'])

        enforce_is_subdir(
            app.config['WEBAPP_CUSTOM_CSS'],
            os.path.join(PROJECT_ROOT, 'webapp/custom'))
    else:
        app.config['WEBAPP_CUSTOM_CSS'] = _find_custom_css()

    if app.config['WEBAPP_CUSTOM_JS']:
        app.config['WEBAPP_CUSTOM_JS'] = os.path.join(
            PROJECT_ROOT, 'webapp/custom/', app.config['WEBAPP_CUSTOM_JS'])

        enforce_is_subdir(
            app.config['WEBAPP_CUSTOM_JS'],
            os.path.join(PROJECT_ROOT, 'webapp/custom'))

    # init sentry first
    sentry.init_app(app)

    @app.before_request
    def capture_user(*args, **kwargs):
        from changes.api.auth import get_current_user
        user = get_current_user()
        if user is not None:
            sentry.client.user_context({
                'id': user.id,
                'email': user.email,
            })

    api.init_app(app)
    db.init_app(app)
    mail.init_app(app)
    queue.init_app(app)
    redis.init_app(app)
    statsreporter.init_app(app)

    configure_debug_toolbar(app)

    from raven.contrib.celery import register_signal, register_logger_signal
    register_signal(sentry.client)
    register_logger_signal(sentry.client, loglevel=logging.WARNING)

    # configure debug routes first
    if app.debug:
        configure_debug_routes(app)

    configure_templates(app)

    # TODO: these can be moved to wsgi app entrypoints
    configure_api_routes(app)
    configure_web_routes(app)

    configure_jobs(app)
    configure_transaction_logging(app)

    rules_file = app.config.get('CATEGORIZE_RULES_FILE')
    if rules_file:
        # Fail at startup if we have a bad rules file.
        categorize.load_rules(rules_file)

    import jinja2
    webapp_template_folder = os.path.join(PROJECT_ROOT, 'webapp/html')
    template_folder = os.path.join(PROJECT_ROOT, 'templates')
    template_loader = jinja2.ChoiceLoader([
                app.jinja_loader,
                jinja2.FileSystemLoader([webapp_template_folder, template_folder])
                ])
    app.jinja_loader = template_loader

    return app

Example 38

Project: changes
Source File: config.py
View license
def create_app(_read_config=True, **config):
    app = flask.Flask(__name__,
                      static_folder=None,
                      template_folder=os.path.join(PROJECT_ROOT, 'templates'))

    # app.wsgi_app = TracerMiddleware(app.wsgi_app, app)

    # This key is insecure and you should override it on the server
    app.config['SECRET_KEY'] = 't\xad\xe7\xff%\xd2.\xfe\x03\x02=\xec\xaf\\2+\xb8=\xf7\x8a\x9aLD\xb1'

    app.config['SQLALCHEMY_COMMIT_ON_TEARDOWN'] = True
    app.config['SQLALCHEMY_DATABASE_URI'] = 'postgresql:///changes'
    app.config['SQLALCHEMY_POOL_SIZE'] = 60
    app.config['SQLALCHEMY_MAX_OVERFLOW'] = 20
    # required for flask-debugtoolbar and the db perf metrics we record
    app.config['SQLALCHEMY_RECORD_QUERIES'] = True

    app.config['REDIS_URL'] = 'redis://localhost/0'
    app.config['GROUPER_API_URL'] = 'https://localhost/'
    app.config['GROUPER_PERMISSIONS_ADMIN'] = 'changes.prod.admin'
    app.config['GROUPER_PERMISSIONS_PROJECT_ADMIN'] = 'changes.prod.project.admin'
    app.config['GROUPER_EXCLUDED_ROLES'] = ['np-owner']
    app.config['DEBUG'] = True
    app.config['HTTP_PORT'] = 5000
    app.config['SEND_FILE_MAX_AGE_DEFAULT'] = 0

    app.config['BAZEL_ARTIFACT_SUFFIX'] = '.bazel'

    app.config['BAZEL_TEST_OUTPUT_RELATIVE_PATH'] = 'bazel-testlogs/'

    app.config['API_TRACEBACKS'] = True

    # Expiration delay between when a snapshot image becomes superceded and when
    # it becomes truly expired (and thus no longer included in the sync information
    # for any cluster that runs that particular image's plan)
    app.config['CACHED_SNAPSHOT_EXPIRATION_DELTA'] = timedelta(hours=1)

    # default snapshot ID to use when no project-specific active image available
    app.config['DEFAULT_SNAPSHOT'] = None
    app.config['SNAPSHOT_S3_BUCKET'] = None
    app.config['LXC_PRE_LAUNCH'] = None
    app.config['LXC_POST_LAUNCH'] = None

    # APT mirror URLs to use for new LXC containers created by changes-client.
    # NB: these aren't currently supported in the public changes-client repo.
    app.config['LXC_APT_MIRROR'] = None
    app.config['LXC_APT_SECURITY_MIRROR'] = None

    # name of the template to use for LXC (usually the name of a particular
    # Linux distro). Defaults to ubuntu.
    app.config['LXC_TEMPLATE'] = 'ubuntu'

    # Location of artifacts server that is passed to changes-client
    # (include http:// or https://)
    #
    # The default artifact server url uses a random uri which is expected to fail
    # without being overridden. This value is referenced in test code.
    app.config['ARTIFACTS_SERVER'] = 'http://localhost:1234'

    # The default max artifact size handlers should be capable of processing.
    app.config['MAX_ARTIFACT_BYTES'] = 200 * 1024 * 1024
    # The max artifact size the analytics json handler should be capable of processing.
    app.config['MAX_ARTIFACT_BYTES_ANALYTICS_JSON'] = 70 * 1024 * 1024

    # the binary to use for running changes-client. Default is just
    # "changes-client", but can also be specified as e.g. a full path.
    app.config['CHANGES_CLIENT_BINARY'] = 'changes-client'

    app.config['CHANGES_CLIENT_DEFAULT_BUILD_TYPE'] = 'legacy'

    # Base URI to use for git repos that we want to clone (currently only used
    # for the "other_repos" buildstep config). The repo name is appended
    # directly to this, so it should already contain necessary colons and
    # slashes, etc. For example, if GIT_DEFAULT_BASE_URI is `[email protected]:`
    # and a repo is specified as `changes.git`, the clone url will be
    # `[email protected]:changes.git`
    app.config['GIT_DEFAULT_BASE_URI'] = None
    # Same as GIT_DEFAULT_BASE_URI but used for mercurial repos.
    app.config['MERCURIAL_DEFAULT_BASE_URI'] = None

    # This is a hash from each build type (string identifiers used in
    # build step configuration) to a "build spec", a definition of
    # how to use changes-client to build. To use changes-client, the key
    # 'uses_client' must be set to True.
    #
    # Required build spec keys for client:
    #   adapter -> basic or lxc
    #   jenkins-command -> command to run from jenkins directly ($JENKINS_COMMAND)
    #   commands -> array of hash from script -> string that represents a script
    #
    # Optional keys (lxc-only)
    #   pre-launch -> lxc pre-launch script
    #   post-launch -> lxc post-launch script
    #   release -> lxc release
    app.config['CHANGES_CLIENT_BUILD_TYPES'] = {
        'legacy': {'uses_client': False},
    }

    app.config['CELERY_ACCEPT_CONTENT'] = ['changes_json']
    app.config['CELERY_ACKS_LATE'] = True
    app.config['CELERY_BROKER_URL'] = 'redis://localhost/0'
    app.config['CELERY_DEFAULT_QUEUE'] = "default"
    app.config['CELERY_DEFAULT_EXCHANGE'] = "default"
    app.config['CELERY_DEFAULT_EXCHANGE_TYPE'] = "direct"
    app.config['CELERY_DEFAULT_ROUTING_KEY'] = "default"
    app.config['CELERY_DISABLE_RATE_LIMITS'] = True
    app.config['CELERY_IGNORE_RESULT'] = True
    app.config['CELERY_RESULT_BACKEND'] = None
    app.config['CELERY_RESULT_SERIALIZER'] = 'changes_json'
    app.config['CELERY_SEND_EVENTS'] = False
    app.config['CELERY_TASK_RESULT_EXPIRES'] = 1
    app.config['CELERY_TASK_SERIALIZER'] = 'changes_json'
    app.config['CELERYD_PREFETCH_MULTIPLIER'] = 1
    app.config['CELERYD_MAX_TASKS_PER_CHILD'] = 10000

    # By default, Celery logs writes to stdout/stderr as WARNING, which
    # is a bit harsh considering that some of the code is code we don't
    # own calling 'print'. This flips the default back to INFO, which seems
    # more appropriate. Can be overridden by the Changes config.
    app.config['CELERY_REDIRECT_STDOUTS_LEVEL'] = 'INFO'

    app.config['CELERY_QUEUES'] = (
        Queue('job.sync', routing_key='job.sync'),
        Queue('job.create', routing_key='job.create'),
        Queue('celery', routing_key='celery'),
        Queue('events', routing_key='events'),
        Queue('default', routing_key='default'),
        Queue('delete', routing_key='delete'),
        Queue('repo.sync', Exchange('fanout', 'fanout'), routing_key='repo.sync'),
        Queue('grouper.sync', routing_key='grouper.sync'),
        Broadcast('repo.update'),
    )
    app.config['CELERY_ROUTES'] = {
        'create_job': {
            'queue': 'job.create',
            'routing_key': 'job.create',
        },
        'sync_job': {
            'queue': 'job.sync',
            'routing_key': 'job.sync',
        },
        'sync_job_step': {
            'queue': 'job.sync',
            'routing_key': 'job.sync',
        },
        'sync_build': {
            'queue': 'job.sync',
            'routing_key': 'job.sync',
        },
        'check_repos': {
            'queue': 'repo.sync',
            'routing_key': 'repo.sync',
        },
        'sync_grouper': {
            'queue': 'grouper.sync',
            'routing_key': 'grouper.sync',
        },
        'sync_repo': {
            'queue': 'repo.sync',
            'routing_key': 'repo.sync',
        },
        'run_event_listener': {
            'queue': 'events',
            'routing_key': 'events',
        },
        'fire_signal': {
            'queue': 'events',
            'routing_key': 'events',
        },
        'update_local_repos': {
            'queue': 'repo.update',
        },
        'delete_old_data': {
            'queue': 'delete',
            'routing_key': 'delete',
        },
        'delete_old_data_10m': {
            'queue': 'delete',
            'routing_key': 'delete',
        },
        'delete_old_data_5h_delayed': {
            'queue': 'delete',
            'routing_key': 'delete',
        },
    }

    app.config['EVENT_LISTENERS'] = (
        ('changes.listeners.mail.build_finished_handler', 'build.finished'),
        ('changes.listeners.green_build.revision_result_updated_handler', 'revision_result.updated'),
        ('changes.listeners.build_revision.revision_created_handler', 'revision.created'),
        ('changes.listeners.build_finished_notifier.build_finished_handler', 'build.finished'),
        ('changes.listeners.phabricator_listener.build_finished_handler', 'build.finished'),
        ('changes.listeners.analytics_notifier.build_finished_handler', 'build.finished'),
        ('changes.listeners.analytics_notifier.job_finished_handler', 'job.finished'),
        ('changes.listeners.revision_result.revision_result_build_finished_handler', 'build.finished'),
        ('changes.listeners.stats_notifier.build_finished_handler', 'build.finished'),
        ('changes.listeners.snapshot_build.build_finished_handler', 'build.finished'),
    )

    # restrict outbound notifications to the given domains
    app.config['MAIL_DOMAIN_WHITELIST'] = ()

    app.config['DEBUG_TB_ENABLED'] = True

    app.config['DEBUG_TB_PANELS'] = ('flask_debugtoolbar.panels.versions.VersionDebugPanel',
                                     'flask_debugtoolbar.panels.timer.TimerDebugPanel',
                                     'flask_debugtoolbar.panels.headers.HeaderDebugPanel',
                                     'flask_debugtoolbar.panels.request_vars.RequestVarsDebugPanel',
                                     # Disable the config vars panel by default; it can contain sensitive information.
                                     # 'flask_debugtoolbar.panels.config_vars.ConfigVarsDebugPanel',
                                     'flask_debugtoolbar.panels.template.TemplateDebugPanel',
                                     'flask_debugtoolbar.panels.sqlalchemy.SQLAlchemyDebugPanel',
                                     'flask_debugtoolbar.panels.logger.LoggingPanel',
                                     'flask_debugtoolbar.panels.profiler.ProfilerDebugPanel')

    # celerybeat must be running for our cleanup tasks to execute
    # e.g. celery worker -B
    app.config['CELERYBEAT_SCHEDULE'] = {
        'cleanup-tasks': {
            'task': 'cleanup_tasks',
            'schedule': timedelta(minutes=1),
        },
        'check-repos': {
            'task': 'check_repos',
            'schedule': timedelta(minutes=2),
        },
        'sync-grouper': {
            'task': 'sync_grouper',
            'schedule': timedelta(minutes=1),
        },
        'aggregate-flaky-tests': {
            'task': 'aggregate_flaky_tests',
            # Hour 7 GMT is midnight PST, hopefully a time of low load
            'schedule': crontab(hour=7, minute=0),
        },
        'delete-old-data-10m': {
            'task': 'delete_old_data_10m',
            'schedule': timedelta(minutes=10),
        },
        'delete-old-data-5h-delayed': {
            'task': 'delete_old_data_5h_delayed',
            # This task runs every 4 hours but looks at 5 hours worth of tests
            # so consecutive runs will look at sets of tests that will overlap.
            # This is to make it unlikely to miss tests in between.
            #
            # While this is looking at 5 hours worth of tests, this should not be long running
            # as the shorter delete tasks will catch most cases and this checks
            # a time frame that should've been cleaned by them already.
            'schedule': crontab(hour='*/4'),
        },
        'update-local-repos': {
            'task': 'update_local_repos',
            'schedule': timedelta(minutes=1),
        }
    }
    app.config['CELERY_TIMEZONE'] = 'UTC'

    app.config['SENTRY_DSN'] = None
    app.config['SENTRY_INCLUDE_PATHS'] = [
        'changes',
    ]

    app.config['KOALITY_URL'] = None
    app.config['KOALITY_API_KEY'] = None

    app.config['GOOGLE_CLIENT_ID'] = None
    app.config['GOOGLE_CLIENT_SECRET'] = None
    app.config['GOOGLE_DOMAIN'] = None

    # must be a URL-safe base64-encoded 32-byte key
    app.config['COOKIE_ENCRYPTION_KEY'] = 'theDefaultKeyIs32BytesLongAndTotallyURLSafe='

    app.config['REPO_ROOT'] = None

    app.config['DEFAULT_FILE_STORAGE'] = 'changes.storage.s3.S3FileStorage'
    app.config['S3_ACCESS_KEY'] = None
    app.config['S3_SECRET_KEY'] = None
    app.config['S3_BUCKET'] = None

    app.config['PHABRICATOR_LINK_HOST'] = None
    app.config['PHABRICATOR_API_HOST'] = None
    app.config['PHABRICATOR_USERNAME'] = None
    app.config['PHABRICATOR_CERT'] = None

    # Configuration to access Zookeeper - currently used to discover mesos master leader instance
    # E.g., if mesos master is configured to talk to zk://zk1:2181,zk2:2181/mesos,
    # set ZOOKEEPER_HOSTS = 'zk1:2181,zk2:2181'
    #     ZOOKEEPER_MESOS_MASTER_PATH = '/mesos'
    #
    # This is only used to control mesos slave offline/online status from within Changes

    # Comma-separated list of host:port (or ip:port) to Zookeeper instances.
    app.config['ZOOKEEPER_HOSTS'] = 'zk:2181'
    # Namespace within zookeeper where mesos master election is performed.
    app.config['ZOOKEEPER_MESOS_MASTER_PATH'] = '/mesos'

    # List of valid tables to be written to when reporting project analytics.
    # Analytics artifacts targeting tables not listed here will be considered invalid.
    app.config['ANALYTICS_PROJECT_TABLES'] = []
    # URL any project analytics JSON entries will be posted to.
    # Entries will be posted as JSON, with the intended table specified as 'source' in the URL params.
    app.config['ANALYTICS_PROJECT_POST_URL'] = None

    app.config['SUPPORT_CONTACT'] = 'support'

    app.config['MAIL_DEFAULT_SENDER'] = '[email protected]'
    app.config['BASE_URI'] = 'http://localhost:5000'

    # if set to a string, most (all?) of the frontend js will make API calls
    # to the host this string is set to (e.g. http://changes.bigcompany.com)
    # THIS IS JUST FOR EASIER TESTING IN DEVELOPMENT. Although it won't even
    # work in prod: you'll have to start chrome with --disable-web-security to
    # make this work. Override this in your changes.conf.py file
    app.config['WEBAPP_USE_ANOTHER_HOST'] = None

    # Custom changes content unique to your deployment. This is intended to
    # customize the look and feel, provide contextual help and add custom links
    # to other internal tools. You should put your files in webapp/custom and
    # link them here.
    #
    # e.g. /acmecorp-changes/changes.js
    #
    # Some of the custom_content hooks can show images. Assume that the webserver
    # is willing to serve any file within the directory of the js file
    app.config['WEBAPP_CUSTOM_JS'] = None
    # This can be a .less file. We import it after the variables.less,
    # so you can override them in your file
    # Note: if you change this and nothing seems to happen, try deleting
    # webapp/.webassets-cache and bundled.css. This probably won't happen, though
    # If not specified, we will search for CUSTOM_CSS_FILE in the custom dir.
    app.config['WEBAPP_CUSTOM_CSS'] = None

    # In minutes, the timeout applied to jobs without a timeout specified at build time.
    # A timeout should nearly always be specified; this is just a safeguard so that
    # unspecified timeout doesn't mean "is allowed to run indefinitely".
    app.config['DEFAULT_JOB_TIMEOUT_MIN'] = 60

    # Number of milliseconds a transaction can run before triggering a warning.
    app.config['TRANSACTION_MS_WARNING_THRESHOLD'] = 2500

    # Hard maximum number of jobsteps to retry for a given job
    app.config['JOBSTEP_RETRY_MAX'] = 6
    # Maximum number of machines that we'll retry jobsteps for. This allows us
    # to retry more jobsteps if it's always the same machine failing.
    app.config['JOBSTEP_MACHINE_RETRY_MAX'] = 2

    # the PHID of the user creating quarantine tasks. We can use this to show
    # the list of open quarantine tasks inline
    app.config['QUARANTINE_PHID'] = None

    # The max length a test's output to be stored. If it is longer, the it will
    # be truncated.
    app.config['TEST_MESSAGE_MAX_LEN'] = 64 * 1024

    # List of packages needed to install bazel and any environment.
    app.config['BAZEL_APT_PKGS'] = ['bazel']

    # rsync source for encap
    # Example: rsync://example.com/encap/
    app.config['ENCAP_RSYNC_URL'] = None

    # In some configurations, build slaves might not have access to the Changes API via the
    # normal address; if PATCH_BASE_URI is specified, it'll be used as the base URI for
    # PATCH_URI variables provided to build slaves.
    app.config['PATCH_BASE_URI'] = None

    # name of default cluster to use for autogenerated jobs
    app.config['DEFAULT_CLUSTER'] = None

    # Maximum number of cpus allowed for a bazel executor. Since we expose `bazel.cpus` to
    # the user, this number needs to be bounded to avoid runaway resource allocation (by always
    # allocating large chunks of resources, like 12-16 cores), and to avoid invalid configuration
    # (like, requesting more cpus than available on a single slave, typically 32)
    app.config['MAX_CPUS_PER_EXECUTOR'] = 16

    # Minimum memory allowed per executor (in MB)
    app.config['MIN_MEM_MB_PER_EXECUTOR'] = 1024

    # Maximum memory allowed per executor (in MB)
    app.config['MAX_MEM_MB_PER_EXECUTOR'] = 16384

    # Maximum number of bazel executors allowed.
    app.config['MAX_EXECUTORS'] = 10

    # Absolute path to Bazel root (passed via --output_root to Bazel)
    # Storing bazel cache in tmpfs could be a bad idea because:
    #  - tmpfs means any files stored here will be stored purely in RAM and will eat into container limits
    #  - these containers are not persisted from the snapshot
    #
    # Bazel will create parent directories (if the user has appropriate permissions), if missing.
    app.config['BAZEL_ROOT_PATH'] = '/tmp/bazel_changes'

    # List of mandatory flags to be passed to `bazel test`
    app.config['BAZEL_MANDATORY_TEST_FLAGS'] = [
        '--spawn_strategy=sandboxed',
        '--genrule_strategy=sandboxed',
        '--keep_going',
    ]

    app.config['BAZEL_ADDITIONAL_TEST_FLAGS_WHITELIST_REGEX'] = [
        r'^--test_env=[A-Za-z0-9=]+',
        r'^--test_arg=[A-Za-z0-9=]+',
        r'^--define=[A-Za-z0-9=]+',
    ]

    app.config['SELECTIVE_TESTING_PROPAGATION_LIMIT'] = 30

    app.config['SELECTIVE_TESTING_ENABLED'] = False

    # Debug config entries passed to every autobazel jobstep
    app.config['BAZEL_DEBUG_CONFIG'] = {}

    # Extra test setup commands to be executed before collect-targets or `bazel test` invocations.
    app.config['BAZEL_EXTRA_SETUP_CMD'] = ['exit 0']

    # Jobsteps go from 'pending_allocation' to 'allocated' once an external scheduler claims them, and
    # once they begin running they're updated to 'in_progress'. If the scheduler somehow fails or drops
    # the task, this value is used to time out the 'allocated' status and revert back to 'pending_allocation'.
    # For current and expected schedulers, we don't allocate unless we think we can execute immediately, so
    # a 3 minute timeout is conservative and should be safe.
    app.config['JOBSTEP_ALLOCATION_TIMEOUT_SECONDS'] = 3 * 60

    app.config.update(config)

    if _read_config:
        if os.environ.get('CHANGES_CONF'):
            # CHANGES_CONF=/etc/changes.conf.py
            app.config.from_envvar('CHANGES_CONF')
        else:
            # Look for ~/.changes/changes.conf.py
            path = os.path.normpath(os.path.expanduser('~/.changes/changes.conf.py'))
            app.config.from_pyfile(path, silent=True)

    # default the DSN for changes-client to the server's DSN
    app.config.setdefault('CLIENT_SENTRY_DSN', app.config['SENTRY_DSN'])

    # Backwards compatibility with old configs containing BASE_URI
    if 'WEB_BASE_URI' not in app.config and 'BASE_URI' in app.config:
        app.config['WEB_BASE_URI'] = app.config['BASE_URI']
    if 'INTERNAL_BASE_URI' not in app.config and 'BASE_URI' in app.config:
        app.config['INTERNAL_BASE_URI'] = app.config['BASE_URI']

    parsed_url = urlparse(app.config['WEB_BASE_URI'])
    app.config.setdefault('PREFERRED_URL_SCHEME', 'https')

    if app.debug:
        app.config['SEND_FILE_MAX_AGE_DEFAULT'] = 0
    else:
        app.config['SEND_FILE_MAX_AGE_DEFAULT'] = 30

    app.url_map.converters['uuid'] = UUIDConverter

    # now that config is set up, let's ensure the CUSTOM_JS / CUSTOM_CSS
    # variables are safe (within the changes directory) and convert them to
    # absolute paths
    if app.config['WEBAPP_CUSTOM_CSS']:
        app.config['WEBAPP_CUSTOM_CSS'] = os.path.join(
            PROJECT_ROOT, 'webapp/custom/', app.config['WEBAPP_CUSTOM_CSS'])

        enforce_is_subdir(
            app.config['WEBAPP_CUSTOM_CSS'],
            os.path.join(PROJECT_ROOT, 'webapp/custom'))
    else:
        app.config['WEBAPP_CUSTOM_CSS'] = _find_custom_css()

    if app.config['WEBAPP_CUSTOM_JS']:
        app.config['WEBAPP_CUSTOM_JS'] = os.path.join(
            PROJECT_ROOT, 'webapp/custom/', app.config['WEBAPP_CUSTOM_JS'])

        enforce_is_subdir(
            app.config['WEBAPP_CUSTOM_JS'],
            os.path.join(PROJECT_ROOT, 'webapp/custom'))

    # init sentry first
    sentry.init_app(app)

    @app.before_request
    def capture_user(*args, **kwargs):
        from changes.api.auth import get_current_user
        user = get_current_user()
        if user is not None:
            sentry.client.user_context({
                'id': user.id,
                'email': user.email,
            })

    api.init_app(app)
    db.init_app(app)
    mail.init_app(app)
    queue.init_app(app)
    redis.init_app(app)
    statsreporter.init_app(app)

    configure_debug_toolbar(app)

    from raven.contrib.celery import register_signal, register_logger_signal
    register_signal(sentry.client)
    register_logger_signal(sentry.client, loglevel=logging.WARNING)

    # configure debug routes first
    if app.debug:
        configure_debug_routes(app)

    configure_templates(app)

    # TODO: these can be moved to wsgi app entrypoints
    configure_api_routes(app)
    configure_web_routes(app)

    configure_jobs(app)
    configure_transaction_logging(app)

    rules_file = app.config.get('CATEGORIZE_RULES_FILE')
    if rules_file:
        # Fail at startup if we have a bad rules file.
        categorize.load_rules(rules_file)

    import jinja2
    webapp_template_folder = os.path.join(PROJECT_ROOT, 'webapp/html')
    template_folder = os.path.join(PROJECT_ROOT, 'templates')
    template_loader = jinja2.ChoiceLoader([
                app.jinja_loader,
                jinja2.FileSystemLoader([webapp_template_folder, template_folder])
                ])
    app.jinja_loader = template_loader

    return app

Example 39

Project: tractor
Source File: kick-tires.py
View license
def galaxies():
    ps = PlotSequence('kick')
    plt.subplots_adjust(top=0.95, bottom=0.1, left=0.1, right=0.95)
    brick = '3166p025'

    decals = Decals()
    b = decals.get_brick_by_name(brick)
    brickwcs = wcs_for_brick(b)
    
    # A catalog of sources overlapping one DECaLS CCD, arbitrarily:
    # python projects/desi/forced-photom-decam.py decam/CP20140810_g_v2/c4d_140816_032035_ooi_g_v2.fits.fz 1 DR1 f.fits
    #T = fits_table('cat.fits')

    T = fits_table(os.path.join('dr1', 'tractor', brick[:3],
                                'tractor-%s.fits' % brick))
    print(len(T), 'catalog sources')
    print(np.unique(T.brick_primary))
    T.cut(T.brick_primary)
    print(len(T), 'primary')

    print('Out of bounds:', np.unique(T.out_of_bounds))
    print('Left blob:', np.unique(T.left_blob))
    
    img = plt.imread(os.path.join('dr1', 'coadd', brick[:3], brick,
                                  'decals-%s-image.jpg' % brick))
    img = img[::-1,:,:]
    print('Image:', img.shape)

    if False:
        resid = plt.imread(os.path.join('dr1', 'coadd', brick[:3], brick,
                                      'decals-%s-resid.jpg' % brick))
        resid = resid[::-1,:,:]

    
    T.shapeexp_e1_err = 1./np.sqrt(T.shapeexp_e1_ivar)
    T.shapeexp_e2_err = 1./np.sqrt(T.shapeexp_e2_ivar)
    T.shapeexp_r_err  = 1./np.sqrt(T.shapeexp_r_ivar)
    T.shapedev_e1_err = 1./np.sqrt(T.shapedev_e1_ivar)
    T.shapedev_e2_err = 1./np.sqrt(T.shapedev_e2_ivar)
    T.shapedev_r_err  = 1./np.sqrt(T.shapedev_r_ivar)

    T.gflux = T.decam_flux[:,1]
    T.rflux = T.decam_flux[:,2]
    T.zflux = T.decam_flux[:,4]
    
    I = np.flatnonzero(T.type == 'EXP ')
    J = np.flatnonzero(T.type == 'DEV ')

    E = T[I]
    D = T[J]

    cutobjs = []

    cut = np.logical_or(E.shapeexp_e1_err > 1., E.shapeexp_e2_err > 1.)
    I = np.flatnonzero(cut)
    print(len(I), 'EXP with large ellipticity error')
    cutobjs.append((E[I], 'EXP ellipticity error > 1'))

    E.cut(np.logical_not(cut))
    
    I = np.flatnonzero(np.logical_or(D.shapedev_e1_err > 1., D.shapedev_e2_err > 1.))
    print(len(I), 'DEV with large ellipticity error')
    cutobjs.append((D[I], 'DEV ellipticity error > 1'))

    I = np.flatnonzero(np.logical_or(np.abs(E.shapeexp_e1) > 0.5,
                                     np.abs(E.shapeexp_e2) > 0.5))
    cutobjs.append((E[I], 'EXP with ellipticity > 0.5'))

    I = np.flatnonzero(np.logical_or(np.abs(D.shapedev_e1) > 0.5,
                                     np.abs(D.shapedev_e2) > 0.5))
    cutobjs.append((E[I], 'DEV with ellipticity > 0.5'))

    I = np.flatnonzero(np.logical_or(E.shapeexp_e1_err < 3e-3,
                                     E.shapeexp_e2_err < 3e-3))
    cutobjs.append((E[I], 'EXP with small ellipticity errors (<3e-3)'))

    I = np.flatnonzero(np.logical_or(D.shapedev_e1_err < 3e-3,
                                     D.shapedev_e2_err < 3e-3))
    cutobjs.append((D[I], 'DEV with small ellipticity errors (<3e-3)'))

    I = np.flatnonzero(D.shapedev_r > 10.)
    cutobjs.append((D[I], 'DEV with large radius (>10")'))

    I = np.flatnonzero(D.shapedev_r_err < 2e-3)
    cutobjs.append((D[I], 'DEV with small radius errors (<2e-3)'))

    I = np.flatnonzero((D.rflux > 100.) * (D.shapedev_r < 5.))
    cutobjs.append((D[I], 'DEV, small & bright'))

    I = np.flatnonzero((E.rflux > 100.) * (E.shapeexp_r < 5.))
    cutobjs.append((E[I], 'EXP, small & bright'))
    
    # I = np.argsort(-T.decam_flux[:,2])
    # cutobjs.append((T[I], 'brightest objects'))

    I = np.flatnonzero(np.logical_or(D.rflux < -5., D.gflux < -5))
    cutobjs.append((D[I], 'DEV with neg g or r flux'))

    I = np.flatnonzero(np.logical_or(E.rflux < -5., E.gflux < -5))
    cutobjs.append((E[I], 'EXP with neg g or r flux'))
    
    I = np.flatnonzero(T.decam_rchi2[:,2] > 5.)
    cutobjs.append((T[I], 'rchi2 > 5'))

    
    plt.subplots_adjust(left=0.1, right=0.95, bottom=0.1, top=0.95,
                        hspace=0.05, wspace=0.05)

        
    # plt.clf()
    # p1 = plt.semilogy(T.shapeexp_e1[I], T.shapeexp_e1_ivar[I], 'b.')
    # p2 = plt.semilogy(T.shapeexp_e2[I], T.shapeexp_e2_ivar[I], 'r.')
    # plt.xlabel('Ellipticity e')
    # plt.ylabel('Ellipticity inverse-variance e_ivar')
    # plt.title('EXP galaxies')
    # plt.legend([p1[0],p2[0]], ['e1','e2'])
    # ps.savefig()

    plt.clf()
    p1 = plt.semilogy(E.shapeexp_e1, E.shapeexp_e1_err, 'b.')
    p2 = plt.semilogy(E.shapeexp_e2, E.shapeexp_e2_err, 'r.')
    plt.xlabel('Ellipticity e')
    plt.ylabel('Ellipticity error e_err')
    plt.title('EXP galaxies')
    plt.legend([p1[0],p2[0]], ['e1','e2'])
    ps.savefig()
    
    # plt.clf()
    # p1 = plt.semilogy(T.shapedev_e1[J], T.shapedev_e1_ivar[J], 'b.')
    # p2 = plt.semilogy(T.shapedev_e2[J], T.shapedev_e2_ivar[J], 'r.')
    # plt.xlabel('Ellipticity e')
    # plt.ylabel('Ellipticity inverse-variance e_ivar')
    # plt.title('DEV galaxies')
    # plt.legend([p1[0],p2[0]], ['e1','e2'])
    # ps.savefig()

    plt.clf()
    p1 = plt.semilogy(D.shapedev_e1, D.shapedev_e1_err, 'b.')
    p2 = plt.semilogy(D.shapedev_e2, D.shapedev_e2_err, 'r.')
    plt.xlabel('Ellipticity e')
    plt.ylabel('Ellipticity error e_err')
    plt.title('DEV galaxies')
    plt.legend([p1[0],p2[0]], ['e1','e2'])
    ps.savefig()


    plt.clf()
    p1 = plt.loglog(D.shapedev_r, D.shapedev_r_err, 'b.')
    p2 = plt.loglog(E.shapeexp_r, E.shapeexp_r_err, 'r.')
    plt.xlabel('Radius r')
    plt.ylabel('Radius error r_err')
    plt.title('DEV, EXP galaxies')
    plt.legend([p1[0],p2[0]], ['deV','exp'])
    ps.savefig()



    plt.clf()
    p1 = plt.loglog(D.rflux, D.shapedev_r, 'b.')
    p2 = plt.loglog(E.rflux, E.shapeexp_r, 'r.')
    plt.xlabel('r-band flux')
    plt.ylabel('Radius r')
    plt.title('DEV, EXP galaxies')
    plt.legend([p1[0],p2[0]], ['deV','exp'])
    ps.savefig()



    plt.clf()
    p1 = plt.loglog(-D.rflux, D.shapedev_r, 'b.')
    p2 = plt.loglog(-E.rflux, E.shapeexp_r, 'r.')
    plt.xlabel('Negative r-band flux')
    plt.ylabel('Radius r')
    plt.title('DEV, EXP galaxies')
    plt.legend([p1[0],p2[0]], ['deV','exp'])
    ps.savefig()

    plt.clf()
    plt.loglog(D.rflux, D.decam_rchi2[:,2], 'b.')
    plt.loglog(E.rflux, E.decam_rchi2[:,2], 'r.')
    plt.xlabel('r-band flux')
    plt.ylabel('rchi2 in r')
    plt.title('DEV, EXP galaxies')
    plt.legend([p1[0],p2[0]], ['deV','exp'])
    ps.savefig()



    for objs,desc in cutobjs:
        if len(objs) == 0:
            print('No objects in cut', desc)
            continue

        rows,cols = 4,6
        objs = objs[:rows*cols]

        if False:
            plt.clf()
            dimshow(img)
            plt.plot(objs.bx, objs.by, 'rx')
            ps.savefig()

        plt.clf()
        plt.subplots_adjust(left=0.05, right=0.95, bottom=0.05, top=0.95,
                            hspace=0.05, wspace=0.05)

        plot_objects(objs, None, img, brickwcs)
        plt.suptitle(desc)
        ps.savefig()

        plt.clf()
        plot_objects(objs, T, img, brickwcs, fracflux=True)
        plt.suptitle(desc)
        ps.savefig()
        
        if False:
            plt.clf()
            for i,o in enumerate(objs):
                plt.subplot(rows, cols, i+1)
                H,W,three = img.shape
                dimshow(resid[o.y0:min(H, o.by+S),
                              o.x0:min(W, o.bx+S), :], ticks=False)
            plt.suptitle(desc)
            ps.savefig()

    
    

    sys.exit(0)

    
    print('RA', T.ra.min(), T.ra.max())
    print('Dec', T.dec.min(), T.dec.max())

    # Uhh, how does *this* happen?!  Fitting gone wild I guess
    # T.cut((T.ra > 0) * (T.ra < 360) * (T.dec > -90) * (T.dec < 90))
    # print 'RA', T.ra.min(), T.ra.max()
    # print 'Dec', T.dec.min(), T.dec.max()
    # rlo,rhi = [np.percentile(T.ra,  p) for p in [1,99]]
    # dlo,dhi = [np.percentile(T.dec, p) for p in [1,99]]
    # print 'RA', rlo,rhi
    # print 'Dec', dlo,dhi
    # plt.clf()
    # plothist(T.ra, T.dec, 100, range=((rlo,rhi),(dlo,dhi)))
    # plt.xlabel('RA')
    # plt.ylabel('Dec')
    # ps.savefig()
    
    # decals = Decals()
    # B = decals.get_bricks()
    # #B.about()
    # brick = B[B.brickid == brickid]
    # assert(len(brick) == 1)
    # brick = brick[0]
    # wcs = wcs_for_brick(brick)
    # ccds = decals.get_ccds()
    # ccds.cut(ccds_touching_wcs(wcs, ccds))
    # print len(ccds), 'CCDs'
    # #ccds.about()
    # ccds.cut(ccds.filter == 'r')
    # print len(ccds), 'CCDs'
    # S = []
    # for ccd in ccds:
    #     im = DecamImage(ccd)
    #     S.append(fits_table(im.sdssfn))
    # S = merge_tables(S)
    # print len(S), 'total SDSS'
    # #nil,I = np.unique(S.ra, return_index=True)
    # nil,I = np.unique(['%.5f %.5f' % (r,d) for r,d in zip(S.ra,S.dec)], return_index=True)
    # S.cut(I)
    # print len(S), 'unique'
    # 
    # I,J,d = match_radec(T.ra, T.dec, S.ra, S.dec, 1./3600.)
    # print len(I), 'matches'
    # 
    # plt.clf()
    # plt.loglog(S.r_psfflux[J], T.decam_r_nanomaggies[I], 'r.')
    # ps.savefig()

    #plt.clf()
    #plt.loglog(T.sdss_modelflux[:,2], T.decam_r_nanomaggies, 'r.')
    #ps.savefig()

    
    for bindex,band in [(1,'g'), (2,'r'), (4,'z')]:
        sflux = T.sdss_modelflux[:,bindex]
        dflux = T.get('decam_%s_nanomaggies' % band)
        I = np.flatnonzero(sflux > 10.)
        med = np.median(dflux[I] / sflux[I])
        # plt.clf()
        # plt.loglog(sflux, dflux / sflux, 'ro', mec='r', ms=4, alpha=0.1)
        # plt.axhline(med, color='k')
        # plt.ylim(0.5, 2.)
        # ps.savefig()

        corr = dflux / med
        T.set('decam_%s_nanomaggies_corr' % band, corr)
        T.set('decam_%s_mag_corr' % band, NanoMaggies.nanomaggiesToMag(corr))

        dflux = T.get('decam_%s_nanomaggies_corr' % band)
        plt.clf()
        #plt.loglog(sflux, dflux / sflux, 'o', mec='b', ms=4, alpha=0.1)
        plt.loglog(sflux, dflux / sflux, 'b.', alpha=0.01)
        plt.xlim(1e-1, 3e3)
        plt.axhline(1., color='k')
        plt.ylim(0.5, 2.)
        plt.xlabel('SDSS flux (nmgy)')
        plt.ylabel('DECam flux / SDSS flux')
        plt.title('%s band' % band)
        ps.savefig()
        
    bands = 'grz'

    # for band in bands:
    #     plt.clf()
    #     sn = T.get('decam_%s_nanomaggies' % band) * np.sqrt(T.get('decam_%s_nanomaggies_invvar' % band))
    #     mag = T.get('decam_%s_mag_corr' % band)
    #     plt.semilogy(mag, sn, 'b.')
    #     plt.axis([20, 26, 1, 100])
    #     ps.savefig()

    ccmap = dict(g='g', r='r', z='m')
    plt.clf()
    for band in bands:
        sn = T.get('decam_%s_nanomaggies' % band) * np.sqrt(T.get('decam_%s_nanomaggies_invvar' % band))
        mag = T.get('decam_%s_mag_corr' % band)
        cc = ccmap[band]
        #plt.semilogy(mag, sn, '.', color=cc, alpha=0.2)
        plt.semilogy(mag, sn, '.', color=cc, alpha=0.01, mec='none')
    plt.xlabel('mag')
    plt.ylabel('Flux Signal-to-Noise')
    tt = [1,2,3,4,5,10,20,30,40,50]
    plt.yticks(tt, ['%i' % t for t in tt])
    plt.axhline(5., color='k')
    plt.axis([21, 26, 1, 20])
    plt.title('DECaLS depth')
    ps.savefig()



    [gsn,rsn,zsn] = [T.get('decam_%s_nanomaggies' % band) * np.sqrt(T.get('decam_%s_nanomaggies_invvar' % band))
                     for band in bands]
    TT = T[(gsn > 5.) * (rsn > 5.) * (zsn > 5.)]
    
    # plt.clf()
    # plt.plot(g-r, r-z, 'k.', alpha=0.2)
    # plt.xlabel('g - r (mag)')
    # plt.ylabel('r - z (mag)')
    # plt.xlim(-0.5, 2.5)
    # plt.ylim(-0.5, 3)
    # ps.savefig()

    plt.clf()
    lp = []
    cut = (TT.sdss_objc_type == 6)
    g,r,z = [NanoMaggies.nanomaggiesToMag(TT.sdss_psfflux[:,i])
             for i in [1,2,4]]
    p = plt.plot((g-r)[cut], (r-z)[cut], '.', alpha=0.3, color='b')
    lp.append(p[0])
    cut = (TT.sdss_objc_type == 3)
    g,r,z = [NanoMaggies.nanomaggiesToMag(TT.sdss_modelflux[:,i])
             for i in [1,2,4]]
    p = plt.plot((g-r)[cut], (r-z)[cut], '.', alpha=0.3, color='r')
    lp.append(p[0])
    plt.xlabel('g - r (mag)')
    plt.ylabel('r - z (mag)')
    plt.xlim(-0.5, 2.5)
    plt.ylim(-0.5, 3)
    plt.legend(lp, ['stars', 'galaxies'])
    plt.title('SDSS')
    ps.savefig()


    g = TT.decam_g_mag_corr
    r = TT.decam_r_mag_corr
    z = TT.decam_z_mag_corr

    plt.clf()
    lt,lp = [],[]
    for cut,cc,tt in [(TT.sdss_objc_type == 6, 'b', 'stars'),
                      (TT.sdss_objc_type == 3, 'r', 'galaxies'),
                      (TT.sdss_objc_type == 0, 'g', 'faint')]:
        p = plt.plot((g-r)[cut], (r-z)[cut], '.', alpha=0.3, color=cc)
        lt.append(tt)
        lp.append(p[0])
    plt.xlabel('g - r (mag)')
    plt.ylabel('r - z (mag)')
    plt.xlim(-0.5, 2.5)
    plt.ylim(-0.5, 3)
    plt.legend(lp, lt)
    plt.title('DECaLS')
    ps.savefig()

    

    # Stars/galaxies in subplots

    plt.clf()
    lp = []
    cut = (TT.sdss_objc_type == 6)
    g,r,z = [NanoMaggies.nanomaggiesToMag(TT.sdss_psfflux[:,i])
             for i in [1,2,4]]
    plt.subplot(1,2,1)
    p = plt.plot((g-r)[cut], (r-z)[cut], '.', alpha=0.02, color='b')
    px = plt.plot(100, 100, '.', color='b')
    lp.append(px[0])
    plt.xlabel('g - r (mag)')
    plt.ylabel('r - z (mag)')
    plt.xlim(-0.5, 2.5)
    plt.ylim(-0.5, 3)
    cut = (TT.sdss_objc_type == 3)
    g,r,z = [NanoMaggies.nanomaggiesToMag(TT.sdss_modelflux[:,i])
             for i in [1,2,4]]
    plt.subplot(1,2,2)
    p = plt.plot((g-r)[cut], (r-z)[cut], '.', alpha=0.02, color='r')
    px = plt.plot(100, 100, '.', color='r')
    lp.append(px[0])
    plt.xlabel('g - r (mag)')
    plt.ylabel('r - z (mag)')
    plt.xlim(-0.5, 2.5)
    plt.ylim(-0.5, 3)
    plt.figlegend(lp, ['stars', 'galaxies'], 'upper right')
    plt.suptitle('SDSS')
    ps.savefig()

    g = TT.decam_g_mag_corr
    r = TT.decam_r_mag_corr
    z = TT.decam_z_mag_corr

    plt.clf()
    lt,lp = [],[]
    for i,(cut,cc,tt) in enumerate([
        (TT.sdss_objc_type == 6, 'b', 'stars'),
        (TT.sdss_objc_type == 3, 'r', 'galaxies'),
        #(TT.sdss_objc_type == 0, 'g', 'faint'),
        ]):
        plt.subplot(1,2,i+1)
        p = plt.plot((g-r)[cut], (r-z)[cut], '.', alpha=0.02, color=cc)
        lt.append(tt)
        px = plt.plot(100, 100, '.', color=cc)
        lp.append(px[0])
        plt.xlabel('g - r (mag)')
        plt.ylabel('r - z (mag)')
        plt.xlim(-0.5, 2.5)
        plt.ylim(-0.5, 3)
    plt.figlegend(lp, lt, 'upper right')
    plt.suptitle('DECaLS')
    ps.savefig()

Example 40

Project: tractor
Source File: forcedphoht-des-wise.py
View license
def one_tile(tile, opt, savepickle, ps, tiles, tiledir, tempoutdir, T=None, hdr=None):

    bands = opt.bands
    outfn = opt.output % (tile.coadd_id)
    savewise_outfn = opt.save_wise_output % (tile.coadd_id)

    sband = 'r'
    bandnum = 'ugriz'.index(sband)

    tt0 = Time()
    print()
    print('Coadd tile', tile.coadd_id)

    thisdir = get_unwise_tile_dir(tiledir, tile.coadd_id)
    fn = os.path.join(thisdir, 'unwise-%s-w%i-img-m.fits' % (tile.coadd_id, bands[0]))
    if os.path.exists(fn):
        print('Reading', fn)
        wcs = Tan(fn)
    else:
        print('File', fn, 'does not exist; faking WCS')
        from unwise_coadd import get_coadd_tile_wcs
        wcs = get_coadd_tile_wcs(tile.ra, tile.dec)

    r0,r1,d0,d1 = wcs.radec_bounds()
    print('RA,Dec bounds:', r0,r1,d0,d1)
    H,W = wcs.get_height(), wcs.get_width()

    if T is None:
        
        T = merge_tables([fits_table(fn, columns=[x.upper() for x in [
            'chi2_psf_r', 'chi2_model_r', 'mag_psf_r',
            #'mag_disk_r',
            'mag_spheroid_r', 'spheroid_reff_world',
            'spheroid_aspect_world', 'spheroid_theta_world',
            #'disk_scale_world', 'disk_aspect_world',
            #'disk_theta_world',
            'alphamodel_j2000', 'deltamodel_j2000']],
                                     column_map=dict(CHI2_PSF_R='chi2_psf',
                                                     CHI2_MODEL_R='chi2_model',
                                                     MAG_PSF_R='mag_psf',
                                                     MAG_SPHEROID_R='mag_spheroid_r',
                                                     ))
                          for fn in
                          ['DES_SNX3cat_000001.fits', 'DES_SNX3cat_000002.fits']]
                          )
        T.mag_disk = np.zeros(len(T), np.float32) + 99.
        print('Read total of', len(T), 'DES sources')
        ok,T.x,T.y = wcs.radec2pixelxy(T.alphamodel_j2000, T.deltamodel_j2000)
        margin = int(60. * wcs.pixel_scale())
        print('Margin:', margin, 'pixels')
        T.cut((T.x > -margin) * (T.x < (W+margin)) *
              (T.y > -margin) * (T.y < (H+margin)))
        print('Cut to', len(T), 'in bounds')
        if opt.photoObjsOnly:
            return
    print(len(T), 'objects')
    if len(T) == 0:
        return

    defaultflux = 1.

    # hack
    T.x = (T.x - 1.).astype(np.float32)
    T.y = (T.y - 1.).astype(np.float32)
    margin = 20.
    I = np.flatnonzero((T.x >= -margin) * (T.x < W+margin) *
                       (T.y >= -margin) * (T.y < H+margin))
    T.cut(I)
    print('Cut to margins: N objects:', len(T))
    if len(T) == 0:
        return

    wanyband = wband = 'w'

    classmap = {}

    print('Creating tractor sources...')
    cat = get_se_modelfit_cat(T, bands=[wanyband])
    print('Created', len(T), 'sources')
    assert(len(cat) == len(T))

    pixscale = wcs.pixel_scale()
    # crude intrinsic source radii, in pixels
    sourcerad = np.zeros(len(cat))
    for i in range(len(cat)):
        src = cat[i]
        if isinstance(src, PointSource):
            continue
        elif isinstance(src, HoggGalaxy):
            sourcerad[i] = (src.nre * src.shape.re / pixscale)
        elif isinstance(src, FixedCompositeGalaxy):
            sourcerad[i] = max(src.shapeExp.re * ExpGalaxy.nre,
                               src.shapeDev.re * DevGalaxy.nre) / pixscale
    print('sourcerad range:', min(sourcerad), max(sourcerad))

    # Find WISE-only catalog sources
    wfn = os.path.join(tempoutdir, 'wise-sources-%s.fits' % (tile.coadd_id))
    WISE = read_wise_sources(wfn, r0,r1,d0,d1, allwise=True)

    for band in bands:
        mag = WISE.get('w%impro' % band)
        nm = NanoMaggies.magToNanomaggies(mag)
        WISE.set('w%inm' % band, nm)
        print('Band', band, 'max WISE catalog flux:', max(nm))
        print('  (min mag:', mag.min(), ')')

    unmatched = np.ones(len(WISE), bool)
    I,J,d = match_radec(WISE.ra, WISE.dec, T.ra, T.dec, 4./3600.)
    unmatched[I] = False
    UW = WISE[unmatched]
    print('Got', len(UW), 'unmatched WISE sources')

    if opt.savewise:
        fitwiseflux = {}
        for band in bands:
            fitwiseflux[band] = np.zeros(len(UW))

    # Record WISE fluxes for catalog matches.
    # (this provides decent initialization for 'minsb' approx.)
    wiseflux = {}
    for band in bands:
        wiseflux[band] = np.zeros(len(T))
        if len(I) == 0:
            continue
        # X[I] += Y[J] with duplicate I doesn't work.
        #wiseflux[band][J] += WISE.get('w%inm' % band)[I]
        lhs = wiseflux[band]
        rhs = WISE.get('w%inm' % band)[I]
        print('Band', band, 'max matched WISE flux:', max(rhs))
        for j,f in zip(J, rhs):
            lhs[j] += f

    ok,UW.x,UW.y = wcs.radec2pixelxy(UW.ra, UW.dec)
    UW.x -= 1.
    UW.y -= 1.

    T.coadd_id = np.array([tile.coadd_id] * len(T))

    inbounds = np.flatnonzero((T.x >= -0.5) * (T.x < W-0.5) *
                              (T.y >= -0.5) * (T.y < H-0.5))

    print('Before looping over bands:', Time()-tt0)
   
    for band in bands:
        tb0 = Time()
        print()
        print('Coadd tile', tile.coadd_id)
        print('Band', band)
        wband = 'w%i' % band

        imfn = os.path.join(thisdir, 'unwise-%s-w%i-img-m.fits'    % (tile.coadd_id, band))
        ivfn = os.path.join(thisdir, 'unwise-%s-w%i-invvar-m.fits.gz' % (tile.coadd_id, band))
        ppfn = os.path.join(thisdir, 'unwise-%s-w%i-std-m.fits.gz'    % (tile.coadd_id, band))
        nifn = os.path.join(thisdir, 'unwise-%s-w%i-n-m.fits.gz'      % (tile.coadd_id, band))

        print('Reading', imfn)
        wcs = Tan(imfn)
        r0,r1,d0,d1 = wcs.radec_bounds()
        print('RA,Dec bounds:', r0,r1,d0,d1)
        ra,dec = wcs.radec_center()
        print('Center:', ra,dec)
        img = fitsio.read(imfn)
        print('Reading', ivfn)
        invvar = fitsio.read(ivfn)
        print('Reading', ppfn)
        pp = fitsio.read(ppfn)
        print('Reading', nifn)
        nims = fitsio.read(nifn)
        print('Median # ims:', np.median(nims))

        good = (nims > 0)
        invvar[np.logical_not(good)] = 0.

        sig1 = 1./np.sqrt(np.median(invvar[good]))
        minsig = getattr(opt, 'minsig%i' % band)
        minsb = sig1 * minsig
        print('Sigma1:', sig1, 'minsig', minsig, 'minsb', minsb)

        # Load the average PSF model (generated by wise_psf.py)
        print('Reading PSF from', opt.psffn)
        P = fits_table(opt.psffn, hdu=band)
        psf = GaussianMixturePSF(P.amp, P.mean, P.var)

        # Render the PSF profile for figuring out source radii for
        # approximation purposes.
        R = 100
        psf.radius = R
        pat = psf.getPointSourcePatch(0., 0.)
        assert(pat.x0 == pat.y0)
        assert(pat.x0 == -R)
        psfprofile = pat.patch[R, R:]
        #print 'PSF profile:', psfprofile

        # Reset default flux based on min radius
        defaultflux = minsb / psfprofile[opt.minradius]
        print('Setting default flux', defaultflux)

        # Set WISE source radii based on flux
        UW.rad = np.zeros(len(UW), int)
        wnm = UW.get('w%inm' % band)
        for r,pro in enumerate(psfprofile):
            flux = minsb / pro
            UW.rad[wnm > flux] = r
        UW.rad = np.maximum(UW.rad + 1, 3)

        # Set SDSS fluxes based on WISE catalog matches.
        wf = wiseflux[band]
        I = np.flatnonzero(wf > defaultflux)
        wfi = wf[I]
        print('Initializing', len(I), 'fluxes based on catalog matches')
        for i,flux in zip(I, wf[I]):
            assert(np.isfinite(flux))
            cat[i].getBrightness().setBand(wanyband, flux)

        # Set SDSS radii based on WISE flux
        rad = np.zeros(len(I), int)
        for r,pro in enumerate(psfprofile):
            flux = minsb / pro
            rad[wfi > flux] = r
        srad2 = np.zeros(len(cat), int)
        srad2[I] = rad
        del rad

        # Set radii
        for i in range(len(cat)):
            src = cat[i]
            # set fluxes
            b = src.getBrightness()
            if b.getBand(wanyband) <= defaultflux:
                b.setBand(wanyband, defaultflux)
                
            R = max([opt.minradius, sourcerad[i], srad2[i]])
            # ??  This is used to select which sources are in-range
            sourcerad[i] = R
            if isinstance(src, PointSource):
                src.fixedRadius = R
                src.minradius = opt.minradius
                
            elif (isinstance(src, HoggGalaxy) or
                  isinstance(src, FixedCompositeGalaxy)):
                src.halfsize = R
                
        # We used to dice the image into blocks/cells...
        fullIV = np.zeros(len(cat))
        fskeys = ['prochi2', 'pronpix', 'profracflux', 'proflux', 'npix', 'pronexp']
        fitstats = dict([(k, np.zeros(len(cat))) for k in fskeys])

        twcs = ConstantFitsWcs(wcs)
        sky = 0.
        tsky = ConstantSky(sky)

        if ps:
            tag = '%s W%i' % (tile.coadd_id, band)
            
            plt.clf()
            n,b,p = plt.hist(img.ravel(), bins=100,
                             range=(-10*sig1, 20*sig1), log=True,
                             histtype='step', color='b')
            mx = max(n)
            plt.ylim(0.1, mx)
            plt.xlim(-10*sig1, 20*sig1)
            plt.axvline(sky, color='r')
            plt.title('%s: Pixel histogram' % tag)
            ps.savefig()

        if savepickle:
            mods = []
            cats = []

        # SDSS and WISE source margins beyond the image margins ( + source radii )
        smargin = 1
        wmargin = 1

        tim = Image(data=img, invvar=invvar, psf=psf, wcs=twcs,
                    sky=tsky, photocal=LinearPhotoCal(1., band=wanyband),
                    name='Coadd %s W%i' % (tile.coadd_id, band))

        # Relevant SDSS sources:
        m = smargin + sourcerad
        I = np.flatnonzero(((T.x+m) >= -0.5) * ((T.x-m) < (W-0.5)) *
                           ((T.y+m) >= -0.5) * ((T.y-m) < (H-0.5)))
        inbox = ((T.x[I] >= -0.5) * (T.x[I] < (W-0.5)) *
                 (T.y[I] >= -0.5) * (T.y[I] < (H-0.5)))
        # Inside this cell
        srci = I[inbox]
        # In the margin
        margi = I[np.logical_not(inbox)]

        # sources in the ROI box
        subcat = [cat[i] for i in srci]

        # include *copies* of sources in the margins
        # (that way we automatically don't save the results)
        subcat.extend([cat[i].copy() for i in margi])
        assert(len(subcat) == len(I))

        # add WISE-only sources in the expanded region
        m = wmargin + UW.rad
        J = np.flatnonzero(((UW.x+m) >= -0.5) * ((UW.x-m) < (W-0.5)) *
                           ((UW.y+m) >= -0.5) * ((UW.y-m) < (H-0.5)))

        if opt.savewise:
            jinbox = ((UW.x[J] >= -0.5) * (UW.x[J] < (W-0.5)) *
                      (UW.y[J] >= -0.5) * (UW.y[J] < (H-0.5)))
            uwcat = []
        wnm = UW.get('w%inm' % band)
        nomag = 0
        for ji,j in enumerate(J):
            if not np.isfinite(wnm[j]):
                nomag += 1
                continue
            ptsrc = PointSource(RaDecPos(UW.ra[j], UW.dec[j]),
                                      NanoMaggies(**{wanyband: wnm[j]}))
            ptsrc.radius = UW.rad[j]
            subcat.append(ptsrc)
            if opt.savewise:
                if jinbox[ji]:
                    uwcat.append((j, ptsrc))
                
        print('WISE-only:', nomag, 'of', len(J), 'had invalid mags')
        print('Sources:', len(srci), 'in the box,', len(I)-len(srci), 'in the margins, and', len(J), 'WISE-only')
        print('Creating a Tractor with image', tim.shape, 'and', len(subcat), 'sources')
        tractor = Tractor([tim], subcat)
        tractor.disable_cache()

        print('Running forced photometry...')
        t0 = Time()
        tractor.freezeParamsRecursive('*')

        if opt.sky:
            tractor.thawPathsTo('sky')
            print('Initial sky values:')
            for tim in tractor.getImages():
                print(tim.getSky())

        tractor.thawPathsTo(wanyband)

        wantims = (savepickle or (ps is not None) or opt.save_fits)

        R = tractor.optimize_forced_photometry(
            minsb=minsb, mindlnp=1., sky=opt.sky, minFlux=None,
            fitstats=True, fitstat_extras=[('pronexp', [nims])],
            variance=True, shared_params=False,
            use_ceres=opt.ceres, BW=opt.ceresblock, BH=opt.ceresblock,
            wantims=wantims, negfluxval=0.1*sig1)
        print('That took', Time()-t0)

        if wantims:
            ims0 = R.ims0
            ims1 = R.ims1
        IV,fs = R.IV, R.fitstats

        if opt.sky:
            print('Fit sky values:')
            for tim in tractor.getImages():
                print(tim.getSky())

        if opt.savewise:
            for (j,src) in uwcat:
                fitwiseflux[band][j] = src.getBrightness().getBand(wanyband)

        if opt.save_fits:
            (dat,mod,ie,chi,roi) = ims1[0]

            tag = 'fit-%s-w%i' % (tile.coadd_id, band)
            fitsio.write('%s-data.fits' % tag, dat, clobber=True)
            fitsio.write('%s-mod.fits' % tag,  mod, clobber=True)
            fitsio.write('%s-chi.fits' % tag,  chi, clobber=True)

        if ps:
            tag = '%s W%i' % (tile.coadd_id, band)

            (dat,mod,ie,chi,roi) = ims1[0]

            plt.clf()
            plt.imshow(dat, interpolation='nearest', origin='lower',
                       cmap='gray', vmin=-3*sig1, vmax=10*sig1)
            plt.colorbar()
            plt.title('%s: data' % tag)
            ps.savefig()

            # plt.clf()
            # plt.imshow(1./ie, interpolation='nearest', origin='lower',
            #            cmap='gray', vmin=0, vmax=10*sig1)
            # plt.colorbar()
            # plt.title('%s: sigma' % tag)
            # ps.savefig()

            plt.clf()
            plt.imshow(mod, interpolation='nearest', origin='lower',
                       cmap='gray', vmin=-3*sig1, vmax=10*sig1)
            plt.colorbar()
            plt.title('%s: model' % tag)
            ps.savefig()

            plt.clf()
            plt.imshow(chi, interpolation='nearest', origin='lower',
                       cmap='gray', vmin=-5, vmax=+5)
            plt.colorbar()
            plt.title('%s: chi' % tag)
            ps.savefig()

            # plt.clf()
            # plt.imshow(np.round(chi), interpolation='nearest', origin='lower',
            #            cmap='jet', vmin=-5, vmax=+5)
            # plt.colorbar()
            # plt.title('Chi')
            # ps.savefig()

            plt.clf()
            plt.imshow(chi, interpolation='nearest', origin='lower',
                       cmap='gray', vmin=-20, vmax=+20)
            plt.colorbar()
            plt.title('%s: chi 2' % tag)
            ps.savefig()

            plt.clf()
            n,b,p = plt.hist(chi.ravel(), bins=100,
                             range=(-10, 10), log=True,
                             histtype='step', color='b')
            mx = max(n)
            plt.ylim(0.1, mx)
            plt.axvline(0, color='r')
            plt.title('%s: chi' % tag)
            ps.savefig()

            # fn = ps.basefn + '-chi.fits'
            # fitsio.write(fn, chi, clobber=True)
            # print 'Wrote', fn

        if savepickle:
            if ims1 is None:
                mod = None
            else:
                im,mod,ie,chi,roi = ims1[0]
            mods.append(mod)
            cats.append((
                srci, margi, UW.x[J], UW.y[J],
                T.x[srci], T.y[srci], T.x[margi], T.y[margi],
                [src.copy() for src in cat],
                [src.copy() for src in subcat]))

        if len(srci):
            # Save fit stats
            fullIV[srci] = IV[:len(srci)]
            for k in fskeys:
                x = getattr(fs, k)
                fitstats[k][srci] = np.array(x)

        nm = np.array([src.getBrightness().getBand(wanyband) for src in cat])
        nm_ivar = fullIV
        T.set(wband + '_nanomaggies', nm.astype(np.float32))
        T.set(wband + '_nanomaggies_ivar', nm_ivar.astype(np.float32))
        dnm = np.zeros(len(nm_ivar), np.float32)
        okiv = (nm_ivar > 0)
        dnm[okiv] = (1./np.sqrt(nm_ivar[okiv])).astype(np.float32)
        okflux = (nm > 0)
        mag = np.zeros(len(nm), np.float32)
        mag[okflux] = (NanoMaggies.nanomaggiesToMag(nm[okflux])).astype(np.float32)
        dmag = np.zeros(len(nm), np.float32)
        ok = (okiv * okflux)
        dmag[ok] = (np.abs((-2.5 / np.log(10.)) * dnm[ok] / nm[ok])).astype(np.float32)

        mag[np.logical_not(okflux)] = np.nan
        dmag[np.logical_not(ok)] = np.nan
        
        T.set(wband + '_mag', mag)
        T.set(wband + '_mag_err', dmag)
        for k in fskeys:
            T.set(wband + '_' + k, fitstats[k].astype(np.float32))

        if ps:
            I,J,d = match_radec(WISE.ra, WISE.dec, T.ra, T.dec, 4./3600.)

            plt.clf()
            lo,cathi = 10,18
            if band == 3:
                lo,cathi = 8, 13
            elif band == 4:
                #lo,cathi = 4.5, 10.5
                lo,cathi = 4.5, 12
            loghist(WISE.get('w%impro'%band)[I], T.get(wband+'_mag')[J],
                    range=((lo,cathi),(lo,cathi)), bins=200)
            plt.xlabel('WISE W%i mag' % band)
            plt.ylabel('Tractor W%i mag' % band)
            plt.title('WISE catalog vs Tractor forced photometry')
            plt.axis([cathi,lo,cathi,lo])
            ps.savefig()

        print('Tile', tile.coadd_id, 'band', wband, 'took', Time()-tb0)

    T.cut(inbounds)

    T.delete_column('psfflux')
    T.delete_column('cmodelflux')
    T.delete_column('devflux')
    T.delete_column('expflux')
    T.treated_as_pointsource = T.treated_as_pointsource.astype(np.uint8)
    T.pointsource = T.pointsource.astype(np.uint8)

    T.writeto(outfn, header=hdr)
    print('Wrote', outfn)

    if savepickle:
        fn = opt.output % (tile.coadd_id)
        fn = fn.replace('.fits','.pickle')
        pickle_to_file((mods, cats, T, sourcerad), fn)
        print('Pickled', fn)

    print('Tile', tile.coadd_id, 'took', Time()-tt0)

Example 41

View license
def main():
    parser = argparse.ArgumentParser(
        description=__doc__,
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('rst_discourse_tb_dir',
                        help='directory for the RST Discourse Treebank.  \
                              This should have a subdirectory \
                              data/RSTtrees-WSJ-main-1.0.')
    parser.add_argument('ptb_dir',
                        help='directory for the Penn Treebank.  This should \
                              have a subdirectory parsed/mrg/wsj.')
    parser.add_argument('--output_dir',
                        help='directory where the output JSON files go.',
                        default='.')
    args = parser.parse_args()

    logging.basicConfig(format=('%(asctime)s - %(name)s - %(levelname)s - ' +
                                '%(message)s'), level=logging.INFO)

    logging.warning(
        " Warnings related to minor issues that are difficult to resolve " +
        " will be logged for the following files: " +
        " file1.edus, file5.edus, wsj_0678.out.edus, and wsj_2343.out.edus." +
        " Multiple warnings 'not enough syntax trees' will be produced" +
        " because the RSTDTB has footers that are not in the PTB (e.g.," +
        " indicating where a story is written). Also, there are some loose" +
        " match warnings because of differences in formatting between" +
        " treebanks.")

    for dataset in ['TRAINING', 'TEST']:
        logging.info(dataset)

        outputs = []

        for path_index, path in enumerate(
                sorted(glob(os.path.join(args.rst_discourse_tb_dir,
                                         'data',
                                         'RSTtrees-WSJ-main-1.0',
                                         dataset,
                                         '*.edus')))):

            path_basename = os.path.basename(path)
            # if path_basename in file_mapping:
            #     # skip the not-so-well-formatted files "file1" to "file5"
            #     continue

            tokens_doc = []
            edu_start_indices = []

            logging.info('{} {}'.format(path_index, path_basename))
            ptb_id = (file_mapping[path_basename] if
                      path_basename in file_mapping else
                      path_basename)[:-9]
            ptb_path = os.path.join(args.ptb_dir, 'parsed', 'mrg', 'wsj',
                                    ptb_id[4:6], '{}.mrg'.format(ptb_id))

            with open(ptb_path) as f:
                doc = re.sub(r'\s+', ' ', f.read()).strip()
                trees = [ParentedTree.fromstring('( ({}'.format(x)) for x
                         in re.split(r'\(\s*\(', doc) if x]

            for t in trees:
                convert_ptb_tree(t)

            with open(path) as f:
                edus = [line.strip() for line in f.readlines()]
            path_outfile = path[:-5]
            path_dis = "{}.dis".format(path_outfile)
            with open(path_dis) as f:
                rst_tree_str = f.read().strip()
                rst_tree_str = fix_rst_treebank_tree_str(rst_tree_str)
                rst_tree_str = convert_parens_in_rst_tree_str(rst_tree_str)
                rst_tree = ParentedTree.fromstring(rst_tree_str)
                reformat_rst_tree(rst_tree)

            # Identify which EDUs are at the beginnings of paragraphs.
            edu_starts_paragraph = []
            with open(path_outfile) as f:
                outfile_doc = f.read().strip()
                paragraphs = re.split(r'\n\n+', outfile_doc)
                # Filter out paragraphs that don't include a word character.
                paragraphs = [x for x in paragraphs if re.search(r'\w', x)]
                # Remove extra nonword characters to make alignment easier
                # (to avoid problems with the minor discrepancies that exist
                #  in the two versions of the documents.)
                paragraphs = [re.sub(r'\W', r'', p.lower())
                              for p in paragraphs]

                p_idx = -1
                paragraph = ""
                for edu_index, edu in enumerate(edus):
                    logging.debug('edu: {}, paragraph: {}, p_idx: {}'
                                  .format(edu, paragraph, p_idx))
                    edu = re.sub(r'\W', r'', edu.lower())
                    starts_paragraph = False
                    crossed_paragraphs = False
                    while len(paragraph) < len(edu):
                        assert not crossed_paragraphs or starts_paragraph
                        starts_paragraph = True
                        p_idx += 1
                        paragraph += paragraphs[p_idx]
                        if len(paragraph) < len(edu):
                            crossed_paragraphs = True
                            logging.warning(
                                'A paragraph is split across trees.' +
                                ' doc: {}, chars: {}, EDU: {}'
                                .format(path_basename,
                                        paragraphs[p_idx:p_idx + 2], edu))

                    assert paragraph.index(edu) == 0
                    logging.debug('edu_starts_paragraph = {}'
                                  .format(starts_paragraph))
                    edu_starts_paragraph.append(starts_paragraph)
                    paragraph = paragraph[len(edu):].strip()
                assert p_idx == len(paragraphs) - 1
                if sum(edu_starts_paragraph) != len(paragraphs):
                    logging.warning(('The number of sentences that start a' +
                                     ' paragraph is not equal to the number' +
                                     ' of paragraphs.  This is probably due' +
                                     ' to trees being split across' +
                                     ' paragraphs. doc: {}')
                                    .format(path_basename))

            edu_index = -1
            tok_index = 0
            tree_index = 0

            edu = ""
            tree = trees[0]

            tokens_doc = [extract_converted_terminals(t) for t in trees]
            tokens = tokens_doc[0]
            preterminals = [extract_preterminals(t) for t in trees]

            while edu_index < len(edus) - 1:
                # if we are out of tokens for the sentence we are working
                # with, move to the next sentence.
                if tok_index >= len(tokens):
                    tree_index += 1
                    if tree_index >= len(trees):
                        logging.warning(('Not enough syntax trees for {},' +
                                         ' probably because the RSTDB' +
                                         ' contains a footer that is not in' +
                                         ' the PTB. The remaining EDUs will' +
                                         ' be automatically tagged.')
                                        .format(path_basename))
                        unparsed_edus = ' '.join(edus[edu_index + 1:])
                        # The tokenizer splits '---' into '--' '-'.
                        # This is a hack to get around that.
                        unparsed_edus = re.sub(r'---', '--', unparsed_edus)
                        for tagged_sent in \
                            [nltk.pos_tag(convert_paren_tokens_to_ptb_format( \
                             TreebankWordTokenizer().tokenize(x)))
                             for x in nltk.sent_tokenize(unparsed_edus)]:
                            new_tree = ParentedTree.fromstring('((S {}))' \
                                .format(' '.join(['({} {})'.format(tag, word)
                                                  for word, tag
                                                  in tagged_sent])))
                            trees.append(new_tree)
                            tokens_doc.append(
                                extract_converted_terminals(new_tree))
                            preterminals.append(extract_preterminals(new_tree))

                    tree = trees[tree_index]
                    tokens = tokens_doc[tree_index]
                    tok_index = 0

                tok = tokens[tok_index]

                # if edu is the empty string, then the previous edu was
                # completed by the last token,
                # so this token starts the next edu.
                if not edu:
                    edu_index += 1
                    edu = edus[edu_index]
                    edu = re.sub(r'>\s*', r'', edu).replace('&amp;', '&')
                    edu = re.sub(r'---', r'--', edu)
                    edu = edu.replace('. . .', '...')

                    # annoying edge cases
                    if path_basename == 'file1.edus':
                        edu = edu.replace('founded by',
                                          'founded by his grandfather.')
                    elif (path_basename == 'wsj_0660.out.edus'
                          or path_basename == 'wsj_1368.out.edus'
                          or path_basename == "wsj_1371.out.edus"):
                        edu = edu.replace('S.p. A.', 'S.p.A.')
                    elif path_basename == 'wsj_1329.out.edus':
                        edu = edu.replace('G.m.b. H.', 'G.m.b.H.')
                    elif path_basename == 'wsj_1367.out.edus':
                        edu = edu.replace('-- that turban --',
                                          '-- that turban')
                    elif path_basename == 'wsj_1377.out.edus':
                        edu = edu.replace('Part of a Series',
                                          'Part of a Series }')
                    elif path_basename == 'wsj_1974.out.edus':
                        edu = edu.replace(r'5/ 16', r'5/16')
                    elif path_basename == 'file2.edus':
                        edu = edu.replace('read it into the record,',
                                          'read it into the record.')
                    elif path_basename == 'file3.edus':
                        edu = edu.replace('about $to $', 'about $2 to $4')
                    elif path_basename == 'file5.edus':
                        # There is a PTB error in wsj_2172.mrg:
                        # The word "analysts" is missing from the parse.
                        # It's gone without a trace :-/
                        edu = edu.replace('panic among analysts',
                                          'panic among')
                        edu = edu.replace('his bid Oct. 17', 'his bid Oct. 5')
                        edu = edu.replace('his bid on Oct. 17',
                                          'his bid on Oct. 5')
                        edu = edu.replace('to commit $billion,',
                                          'to commit $3 billion,')
                        edu = edu.replace('received $million in fees',
                                          'received $8 million in fees')
                        edu = edu.replace('`` in light', '"in light')
                        edu = edu.replace('3.00 a share', '2 a share')
                        edu = edu.replace(" the Deal.", " the Deal.'")
                        edu = edu.replace("' Why doesn't", "Why doesn't")
                    elif path_basename == 'wsj_1331.out.edus':
                        edu = edu.replace('`S', "'S")
                    elif path_basename == 'wsj_1373.out.edus':
                        edu = edu.replace('... An N.V.', 'An N.V.')
                        edu = edu.replace('features.', 'features....')
                    elif path_basename == 'wsj_1123.out.edus':
                        edu = edu.replace('" Reuben', 'Reuben')
                        edu = edu.replace('subscribe to.', 'subscribe to."')
                    elif path_basename == 'wsj_2317.out.edus':
                        edu = edu.replace('. The lower', 'The lower')
                        edu = edu.replace('$4 million', '$4 million.')
                    elif path_basename == 'wsj_1376.out.edus':
                        edu = edu.replace('Elizabeth.', 'Elizabeth.\'"')
                        edu = edu.replace('\'" In', 'In')
                    elif path_basename == 'wsj_1105.out.edus':
                        # PTB error: a sentence starts with an end quote.
                        # For simplicity, we'll just make the
                        # EDU string look like the PTB sentence.
                        edu = edu.replace('By lowering prices',
                                          '"By lowering prices')
                        edu = edu.replace(' 70% off."', ' 70% off.')
                    elif path_basename == 'wsj_1125.out.edus':
                        # PTB error: a sentence ends with an start quote.
                        edu = edu.replace('developer.', 'developer."')
                        edu = edu.replace('"So developers', 'So developers')
                    elif path_basename == 'wsj_1158.out.edus':
                        edu = re.sub(r'\s*\-$', r'', edu)
                        # PTB error: a sentence starts with an end quote.
                        edu = edu.replace(' virtues."', ' virtues.')
                        edu = edu.replace('So much for', '"So much for')
                    elif path_basename == 'wsj_0632.out.edus':
                        # PTB error: a sentence starts with an end quote.
                        edu = edu.replace(' individual.', ' individual."')
                        edu = edu.replace('"If there ', 'If there ')
                    elif path_basename == 'wsj_2386.out.edus':
                        # PTB error: a sentence starts with an end quote.
                        edu = edu.replace('lenders."', 'lenders.')
                        edu = edu.replace('Mr. P', '"Mr. P')
                    elif path_basename == 'wsj_1128.out.edus':
                        # PTB error: a sentence ends with an start quote.
                        edu = edu.replace('it down.', 'it down."')
                        edu = edu.replace('"It\'s a real"', "It's a real")
                    elif path_basename == 'wsj_1323.out.edus':
                        # PTB error (or at least a very unusual edge case):
                        # "--" ends a sentence.
                        edu = edu.replace('-- damn!', 'damn!')
                        edu = edu.replace('from the hook', 'from the hook --')
                    elif path_basename == 'wsj_2303.out.edus':
                        # PTB error: a sentence ends with an start quote.
                        edu = edu.replace('Simpson in an interview.',
                                          'Simpson in an interview."')
                        edu = edu.replace('"Hooker\'s', 'Hooker\'s')
                    # wsj_2343.out.edus also has an error that can't be easily
                    # fixed: and EDU spans 2 sentences, ("to analyze what...").

                    if edu_start_indices \
                            and tree_index - edu_start_indices[-1][0] > 1:
                        logging.warning(("SKIPPED A TREE. file = {}" +
                                         " tree_index = {}," +
                                         " edu_start_indices[-1][0] = {}," +
                                         " edu index = {}")
                                        .format(path_basename, tree_index,
                                                edu_start_indices[-1][0],
                                                edu_index))

                    edu_start_indices.append((tree_index, tok_index,
                                              edu_index))

                # remove the next token from the edu, along with any whitespace
                if edu.startswith(tok):
                    edu = edu[len(tok):].strip()
                elif (re.search(r'[^a-zA-Z0-9]', edu[0])
                      and edu[1:].startswith(tok)):
                    logging.warning(("loose match: tok = {}, " +
                                     "remainder of EDU: {}").format(tok, edu))
                    edu = edu[len(tok) + 1:].strip()
                else:
                    m_tok = re.search(r'^[^a-zA-Z ]+$', tok)
                    m_edu = re.search(r'^[^a-zA-Z ]+(.*)', edu)
                    if not m_tok or not m_edu:
                        raise Exception(('\n\npath_index: {}\ntok: {}\n' +
                                         'edu: {}\nfull_edu:{}\nleaves:' +
                                         '{}\n\n').format(path_index, tok, edu,
                                                          edus[edu_index],
                                                          tree.leaves()))
                    logging.warning("loose match: {} ==> {}".format(tok, edu))
                    edu = m_edu.groups()[0].strip()

                tok_index += 1

            output = {"doc_id": ptb_id,
                      "path_basename": path_basename,
                      "tokens": tokens_doc,
                      "edu_strings": edus,
                      "syntax_trees": [t.pprint(margin=TREE_PRINT_MARGIN)
                                       for t in trees],
                      "token_tree_positions": [[x.treeposition() for x in
                                                preterminals_sentence]
                                               for preterminals_sentence
                                               in preterminals],
                      "pos_tags": [[x.label() for x in preterminals_sentence]
                                   for preterminals_sentence in preterminals],
                      "edu_start_indices": edu_start_indices,
                      "rst_tree": rst_tree.pprint(margin=TREE_PRINT_MARGIN),
                      "edu_starts_paragraph": edu_starts_paragraph}

            assert len(edu_start_indices) == len(edus)
            assert len(edu_starts_paragraph) == len(edus)

            # check that the EDUs match up
            edu_tokens = extract_edus_tokens(edu_start_indices, tokens_doc)
            for edu_index, (edu, edu_token_list) \
                    in enumerate(zip(edus, edu_tokens)):
                edu_nospace = re.sub(r'\s+', '', edu).lower()
                edu_tokens_nospace = ''.join(edu_token_list).lower()
                distance = nltk.metrics.distance.edit_distance(
                    edu_nospace, edu_tokens_nospace)
                if distance > 4:
                    logging.warning(("EDIT DISTANCE > 3 IN {}: " +
                                     "edu string = {}, edu tokens = {}, " +
                                     "edu idx = {}")
                                    .format(path_basename, edu,
                                            edu_token_list, edu_index))
                if not re.search(r'[A-Za-z0-9]', edu_tokens_nospace):
                    logging.warning(("PUNCTUATION-ONLY EDU IN {}: " +
                                     "edu tokens = {}, edu idx = {}")
                                    .format(path_basename, edu_token_list,
                                            edu_index))

            outputs.append(output)

        with open(os.path.join(args.output_dir, ('rst_discourse_tb_edus_' +
                                                 '{}.json').format(dataset)),
                  'w') as outfile:
            json.dump(outputs, outfile)

Example 42

Project: openerp-7.0
Source File: translate.py
View license
def extend_trans_generate(lang, modules, cr):
    dbname = cr.dbname

    pool = pooler.get_pool(dbname)
    trans_obj = pool.get('ir.translation')
    model_data_obj = pool.get('ir.model.data')
    uid = 1
    l = pool.models.items()
    l.sort()

    query = 'SELECT name, model, res_id, module'    \
            '  FROM ir_model_data'

    query_models = """SELECT m.id, m.model, imd.module
            FROM ir_model AS m, ir_model_data AS imd
            WHERE m.id = imd.res_id AND imd.model = 'ir.model' """

    if 'all_installed' in modules:
        query += ' WHERE module IN ( SELECT name FROM ir_module_module WHERE state = \'installed\') '
        query_models += " AND imd.module in ( SELECT name FROM ir_module_module WHERE state = 'installed') "
    query_param = None
    if 'all' not in modules:
        query += ' WHERE module IN %s'
        query_models += ' AND imd.module in %s'
        query_param = (tuple(modules),)
    query += ' ORDER BY module, model, name'
    query_models += ' ORDER BY module, model'

    cr.execute(query, query_param)

    _to_translate = []
    def push_translation(module, type, name, id, source, comments=None):
        tuple = (module, source, name, id, type, comments or [])
        # empty and one-letter terms are ignored, they probably are not meant to be
        # translated, and would be very hard to translate anyway.
        if not source or len(source.strip()) <= 1:
            _logger.debug("Ignoring empty or 1-letter source term: %r", tuple)
            return
        if tuple not in _to_translate:
            _to_translate.append(tuple)

    def encode(s):
        if isinstance(s, unicode):
            return s.encode('utf8')
        return s

    for (xml_name,model,res_id,module) in cr.fetchall():
        module = encode(module)
        model = encode(model)
        xml_name = "%s.%s" % (module, encode(xml_name))

        if not pool.get(model):
            _logger.error("Unable to find object %r", model)
            continue

        exists = pool.get(model).exists(cr, uid, res_id)
        if not exists:
            _logger.warning("Unable to find object %r with id %d", model, res_id)
            continue
        obj = pool.get(model).browse(cr, uid, res_id)

        if model=='ir.ui.view':
            d = etree.XML(encode(obj.arch))
            for t in trans_parse_view(d):
                push_translation(module, 'view', encode(obj.model), 0, t)
        elif model=='ir.actions.wizard':
            service_name = 'wizard.'+encode(obj.wiz_name)
            import openerp.netsvc as netsvc
            if netsvc.Service._services.get(service_name):
                obj2 = netsvc.Service._services[service_name]
                for state_name, state_def in obj2.states.iteritems():
                    if 'result' in state_def:
                        result = state_def['result']
                        if result['type'] != 'form':
                            continue
                        name = "%s,%s" % (encode(obj.wiz_name), state_name)

                        def_params = {
                            'string': ('wizard_field', lambda s: [encode(s)]),
                            'selection': ('selection', lambda s: [encode(e[1]) for e in ((not callable(s)) and s or [])]),
                            'help': ('help', lambda s: [encode(s)]),
                        }

                        # export fields
                        if not result.has_key('fields'):
                            _logger.warning("res has no fields: %r", result)
                            continue
                        for field_name, field_def in result['fields'].iteritems():
                            res_name = name + ',' + field_name

                            for fn in def_params:
                                if fn in field_def:
                                    transtype, modifier = def_params[fn]
                                    for val in modifier(field_def[fn]):
                                        push_translation(module, transtype, res_name, 0, val)

                        # export arch
                        arch = result['arch']
                        if arch and not isinstance(arch, UpdateableStr):
                            d = etree.XML(arch)
                            for t in trans_parse_view(d):
                                push_translation(module, 'wizard_view', name, 0, t)

                        # export button labels
                        for but_args in result['state']:
                            button_name = but_args[0]
                            button_label = but_args[1]
                            res_name = name + ',' + button_name
                            push_translation(module, 'wizard_button', res_name, 0, button_label)

        elif model=='ir.model.fields':
            try:
                field_name = encode(obj.name)
            except AttributeError, exc:
                _logger.error("name error in %s: %s", xml_name, str(exc))
                continue
            objmodel = pool.get(obj.model)
            if not objmodel or not field_name in objmodel._columns:
                continue
            field_def = objmodel._columns[field_name]

            name = "%s,%s" % (encode(obj.model), field_name)
            push_translation(module, 'field', name, 0, encode(field_def.string))

            if field_def.help:
                push_translation(module, 'help', name, 0, encode(field_def.help))

            if field_def.translate:
                ids = objmodel.search(cr, uid, [])
                obj_values = objmodel.read(cr, uid, ids, [field_name])
                for obj_value in obj_values:
                    res_id = obj_value['id']
                    if obj.name in ('ir.model', 'ir.ui.menu'):
                        res_id = 0
                    model_data_ids = model_data_obj.search(cr, uid, [
                        ('model', '=', model),
                        ('res_id', '=', res_id),
                        ])
                    if not model_data_ids:
                        push_translation(module, 'model', name, 0, encode(obj_value[field_name]))

            if hasattr(field_def, 'selection') and isinstance(field_def.selection, (list, tuple)):
                for dummy, val in field_def.selection:
                    push_translation(module, 'selection', name, 0, encode(val))

        elif model=='ir.actions.report.xml':
            name = encode(obj.report_name)
            fname = ""
            ##### Changes for Aeroo ######
            if obj.report_type == 'aeroo':
                trans_ids = trans_obj.search(cr, uid, [('type', '=', 'report'),('res_id', '=', obj.id)])
                for t in trans_obj.read(cr, uid, trans_ids, ['name','src']):
                    push_translation(module, "report", t['name'], xml_name, t['src'])
            ##############################
            else:
                if obj.report_rml:
                    fname = obj.report_rml
                    parse_func = trans_parse_rml
                    report_type = "report"
                elif obj.report_xsl:
                    fname = obj.report_xsl
                    parse_func = trans_parse_xsl
                    report_type = "xsl"
                if fname and obj.report_type in ('pdf', 'xsl'):
                    try:
                        report_file = misc.file_open(fname)
                        try:
                            d = etree.parse(report_file)
                            for t in parse_func(d.iter()):
                                push_translation(module, report_type, name, 0, t)
                        finally:
                            report_file.close()
                    except (IOError, etree.XMLSyntaxError):
                        _logger.exception("couldn't export translation for report %s %s %s", name, report_type, fname)

        for field_name,field_def in obj._table._columns.items():
            if field_def.translate:
                name = model + "," + field_name
                try:
                    trad = getattr(obj, field_name) or ''
                except:
                    trad = ''
                push_translation(module, 'model', name, xml_name, encode(trad))

        # End of data for ir.model.data query results

    cr.execute(query_models, query_param)

    def push_constraint_msg(module, term_type, model, msg):
        if not hasattr(msg, '__call__'):
            push_translation(encode(module), term_type, encode(model), 0, encode(msg))

    def push_local_constraints(module, model, cons_type='sql_constraints'):
        """Climb up the class hierarchy and ignore inherited constraints
           from other modules"""
        term_type = 'sql_constraint' if cons_type == 'sql_constraints' else 'constraint'
        msg_pos = 2 if cons_type == 'sql_constraints' else 1
        for cls in model.__class__.__mro__:
            if getattr(cls, '_module', None) != module:
                continue
            constraints = getattr(cls, '_local_' + cons_type, [])
            for constraint in constraints:
                push_constraint_msg(module, term_type, model._name, constraint[msg_pos])
            
    for (_, model, module) in cr.fetchall():
        model_obj = pool.get(model)

        if not model_obj:
            _logger.error("Unable to find object %r", model)
            continue

        if model_obj._constraints:
            push_local_constraints(module, model_obj, 'constraints')

        if model_obj._sql_constraints:
            push_local_constraints(module, model_obj, 'sql_constraints')

    def get_module_from_path(path, mod_paths=None):
        if not mod_paths:
            # First, construct a list of possible paths
            def_path = os.path.abspath(os.path.join(tools.config['root_path'], 'addons'))     # default addons path (base)
            ad_paths= map(lambda m: os.path.abspath(m.strip()),tools.config['addons_path'].split(','))
            mod_paths=[def_path]
            for adp in ad_paths:
                mod_paths.append(adp)
                if not os.path.isabs(adp):
                    mod_paths.append(adp)
                elif adp.startswith(def_path):
                    mod_paths.append(adp[len(def_path)+1:])
        for mp in mod_paths:
            if path.startswith(mp) and (os.path.dirname(path) != mp):
                path = path[len(mp)+1:]
                return path.split(os.path.sep)[0]
        return 'base'   # files that are not in a module are considered as being in 'base' module

    modobj = pool.get('ir.module.module')
    installed_modids = modobj.search(cr, uid, [('state', '=', 'installed')])
    installed_modules = map(lambda m: m['name'], modobj.read(cr, uid, installed_modids, ['name']))

    root_path = os.path.join(tools.config['root_path'], 'addons')

    apaths = map(os.path.abspath, map(str.strip, tools.config['addons_path'].split(',')))
    if root_path in apaths:
        path_list = apaths
    else :
        path_list = [root_path,] + apaths

    # Also scan these non-addon paths
    for bin_path in ['osv', 'report' ]:
        path_list.append(os.path.join(tools.config['root_path'], bin_path))

    _logger.debug("Scanning modules at paths: ", path_list)

    mod_paths = []

    def verified_module_filepaths(fname, path, root):
        fabsolutepath = join(root, fname)
        frelativepath = fabsolutepath[len(path):]
        display_path = "addons%s" % frelativepath
        module = get_module_from_path(fabsolutepath, mod_paths=mod_paths)
        if ('all' in modules or module in modules) and module in installed_modules:
            return module, fabsolutepath, frelativepath, display_path
        return None, None, None, None

    def babel_extract_terms(fname, path, root, extract_method="python", trans_type='code',
                               extra_comments=None, extract_keywords={'_': None}):
        module, fabsolutepath, _, display_path = verified_module_filepaths(fname, path, root)
        extra_comments = extra_comments or []
        if module:
            src_file = open(fabsolutepath, 'r')
            try:
                for lineno, message, comments in extract.extract(extract_method, src_file,
                                                                 keywords=extract_keywords):
                    push_translation(module, trans_type, display_path, lineno,
                                     encode(message), comments + extra_comments)
            except Exception:
                _logger.exception("Failed to extract terms from %s", fabsolutepath)
            finally:
                src_file.close()

    for path in path_list:
        _logger.debug("Scanning files of modules at %s", path)
        for root, dummy, files in osutil.walksymlinks(path):
            for fname in fnmatch.filter(files, '*.py'):
                babel_extract_terms(fname, path, root)
            # mako provides a babel extractor: http://docs.makotemplates.org/en/latest/usage.html#babel
            for fname in fnmatch.filter(files, '*.mako'):
                babel_extract_terms(fname, path, root, 'mako', trans_type='report')
            # Javascript source files in the static/src/js directory, rest is ignored (libs)
            if fnmatch.fnmatch(root, '*/static/src/js*'):
                for fname in fnmatch.filter(files, '*.js'):
                    babel_extract_terms(fname, path, root, 'javascript',
                                        extra_comments=[WEB_TRANSLATION_COMMENT],
                                        extract_keywords={'_t': None, '_lt': None})
            # QWeb template files
            if fnmatch.fnmatch(root, '*/static/src/xml*'):
                for fname in fnmatch.filter(files, '*.xml'):
                    babel_extract_terms(fname, path, root, 'openerp.tools.translate:babel_extract_qweb',
                                        extra_comments=[WEB_TRANSLATION_COMMENT])

    out = []
    _to_translate.sort()
    # translate strings marked as to be translated
    for module, source, name, id, type, comments in _to_translate:
        trans = '' if not lang else trans_obj._get_source(cr, uid, name, type, lang, source)
        out.append([module, type, name, id, source, encode(trans) or '', comments])
    return out

Example 43

Project: openerp-7.0
Source File: translate.py
View license
def extend_trans_generate(lang, modules, cr):
    dbname = cr.dbname

    pool = pooler.get_pool(dbname)
    trans_obj = pool.get('ir.translation')
    model_data_obj = pool.get('ir.model.data')
    uid = 1
    l = pool.models.items()
    l.sort()

    query = 'SELECT name, model, res_id, module'    \
            '  FROM ir_model_data'

    query_models = """SELECT m.id, m.model, imd.module
            FROM ir_model AS m, ir_model_data AS imd
            WHERE m.id = imd.res_id AND imd.model = 'ir.model' """

    if 'all_installed' in modules:
        query += ' WHERE module IN ( SELECT name FROM ir_module_module WHERE state = \'installed\') '
        query_models += " AND imd.module in ( SELECT name FROM ir_module_module WHERE state = 'installed') "
    query_param = None
    if 'all' not in modules:
        query += ' WHERE module IN %s'
        query_models += ' AND imd.module in %s'
        query_param = (tuple(modules),)
    query += ' ORDER BY module, model, name'
    query_models += ' ORDER BY module, model'

    cr.execute(query, query_param)

    _to_translate = []
    def push_translation(module, type, name, id, source, comments=None):
        tuple = (module, source, name, id, type, comments or [])
        # empty and one-letter terms are ignored, they probably are not meant to be
        # translated, and would be very hard to translate anyway.
        if not source or len(source.strip()) <= 1:
            _logger.debug("Ignoring empty or 1-letter source term: %r", tuple)
            return
        if tuple not in _to_translate:
            _to_translate.append(tuple)

    def encode(s):
        if isinstance(s, unicode):
            return s.encode('utf8')
        return s

    for (xml_name,model,res_id,module) in cr.fetchall():
        module = encode(module)
        model = encode(model)
        xml_name = "%s.%s" % (module, encode(xml_name))

        if not pool.get(model):
            _logger.error("Unable to find object %r", model)
            continue

        exists = pool.get(model).exists(cr, uid, res_id)
        if not exists:
            _logger.warning("Unable to find object %r with id %d", model, res_id)
            continue
        obj = pool.get(model).browse(cr, uid, res_id)

        if model=='ir.ui.view':
            d = etree.XML(encode(obj.arch))
            for t in trans_parse_view(d):
                push_translation(module, 'view', encode(obj.model), 0, t)
        elif model=='ir.actions.wizard':
            service_name = 'wizard.'+encode(obj.wiz_name)
            import openerp.netsvc as netsvc
            if netsvc.Service._services.get(service_name):
                obj2 = netsvc.Service._services[service_name]
                for state_name, state_def in obj2.states.iteritems():
                    if 'result' in state_def:
                        result = state_def['result']
                        if result['type'] != 'form':
                            continue
                        name = "%s,%s" % (encode(obj.wiz_name), state_name)

                        def_params = {
                            'string': ('wizard_field', lambda s: [encode(s)]),
                            'selection': ('selection', lambda s: [encode(e[1]) for e in ((not callable(s)) an