bokeh.models.Span

Here are the examples of the python api bokeh.models.Span taken from open source projects. By voting up you can indicate which examples are most useful and appropriate.

29 Examples 7

3 View Source File : figure.py
License : GNU General Public License v3.0
Project Creator : happydasch

    def _plot_hlines(self, obj):
        '''
        Plots horizontal lines on figure
        '''
        hlines = obj.plotinfo._get('plothlines', [])
        if not hlines:
            hlines = obj.plotinfo._get('plotyhlines', [])
        # Horizontal Lines
        hline_color = convert_color(self._scheme.hlinescolor)
        for hline in hlines:
            span = Span(location=hline,
                        dimension='width',
                        line_color=hline_color,
                        line_dash=self._style_mpl2bokeh[
                            self._scheme.hlinesstyle],
                        line_width=self._scheme.hlineswidth)
            self.figure.renderers.append(span)

    def fill_nan(self):

3 View Source File : user_performance.py
License : MIT License
Project Creator : kurusugawa-computer

    def _plot_average_line(fig: bokeh.plotting.Figure, value: float, dimension: str):
        span_average_line = bokeh.models.Span(
            location=value,
            dimension=dimension,
            line_color="red",
            line_width=0.5,
        )
        fig.add_layout(span_average_line)

    @staticmethod

3 View Source File : user_performance.py
License : MIT License
Project Creator : kurusugawa-computer

    def _plot_quartile_line(fig: bokeh.plotting.Figure, quartile: tuple[float, float, float], dimension: str):
        """

        Args:
            fig (bokeh.plotting.Figure):
            quartile (tuple[float, float, float]): 四分位数。tuple[25%値, 50%値, 75%値]
            dimension (str): [description]: width or height
        """

        for value in quartile:
            span_average_line = bokeh.models.Span(
                location=value,
                dimension=dimension,
                line_color="blue",
                line_width=0.5,
            )
            fig.add_layout(span_average_line)

    @staticmethod

3 View Source File : analyze.py
License : GNU General Public License v3.0
Project Creator : varadaio

def add_constant_line(p, dim, value, line_color='black',
                      line_dash='dashed', line_width=2):
    constant_line_value = value
    constant_line = Span(location=constant_line_value,
                         dimension=dim, line_color=line_color,
                         line_dash=line_dash, line_width=line_width)
    p.add_layout(constant_line)


@run

0 View Source File : plotters.py
License : MIT License
Project Creator : andyljones

def x_zeroline(f):
    f.add_layout(bom.Span(location=0, dimension='height'))

def default_tools(f):

0 View Source File : variant_qc_plots.py
License : BSD 3-Clause "New" or "Revised" License
Project Creator : broadinstitute

def plot_metric(df: pd.DataFrame,
                y_name: str,
                cols: List[str],
                y_fun: Callable[[pd.Series], Union[float, int]] = lambda x: x,
                cut: int = None,
                plot_all: bool = True,
                plot_bi_allelics: bool = True,
                plot_singletons: bool = True,
                plot_bi_allelic_singletons: bool = True,
                plot_adj: bool = False,
                colors: Dict[str, str] = None,
                link_cumul_y: bool = True,
                size_prop: str = 'area'
                ) -> Tabs:
    """
    Generic function for generating QC metric plots using a plotting-ready DataFrame (obtained from `get_binned_models_pd`)
    DataFrame needs to have a `rank_id` column, a `bin` column and a `model` column (contains the model name and needs to be added to binned table(s))

    This function generates scatter plots with the metric bin on x-axis and a user-defined function on the y-axis.
    The data for the y-axis function needs to from the columns specified in `cols`. The function is specified with the `y_fun` argument and data columns are access as a list.
    As an example, plotting Transition to transversion ratio is done as follows::

        plot_metric(snvs, 'Ti/Tv', ['n_ti', 'n_tv'], y_fun=lambda x: x[0]/x[1], colors=colors)

    In this command, `x[0]` correspond to the  first column selected (`'n_ti'`)  and `x[1]` to the second (`'n_tv'`).


    This function plots a tab for each of the plot condition(s) selected: all, bi-allelics, bi-allelic singletons.
    Within each tab, each row contains a non-cumulative and a cumulative plot of the bins / values.
    If `plot_adj` is set, then an extra row is added plotting only variants in release samples where AC_ADJ>0. The bin for these sites is computed based on those variants only.

    :param pd.DataFrame df: Input data
    :param str y_name: Name of the metric plotted on the y-axis
    :param list of str cols: Columns used to compute the metric plotted
    :param callable y_fun: Function to apply to the columns to generate the metric
    :param int cut: Where to draw the bin cut
    :param bool plot_all: Whether to plot a tab with all variants
    :param bool plot_bi_allelics: Whether to plot a tab with bi-allelic variants only
    :param bool plot_singletons: Whether to plot a tab with singleton variants only
    :param bool plot_bi_allelic_singletons:  Whether to plot a tab with bi-allelic singleton variants only
    :param bool plot_adj: Whether to plot additional rows with adj variants in release samples only
    :param dict of str -> str colors: Mapping of model name -> color
    :param bool link_cumul_y: If set, y-axes of cumulative and non-cumulative plots are linked
    :param str size_prop: Either 'size' or 'area' can be specified. If either is specified, the points will be sized proportionally to the amount of data in that point.
    :return: Plot
    :rtype: Tabs
    """

    def get_row(df: pd.DataFrame, y_name: str, cols: List[str], y_fun: Callable[[pd.Series], Union[float, int]], titles: List[str], link_cumul_y: bool, cut: int = None) -> Row:
        """
        Generates a single row with two plots: a regular scatter plot and a cumulative one.
        Both plots have bins on the x-axis. The y-axis is computed by applying the function `y_fun` on the columns `cols`.

        Data source is shared between the two plots so that highlighting / selection is linked.
        X-axis is shared between the two plots.
        Y-axus is shared if `link_cumul_y` is `True`

        """

        def get_plot(data_source: ColumnDataSource, y_name: str, y_col_name: str, titles: List[str], data_ranges: Tuple[DataRange1d, DataRange1d], cut: int = None) -> Plot:
            """
            Generates a single scatter plot panel
            """

            p = figure(
                title=titles[0],
                x_axis_label='bin',
                y_axis_label=y_name,
                tools="save,pan,box_zoom,reset,wheel_zoom,box_select,lasso_select,help,hover")
            p.x_range = data_ranges[0]
            p.y_range = data_ranges[1]

            if cut:
                p.add_layout(Span(location=cut, dimension='height', line_color='red', line_dash='dashed'))

            # Add circles layouts one model at a time, so that no default legend is generated.
            # Because data is in the same ColumnDataSource, use a BooleanFilter to plot each model separately
            circles = []
            for model in set(data_source.data['model']):
                view = CDSView(source=data_source, filters=[BooleanFilter([x == model for x in data_source.data['model']])])
                circles.append((model, [p.circle('bin', y_col_name, color='_color', size='_size', source=data_source, view=view)]))

            p.select_one(HoverTool).tooltips = [('model', '@model'),
                                                ('bin', '@bin'),
                                                (y_name, f'@{y_col_name}'),
                                                ('min_score', '@min_score'),
                                                ('max_score', '@max_score'),
                                                ('n_data_points', '@_n')
                                                ] + [(col, f'@{col}') for col in cols]
            set_plots_defaults(p)

            # Add legend above the plot area
            legend = Legend(items=circles, orientation='horizontal', location=(0, 0), click_policy="hide")
            p.add_layout(legend, 'above')

            # Add subtitles if any
            for title in titles[1:]:
                p.add_layout(Title(text=title, text_font_size=qc_plots_settings['subtitle.text_font_size']), 'above')

            return p

        # Compute non-cumulative values by applying `y_fun`
        df['non_cumul'] = df[cols].apply(y_fun, axis=1)

        # Compute cumulative values for each of the data columns
        for col in cols:
            df[f'{col}_cumul'] = df.groupby('model').aggregate(np.cumsum)[col]
        df['cumul'] = df[[f'{col}_cumul' for col in cols]].apply(y_fun, axis=1)

        # Create data ranges that are either shared or distinct depending on the y_cumul parameter
        non_cumul_data_ranges = (DataRange1d(), DataRange1d())
        cumul_data_ranges = non_cumul_data_ranges if link_cumul_y else (non_cumul_data_ranges[0], DataRange1d())
        data_source = ColumnDataSource(df)

        return Row(get_plot(data_source, y_name, 'non_cumul', titles, non_cumul_data_ranges, cut),
                   get_plot(data_source, y_name, 'cumul', [titles[0] + ', cumulative'] + titles[1:], cumul_data_ranges, cut))

    def prepare_pd(df: pd.DataFrame, cols: List[str], colors: Dict[str, str] = {}, size_prop: str = None):
        """
        Groups a pandas DataFrame by model and bin while keeping relevant columns only.
        Adds 3 columns used for plotting:
        1. A _color column column
        2. A _n column containing the number of data points
        3. A _size column containing the size of data points based on the `size_prop` and `qc_plot_settings` parameters
        """
        df = df.groupby(['model', 'bin']).agg({**{col: np.sum for col in cols},
                                               'min_score': np.min, 'max_score': np.max})
        df = df.reset_index()
        df['_color'] = [colors.get(x, 'gray') for x in df['model']]
        df['_n'] = np.sum(df[cols], axis=1)
        df['_size'] = get_point_size_col(df['_n'], size_prop)
        return df

    colors = colors if colors is not None else {}
    tabs = []
    adj_strats = ['', 'adj_'] if plot_adj else ['']

    if plot_all:
        children = []
        for adj in adj_strats:
            titles = [y_name, 'Adj variants (adj rank)' if adj else 'All variants']
            plot_df = prepare_pd(df.loc[df.rank_id == f'{adj}rank'], cols, colors, size_prop)
            if len(plot_df) > 0:
                children.append(get_row(plot_df, y_name, cols, y_fun, titles, link_cumul_y, cut))
            else:
                logger.warn('No data found for plot: {}'.format('\t'.join(titles)))

        if children:
            tabs.append(Panel(child=Column(children=children), title='All'))

    if plot_bi_allelics:
        children = []
        for adj in adj_strats:
            for biallelic_rank in ['', 'biallelic_']:
                titles = [y_name,'Bi-allelic variants ({} rank)'.format('overall' if not adj and not biallelic_rank else ' '.join([adj[:-1], biallelic_rank[:-1]]).lstrip())]
                plot_df = prepare_pd(df.loc[df.bi_allelic & (df.rank_id == f'{adj}{biallelic_rank}rank')], cols, colors, size_prop)
                if len(plot_df) > 0:
                    children.append(get_row(plot_df, y_name, cols, y_fun, titles, link_cumul_y, cut))
                else:
                    logger.warn('No data found for plot: {}'.format('\t'.join(titles)))

        if children:
            tabs.append(Panel(child=Column(children=children), title='Bi-allelic'))

    if plot_singletons:
        children = []
        for adj in adj_strats:
            for singleton_rank in ['', 'singleton_']:
                titles = [y_name, 'Singletons ({} rank)'.format('overall' if not adj and not singleton_rank else " ".join([adj[:-1], singleton_rank[:-1]]).lstrip())]
                plot_df = prepare_pd(df.loc[df.singleton & (df.rank_id == f'{adj}{singleton_rank}rank')], cols, colors, size_prop)
                if len(plot_df) > 0:
                    children.append(get_row(plot_df, y_name, cols, y_fun, titles, link_cumul_y, cut))
                else:
                    logger.warn('No data found for plot: {}'.format('\t'.join(titles)))

        if children:
            tabs.append(Panel(child=Column(children=children), title='Singletons'))

    if plot_bi_allelic_singletons:
        children = []
        for adj in adj_strats:
            for bisingleton_rank in ['', 'biallelic_singleton_']:
                titles = [y_name, 'Bi-allelic singletons ({} rank)'.format('overall' if not adj and not bisingleton_rank else " ".join([adj[:-1], bisingleton_rank[:-1].replace("_", " ")]).lstrip())]
                plot_df = prepare_pd(df.loc[df.bi_allelic & df.singleton & (df.rank_id == f'{adj}{bisingleton_rank}rank')], cols, colors, size_prop)
                if len(plot_df) > 0:
                    children.append(get_row(plot_df, y_name, cols, y_fun, titles, link_cumul_y, cut))
                else:
                    logger.warn('No data found for plot: {}'.format('\t'.join(titles)))

        if children:
            tabs.append(Panel(child=Column(children=children), title='Bi-allelic singletons'))

    return Tabs(tabs=tabs)


def plot_score_distributions(data_type, models: Union[Dict[str, str], List[str]], snv: bool, cut: int, colors: Dict[str, str] = None) -> Tabs:

0 View Source File : variant_qc_plots.py
License : BSD 3-Clause "New" or "Revised" License
Project Creator : broadinstitute

def plot_score_distributions(data_type, models: Union[Dict[str, str], List[str]], snv: bool, cut: int, colors: Dict[str, str] = None) -> Tabs:
    """
    Generates plots of model scores distributions:
    One tab per model.
    Within each tab, there is 2x2 grid of plots:
    - One row showing the score distribution across the entire data
    - One row showing the score distribution across the release-samples, adj data only (release_sample_AC_ADJ > 0)
    - One column showing the histogram of the score
    - One column showing the normalized cumulative histogram of the score

    Cutoff is highlighted by a dashed red line

    :param str data_type: One of 'exomes' or 'genomes'
    :param list of str or dict of str -> str models: Which models to plot. Can either be a list of models or a dict with mapping from model id to model name for display.
    :param bool snv: Whether to plot SNVs or Indels
    :param int cut: Bin cut on the entire data to highlight
    :param dict of str -> str colors: Optional colors to use (model name -> desired color)
    :return: Plots of the score distributions
    :rtype: Tabs
    """

    if not isinstance(models, dict):
        models = {m: m for m in models}

    if colors is None:
        colors = {m_name: "#033649" for m_name in models.values()}

    tabs = []
    for model_id, model_name in models.items():
        if model_id in ['vqsr', 'cnn', 'rf_2.0.2', 'rf_2.0.2_beta']:
            ht = hl.read_table(score_ranking_path(data_type, model_id, binned=False))
        else:
            ht = hl.read_table(rf_path(data_type, 'rf_result', run_hash=model_id))

        ht = ht.filter(hl.is_snp(ht.alleles[0], ht.alleles[1]), keep=snv)
        binned_ht = hl.read_table(score_ranking_path(data_type, model_id, binned=True))
        binned_ht = binned_ht.filter(binned_ht.snv, keep=snv)

        cut_value = binned_ht.aggregate(hl.agg.filter((binned_ht.bin == cut) & (binned_ht.rank_id == 'rank'), hl.agg.min(binned_ht.min_score)))

        min_score, max_score = (-20, 30) if model_id == 'vqsr' else (0.0, 1.0)
        agg_values = ht.aggregate(hl.struct(
            score_hist=[hl.agg.hist(ht.score, min_score, max_score, 100),
                        hl.agg.filter(ht.ac > 0, hl.agg.hist( ht.score, min_score, max_score, 100))],
            adj_counts=hl.agg.filter(ht.ac > 0, hl.agg.counter( ht.score >= cut_value))
        ))
        score_hist = agg_values.score_hist
        adj_cut = '{0:.2f}'.format(100 * agg_values.adj_counts[True] / (agg_values.adj_counts[True] + agg_values.adj_counts[False]))

        rows = []
        x_range = DataRange1d()
        y_range = [DataRange1d(), DataRange1d()]
        for adj in [False, True]:
            title = '{0}, {1} cut (score = {2:.2f})'.format('Adj' if adj else 'All', adj_cut if adj else cut, cut_value)
            p = plot_hail_hist(score_hist[adj], title=title + "\n", fill_color=colors[model_name])
            p.add_layout(Span(location=cut_value, dimension='height', line_color='red', line_dash='dashed'))
            p.x_range = x_range
            p.y_range = y_range[0]
            set_plots_defaults(p)

            p_cumul = plot_hail_hist_cumulative(score_hist[adj], title=title + ', cumulative', line_color=colors[model_name])
            p_cumul.add_layout(Span(location=cut_value, dimension='height', line_color='red', line_dash='dashed'))
            p_cumul.x_range = x_range
            p_cumul.y_range = y_range[1]
            set_plots_defaults(p_cumul)

            rows.append([p, p_cumul])

        tabs.append(Panel(child=gridplot(rows), title=model_name))
    return Tabs(tabs=tabs)


def get_binned_concordance_pd(data_type: str, truth_samples: List[str], models: Union[Dict[str, str], List[str]]) -> pd.DataFrame:

0 View Source File : view.py
License : BSD 3-Clause "New" or "Revised" License
Project Creator : JaneliaSciComp

def init(_bokeh_document):
    global bokeh_document, cluster_dot_palette, snippet_palette, p_cluster, cluster_dots, p_cluster_dots, precomputed_dots, snippets_dy, p_snippets, snippets_label_sources_clustered, snippets_label_sources_annotated, snippets_wave_sources, snippets_wave_glyphs, snippets_gram_sources, snippets_gram_glyphs, snippets_quad_grey, dot_size_cluster, dot_alpha_cluster, cluster_circle_fuchsia, p_waveform, p_spectrogram, p_probability, probability_source, probability_glyph, spectrogram_source, spectrogram_glyph, waveform_span_red, spectrogram_span_red, waveform_quad_grey_clustered, waveform_quad_grey_annotated, waveform_quad_grey_pan, waveform_quad_fuchsia, spectrogram_quad_grey_clustered, spectrogram_quad_grey_annotated, spectrogram_quad_grey_pan, spectrogram_quad_fuchsia, snippets_quad_fuchsia, waveform_source, waveform_glyph, waveform_label_source_clustered, waveform_label_source_annotated, spectrogram_label_source_clustered, spectrogram_label_source_annotated, which_layer, which_species, which_word, which_nohyphen, which_kind, color_picker, circle_radius, dot_size, dot_alpha, zoom_context, zoom_offset, zoomin, zoomout, reset, panleft, panright, allleft, allout, allright, save_indicator, label_count_widgets, label_text_widgets, play, play_callback, video_toggle, video_div, undo, redo, detect, misses, configuration_file, train, leaveoneout, leaveallout, xvalidate, mistakes, activations, cluster, visualize, accuracy, freeze, classify, ethogram, compare, congruence, status_ticker, waitfor, file_dialog_source, file_dialog_source, configuration_contents, logs, logs_folder, model, model_file, wavtfcsvfiles, wavtfcsvfiles_string, groundtruth, groundtruth_folder, validationfiles, testfiles, validationfiles_string, testfiles_string, wantedwords, wantedwords_string, labeltypes, labeltypes_string, prevalences, prevalences_string, copy, labelsounds, makepredictions, fixfalsepositives, fixfalsenegatives, generalize, tunehyperparameters, findnovellabels, examineerrors, testdensely, doit, time_sigma_string, time_smooth_ms_string, frequency_n_ms_string, frequency_nw_string, frequency_p_string, frequency_smooth_ms_string, nsteps_string, restore_from_string, save_and_validate_period_string, validate_percentage_string, mini_batch_string, kfold_string, activations_equalize_ratio_string, activations_max_samples_string, pca_fraction_variance_to_retain_string, tsne_perplexity_string, tsne_exaggeration_string, umap_neighbors_string, umap_distance_string, cluster_algorithm, cluster_these_layers, precision_recall_ratios_string, context_ms_string, shiftby_ms_string, representation, window_ms_string, stride_ms_string, mel_dct_string, optimizer, learning_rate_string, replicates_string, batch_seed_string, weights_seed_string, file_dialog_string, file_dialog_table, readme_contents, wordcounts, wizard_buttons, action_buttons, parameter_buttons, parameter_textinputs, wizard2actions, action2parameterbuttons, action2parametertextinputs, status_ticker_update, status_ticker_pre, status_ticker_post, model_parameters

    bokeh_document = _bokeh_document

    M.cluster_circle_color = M.cluster_circle_color

    if '#' in M.cluster_dot_palette:
      cluster_dot_palette = ast.literal_eval(M.cluster_dot_palette)
    else:
      cluster_dot_palette = getattr(palettes, M.cluster_dot_palette)

    snippet_palette = getattr(palettes, M.snippets_colormap)

    dot_size_cluster = ColumnDataSource(data=dict(ds=[M.state["dot_size"]]))
    dot_alpha_cluster = ColumnDataSource(data=dict(da=[M.state["dot_alpha"]]))

    cluster_dots = ColumnDataSource(data=dict(dx=[], dy=[], dz=[], dl=[], dc=[]))
    cluster_circle_fuchsia = ColumnDataSource(data=dict(cx=[], cy=[], cz=[], cr=[], cc=[]))
    p_cluster = ScatterNd(dx='dx', dy='dy', dz='dz', dl='dl', dc='dc',
                          dots_source=cluster_dots,
                          cx='cx', cy='cy', cz='cz', cr='cr', cc='cc',
                          circle_fuchsia_source=cluster_circle_fuchsia,
                          ds='ds',
                          dot_size_source=dot_size_cluster,
                          da='da',
                          dot_alpha_source=dot_alpha_cluster,
                          width=M.gui_width_pix//2)
    p_cluster.on_change("click_position", lambda a,o,n: C.cluster_tap_callback(n))

    precomputed_dots = None

    snippets_dy = 2*M.snippets_waveform + 2*M.snippets_spectrogram

    p_snippets = figure(plot_width=M.gui_width_pix//2, \
                        background_fill_color='#FFFFFF', toolbar_location=None)
    p_snippets.toolbar.active_drag = None
    p_snippets.grid.visible = False
    p_snippets.xaxis.visible = False
    p_snippets.yaxis.visible = False

    snippets_gram_sources=[None]*(M.snippets_nx*M.snippets_ny)
    snippets_gram_glyphs=[None]*(M.snippets_nx*M.snippets_ny)
    for ixy in range(M.snippets_nx*M.snippets_ny):
        snippets_gram_sources[ixy]=[None]*M.audio_nchannels
        snippets_gram_glyphs[ixy]=[None]*M.audio_nchannels
        for ichannel in range(M.audio_nchannels):
            snippets_gram_sources[ixy][ichannel] = ColumnDataSource(data=dict(image=[]))
            snippets_gram_glyphs[ixy][ichannel] = p_snippets.image('image',
                    source=snippets_gram_sources[ixy][ichannel],
                    palette=M.spectrogram_colormap)

    snippets_quad_grey = ColumnDataSource(data=dict(left=[], right=[], top=[], bottom=[]))
    p_snippets.quad('left','right','top','bottom',source=snippets_quad_grey,
                fill_color="lightgrey", fill_alpha=0.5, line_color="lightgrey")

    snippets_wave_sources=[None]*(M.snippets_nx*M.snippets_ny)
    snippets_wave_glyphs=[None]*(M.snippets_nx*M.snippets_ny)
    for ixy in range(M.snippets_nx*M.snippets_ny):
        snippets_wave_sources[ixy]=[None]*M.audio_nchannels
        snippets_wave_glyphs[ixy]=[None]*M.audio_nchannels
        for ichannel in range(M.audio_nchannels):
            snippets_wave_sources[ixy][ichannel]=ColumnDataSource(data=dict(x=[], y=[]))
            snippets_wave_glyphs[ixy][ichannel]=p_snippets.line(
                    'x', 'y', source=snippets_wave_sources[ixy][ichannel])

    xdata = [(i%M.snippets_nx)*(M.snippets_gap_pix+M.snippets_pix)
             for i in range(M.snippets_nx*M.snippets_ny)]
    ydata = [-(i//M.snippets_nx*snippets_dy-1)
             for i in range(M.snippets_nx*M.snippets_ny)]
    text = ['' for i in range(M.snippets_nx*M.snippets_ny)]
    snippets_label_sources_clustered = ColumnDataSource(data=dict(x=xdata, y=ydata, text=text))
    p_snippets.text('x', 'y', source=snippets_label_sources_clustered, text_font_size='6pt',
                    text_baseline='top',
                    text_color='black' if M.snippets_waveform else 'white')

    xdata = [(i%M.snippets_nx)*(M.snippets_gap_pix+M.snippets_pix)
             for i in range(M.snippets_nx*M.snippets_ny)]
    ydata = [-(i//M.snippets_nx*snippets_dy+1+2*(M.snippets_waveform and M.snippets_spectrogram))
             for i in range(M.snippets_nx*M.snippets_ny)]
    text_annotated = ['' for i in range(M.snippets_nx*M.snippets_ny)]
    snippets_label_sources_annotated = ColumnDataSource(data=dict(x=xdata, y=ydata,
                                                                  text=text_annotated))
    p_snippets.text('x', 'y', source=snippets_label_sources_annotated,
                    text_font_size='6pt',
                    text_color='white' if M.snippets_spectrogram else 'black')

    snippets_quad_fuchsia = ColumnDataSource(data=dict(left=[], right=[], top=[], bottom=[]))
    p_snippets.quad('left','right','top','bottom',source=snippets_quad_fuchsia,
                fill_color=None, line_color="fuchsia")

    p_snippets.on_event(Tap, C.snippets_tap_callback)
    p_snippets.on_event(DoubleTap, C.snippets_doubletap_callback)

    p_waveform = figure(plot_width=M.gui_width_pix,
                        plot_height=M.context_waveform_height_pix,
                        background_fill_color='#FFFFFF', toolbar_location=None)
    p_waveform.toolbar.active_drag = None
    p_waveform.grid.visible = False
    if M.context_spectrogram:
        p_waveform.xaxis.visible = False
    else:
        p_waveform.xaxis.axis_label = 'Time (sec)'
    p_waveform.yaxis.axis_label = ""
    p_waveform.yaxis.ticker = []
    p_waveform.x_range.range_padding = p_waveform.y_range.range_padding = 0.0
    p_waveform.y_range.start = -1
    p_waveform.y_range.end = 1
    p_waveform.title.text=' '

    waveform_span_red = Span(location=0, dimension='height', line_color='red')
    p_waveform.add_layout(waveform_span_red)
    waveform_span_red.visible=False

    waveform_quad_grey_clustered = ColumnDataSource(data=dict(left=[], right=[], top=[], bottom=[]))
    p_waveform.quad('left','right','top','bottom',source=waveform_quad_grey_clustered,
                fill_color="lightgrey", fill_alpha=0.5, line_color="lightgrey",
                level='underlay')
    waveform_quad_grey_annotated = ColumnDataSource(data=dict(left=[], right=[], top=[], bottom=[]))
    p_waveform.quad('left','right','top','bottom',source=waveform_quad_grey_annotated,
                fill_color="lightgrey", fill_alpha=0.5, line_color="lightgrey",
                level='underlay')
    waveform_quad_grey_pan = ColumnDataSource(data=dict(left=[], right=[], top=[], bottom=[]))
    p_waveform.quad('left','right','top','bottom',source=waveform_quad_grey_pan,
                fill_color="lightgrey", fill_alpha=0.5, line_color="lightgrey",
                level='underlay')
    waveform_quad_fuchsia = ColumnDataSource(data=dict(left=[], right=[], top=[], bottom=[]))
    p_waveform.quad('left','right','top','bottom',source=waveform_quad_fuchsia,
                fill_color=None, line_color="fuchsia", level='underlay')

    waveform_source=[None]*M.audio_nchannels
    waveform_glyph=[None]*M.audio_nchannels
    for ichannel in range(M.audio_nchannels):
        waveform_source[ichannel] = ColumnDataSource(data=dict(x=[], y=[]))
        waveform_glyph[ichannel] = p_waveform.line('x', 'y', source=waveform_source[ichannel])

    waveform_label_source_clustered = ColumnDataSource(data=dict(x=[], y=[], text=[]))
    p_waveform.text('x', 'y', source=waveform_label_source_clustered,
                   text_font_size='6pt', text_align='center', text_baseline='top',
                   text_line_height=0.8, level='underlay')
    waveform_label_source_annotated = ColumnDataSource(data=dict(x=[], y=[], text=[]))
    p_waveform.text('x', 'y', source=waveform_label_source_annotated,
                   text_font_size='6pt', text_align='center', text_baseline='bottom',
                   text_line_height=0.8, level='underlay')

    p_waveform.on_event(DoubleTap, lambda e: C.context_doubletap_callback(e, 0))

    p_waveform.on_event(PanStart, C.waveform_pan_start_callback)
    p_waveform.on_event(Pan, C.waveform_pan_callback)
    p_waveform.on_event(PanEnd, C.waveform_pan_end_callback)
    p_waveform.on_event(Tap, C.waveform_tap_callback)

    p_spectrogram = figure(plot_width=M.gui_width_pix,
                           plot_height=M.context_spectrogram_height_pix,
                           background_fill_color='#FFFFFF', toolbar_location=None)
    p_spectrogram.toolbar.active_drag = None
    p_spectrogram.x_range.range_padding = p_spectrogram.y_range.range_padding = 0
    p_spectrogram.xgrid.visible = False
    p_spectrogram.ygrid.visible = True
    p_spectrogram.xaxis.axis_label = 'Time (sec)'
    p_spectrogram.yaxis.axis_label = 'Frequency (' + M.context_spectrogram_units + ')'
    p_spectrogram.yaxis.ticker = list(range(1+M.audio_nchannels))

    spectrogram_source = [None]*M.audio_nchannels
    spectrogram_glyph = [None]*M.audio_nchannels
    for ichannel in range(M.audio_nchannels):
        spectrogram_source[ichannel] = ColumnDataSource(data=dict(image=[]))
        spectrogram_glyph[ichannel] = p_spectrogram.image('image',
                                                          source=spectrogram_source[ichannel],
                                                          palette=M.spectrogram_colormap,
                                                          level="image")

    p_spectrogram.on_event(MouseWheel, C.spectrogram_mousewheel_callback)

    p_spectrogram.on_event(DoubleTap,
                           lambda e: C.context_doubletap_callback(e, M.audio_nchannels/2))

    p_spectrogram.on_event(PanStart, C.spectrogram_pan_start_callback)
    p_spectrogram.on_event(Pan, C.spectrogram_pan_callback)
    p_spectrogram.on_event(PanEnd, C.spectrogram_pan_end_callback)
    p_spectrogram.on_event(Tap, C.spectrogram_tap_callback)

    spectrogram_span_red = Span(location=0, dimension='height', line_color='red')
    p_spectrogram.add_layout(spectrogram_span_red)
    spectrogram_span_red.visible=False

    spectrogram_quad_grey_clustered = ColumnDataSource(data=dict(left=[], right=[], top=[], bottom=[]))
    p_spectrogram.quad('left','right','top','bottom',source=spectrogram_quad_grey_clustered,
                fill_color="lightgrey", fill_alpha=0.5, line_color="lightgrey",
                level='underlay')
    spectrogram_quad_grey_annotated = ColumnDataSource(data=dict(left=[], right=[], top=[], bottom=[]))
    p_spectrogram.quad('left','right','top','bottom',source=spectrogram_quad_grey_annotated,
                fill_color="lightgrey", fill_alpha=0.5, line_color="lightgrey",
                level='underlay')
    spectrogram_quad_grey_pan = ColumnDataSource(data=dict(left=[], right=[], top=[], bottom=[]))
    p_spectrogram.quad('left','right','top','bottom',source=spectrogram_quad_grey_pan,
                fill_color="lightgrey", fill_alpha=0.5, line_color="lightgrey",
                level='underlay')
    spectrogram_quad_fuchsia = ColumnDataSource(data=dict(left=[], right=[], top=[], bottom=[]))
    p_spectrogram.quad('left','right','top','bottom',source=spectrogram_quad_fuchsia,
                fill_color=None, line_color="fuchsia", level='underlay')

    spectrogram_label_source_clustered = ColumnDataSource(data=dict(x=[], y=[], text=[]))
    p_spectrogram.text('x', 'y', source=spectrogram_label_source_clustered,
                   text_font_size='6pt', text_align='center', text_baseline='top',
                   text_line_height=0.8, level='underlay', text_color='white')
    spectrogram_label_source_annotated = ColumnDataSource(data=dict(x=[], y=[], text=[]))
    p_spectrogram.text('x', 'y', source=spectrogram_label_source_annotated,
                   text_font_size='6pt', text_align='center', text_baseline='bottom',
                   text_line_height=0.8, level='underlay', text_color='white')

    TOOLTIPS = """
          <  div> < div> < span style="color:@colors;">@labels < /span> < /div> < /div>
    """

    p_probability = figure(plot_width=M.gui_width_pix, tooltips=TOOLTIPS,
                           plot_height=M.context_probability_height_pix,
                           background_fill_color='#FFFFFF', toolbar_location=None)
    p_probability.toolbar.active_drag = None
    p_probability.grid.visible = False
    p_probability.yaxis.axis_label = "Probability"
    p_probability.x_range.range_padding = p_probability.y_range.range_padding = 0.0
    p_probability.y_range.start = 0
    p_probability.y_range.end = 1
    p_probability.xaxis.visible = False

    probability_source = ColumnDataSource(data=dict(xs=[], ys=[], colors=[], labels=[]))
    probability_glyph = p_probability.multi_line(xs='xs', ys='ys',
                                                 source=probability_source, color='colors')

    probability_span_red = Span(location=0, dimension='height', line_color='red')
    p_probability.add_layout(probability_span_red)
    probability_span_red.visible=False

    which_layer = Select(title="layer:")
    which_layer.on_change('value', lambda a,o,n: C.layer_callback(n))

    which_species = Select(title="species:")
    which_species.on_change('value', lambda a,o,n: C.species_callback(n))

    which_word = Select(title="word:")
    which_word.on_change('value', lambda a,o,n: C.word_callback(n))

    which_nohyphen = Select(title="no hyphen:")
    which_nohyphen.on_change('value', lambda a,o,n: C.nohyphen_callback(n))

    which_kind = Select(title="kind:")
    which_kind.on_change('value', lambda a,o,n: C.kind_callback(n))

    color_picker = ColorPicker(title="color:", disabled=True)
    color_picker.on_change("color", lambda a,o,n: C.color_picker_callback(n))

    circle_radius = Slider(start=0, end=10, step=1, \
                           value=M.state["circle_radius"], \
                           title="circle radius", \
                           disabled=True)
    circle_radius.on_change("value_throttled", C.circle_radius_callback)

    dot_size = Slider(start=1, end=24, step=1, \
                      value=M.state["dot_size"], \
                      title="dot size", \
                      disabled=True)
    dot_size.on_change("value", C.dot_size_callback)

    dot_alpha = Slider(start=0.01, end=1.0, step=0.01, \
                       value=M.state["dot_alpha"], \
                       title="dot alpha", \
                       disabled=True)
    dot_alpha.on_change("value", C.dot_alpha_callback)

    cluster_update()

    zoom_context = TextInput(value=str(M.context_width_ms),
                             title="context (msec):",
                             disabled=True)
    zoom_context.on_change("value", C.zoom_context_callback)

    zoom_offset = TextInput(value=str(M.context_offset_ms),
                            title="offset (msec):",
                            disabled=True)
    zoom_offset.on_change("value", C.zoom_offset_callback)

    zoomin = Button(label='\u2191', disabled=True)
    zoomin.on_click(C.zoomin_callback)

    zoomout = Button(label='\u2193', disabled=True)
    zoomout.on_click(C.zoomout_callback)

    reset = Button(label='\u25ef', disabled=True)
    reset.on_click(C.zero_callback)

    panleft = Button(label='\u2190', disabled=True)
    panleft.on_click(C.panleft_callback)

    panright = Button(label='\u2192', disabled=True)
    panright.on_click(C.panright_callback)

    allleft = Button(label='\u21e4', disabled=True)
    allleft.on_click(C.allleft_callback)

    allout = Button(label='\u2913', disabled=True)
    allout.on_click(C.allout_callback)

    allright = Button(label='\u21e5', disabled=True)
    allright.on_click(C.allright_callback)

    save_indicator = Button(label='0')

    label_count_callbacks=[]
    label_count_widgets=[]
    label_text_callbacks=[]
    label_text_widgets=[]

    for i in range(M.nlabels):
        label_count_callbacks.append(lambda i=i: C.label_count_callback(i))
        label_count_widgets.append(Button(label='0', css_classes=['hide-label'], width=40))
        label_count_widgets[-1].on_click(label_count_callbacks[-1])

        label_text_callbacks.append(lambda a,o,n,i=i: C.label_text_callback(n,i))
        label_text_widgets.append(TextInput(value=M.state['labels'][i],
                                            css_classes=['hide-label']))
        label_text_widgets[-1].on_change("value", label_text_callbacks[-1])

    C.label_count_callback(M.ilabel)

    play = Button(label='play', disabled=True)
    play_callback = CustomJS(args=dict(waveform_span_red=waveform_span_red,
                                       spectrogram_span_red=spectrogram_span_red,
                                       probability_span_red=probability_span_red),
                             code=C.play_callback_code % ("",""))
    play.js_on_event(ButtonClick, play_callback)
    play.on_change('disabled', lambda a,o,n: reset_video())
    play.js_on_change('disabled', play_callback)

    video_toggle = Toggle(label='video', active=False, disabled=True)
    video_toggle.on_click(lambda x: context_update())

    video_div = Div(text=""" < video id="context_video"> < /video>""", width=0, height=0)

    undo = Button(label='undo', disabled=True)
    undo.on_click(C.undo_callback)

    redo = Button(label='redo', disabled=True)
    redo.on_click(C.redo_callback)

    detect = Button(label='detect')
    detect.on_click(lambda: C.action_callback(detect, C.detect_actuate))

    misses = Button(label='misses')
    misses.on_click(lambda: C.action_callback(misses, C.misses_actuate))

    train = Button(label='train')
    train.on_click(lambda: C.action_callback(train, C.train_actuate))

    leaveoneout = Button(label='omit one')
    leaveoneout.on_click(lambda: C.action_callback(leaveoneout,
                                                   lambda: C.leaveout_actuate(False)))

    leaveallout = Button(label='omit all')
    leaveallout.on_click(lambda: C.action_callback(leaveallout,
                                                   lambda: C.leaveout_actuate(True)))

    xvalidate = Button(label='x-validate')
    xvalidate.on_click(lambda: C.action_callback(xvalidate, C.xvalidate_actuate))

    mistakes = Button(label='mistakes')
    mistakes.on_click(lambda: C.action_callback(mistakes, C.mistakes_actuate))

    activations = Button(label='activations')
    activations.on_click(lambda: C.action_callback(activations, C.activations_actuate))

    cluster = Button(label='cluster')
    cluster.on_click(lambda: C.action_callback(cluster, C.cluster_actuate))

    visualize = Button(label='visualize')
    visualize.on_click(lambda: C.action_callback(visualize, C.visualize_actuate))

    accuracy = Button(label='accuracy')
    accuracy.on_click(lambda: C.action_callback(accuracy, C.accuracy_actuate))

    freeze = Button(label='freeze')
    freeze.on_click(lambda: C.action_callback(freeze, C.freeze_actuate))

    classify = Button(label='classify')
    classify.on_click(C.classify_callback)

    ethogram = Button(label='ethogram')
    ethogram.on_click(lambda: C.action_callback(ethogram, C.ethogram_actuate))

    compare = Button(label='compare')
    compare.on_click(lambda: C.action_callback(compare, C.compare_actuate))

    congruence = Button(label='congruence')
    congruence.on_click(lambda: C.action_callback(congruence, C.congruence_actuate))

    status_ticker_pre=" < div style='overflow:auto; white-space:nowrap; width:"+str(M.gui_width_pix-126)+"px'>status: "
    status_ticker_post=" < /div>"
    status_ticker = Div(text=status_ticker_pre+status_ticker_post)

    file_dialog_source = ColumnDataSource(data=dict(names=[], sizes=[], dates=[]))
    file_dialog_source.selected.on_change('indices', C.file_dialog_callback)

    file_dialog_columns = [
        TableColumn(field="names", title="Name", width=M.gui_width_pix//2-50-115-30),
        TableColumn(field="sizes", title="Size", width=50, \
                    formatter=NumberFormatter(format="0 b")),
        TableColumn(field="dates", title="Date", width=115, \
                    formatter=DateFormatter(format="%Y-%m-%d %H:%M:%S")),
    ]
    file_dialog_table = DataTable(source=file_dialog_source, \
                                  columns=file_dialog_columns, \
                                  height=727, width=M.gui_width_pix//2-11, \
                                  index_position=None,
                                  fit_columns=False)

    waitfor = Toggle(label='wait for last job', active=False, disabled=True, width=100)
    waitfor.on_click(C.waitfor_callback)

    logs = Button(label='logs folder:', width=110)
    logs.on_click(C.logs_callback)
    logs_folder = TextInput(value=M.state['logs'], title="", disabled=False)
    logs_folder.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    model = Button(label='checkpoint file:', width=110)
    model.on_click(C.model_callback)
    model_file = TextInput(value=M.state['model'], title="", disabled=False)
    model_file.on_change('value', model_file_update)

    wavtfcsvfiles = Button(label='wav,tf,csv files:', width=110)
    wavtfcsvfiles.on_click(C.wavtfcsvfiles_callback)
    wavtfcsvfiles_string = TextInput(value=M.state['wavtfcsvfiles'], title="", disabled=False)
    wavtfcsvfiles_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    groundtruth = Button(label='ground truth:', width=110)
    groundtruth.on_click(C.groundtruth_callback)
    groundtruth_folder = TextInput(value=M.state['groundtruth'], title="", disabled=False)
    groundtruth_folder.on_change('value', lambda a,o,n: groundtruth_update())

    validationfiles = Button(label='validation files:', width=110)
    validationfiles.on_click(C.validationfiles_callback)
    validationfiles_string = TextInput(value=M.state['validationfiles'], title="", disabled=False)
    validationfiles_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    testfiles = Button(label='test files:', width=110)
    testfiles.on_click(C.testfiles_callback)
    testfiles_string = TextInput(value=M.state['testfiles'], title="", disabled=False)
    testfiles_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    wantedwords = Button(label='wanted words:', width=110)
    wantedwords.on_click(C.wantedwords_callback)
    wantedwords_string = TextInput(value=M.state['wantedwords'], title="", disabled=False)
    wantedwords_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    labeltypes = Button(label='label types:', width=110)
    labeltypes_string = TextInput(value=M.state['labeltypes'], title="", disabled=False)
    labeltypes_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    prevalences = Button(label='prevalences:', width=110)
    prevalences.on_click(C.prevalences_callback)
    prevalences_string = TextInput(value=M.state['prevalences'], title="", disabled=False)
    prevalences_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    copy = Button(label='copy')
    copy.on_click(C.copy_callback)

    labelsounds = Button(label='label sounds')
    labelsounds.on_click(lambda: C.wizard_callback(labelsounds))

    makepredictions = Button(label='make predictions')
    makepredictions.on_click(lambda: C.wizard_callback(makepredictions))

    fixfalsepositives = Button(label='fix false positives')
    fixfalsepositives.on_click(lambda: C.wizard_callback(fixfalsepositives))

    fixfalsenegatives = Button(label='fix false negatives')
    fixfalsenegatives.on_click(lambda: C.wizard_callback(fixfalsenegatives))

    generalize = Button(label='test generalization')
    generalize.on_click(lambda: C.wizard_callback(generalize))

    tunehyperparameters = Button(label='tune h-parameters')
    tunehyperparameters.on_click(lambda: C.wizard_callback(tunehyperparameters))

    findnovellabels = Button(label='find novel labels')
    findnovellabels.on_click(lambda: C.wizard_callback(findnovellabels))

    examineerrors = Button(label='examine errors')
    examineerrors.on_click(lambda: C.wizard_callback(examineerrors))

    testdensely = Button(label='test densely')
    testdensely .on_click(lambda: C.wizard_callback(testdensely))

    doit = Button(label='do it!', disabled=True)
    doit.on_click(C.doit_callback)

    time_sigma_string = TextInput(value=M.state['time_sigma'], \
                                  title="time σ", \
                                  disabled=False)
    time_sigma_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    time_smooth_ms_string = TextInput(value=M.state['time_smooth_ms'], \
                                      title="time smooth", \
                                      disabled=False)
    time_smooth_ms_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    frequency_n_ms_string = TextInput(value=M.state['frequency_n_ms'], \
                                      title="freq N (msec)", \
                                      disabled=False)
    frequency_n_ms_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    frequency_nw_string = TextInput(value=M.state['frequency_nw'], \
                                    title="freq NW", \
                                    disabled=False)
    frequency_nw_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    frequency_p_string = TextInput(value=M.state['frequency_p'], \
                                   title="freq ρ", \
                                   disabled=False)
    frequency_p_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    frequency_smooth_ms_string = TextInput(value=M.state['frequency_smooth_ms'], \
                                           title="freq smooth", \
                                           disabled=False)
    frequency_smooth_ms_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    nsteps_string = TextInput(value=M.state['nsteps'], title="# steps", disabled=False)
    nsteps_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    restore_from_string = TextInput(value=M.state['restore_from'], title="restore from", disabled=False)
    restore_from_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    save_and_validate_period_string = TextInput(value=M.state['save_and_validate_interval'], \
                                                title="validate period", \
                                                disabled=False)
    save_and_validate_period_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    validate_percentage_string = TextInput(value=M.state['validate_percentage'], \
                                           title="validate %", \
                                           disabled=False)
    validate_percentage_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    mini_batch_string = TextInput(value=M.state['mini_batch'], \
                                  title="mini-batch", \
                                  disabled=False)
    mini_batch_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    kfold_string = TextInput(value=M.state['kfold'], title="k-fold",  disabled=False)
    kfold_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    activations_equalize_ratio_string = TextInput(value=M.state['activations_equalize_ratio'], \
                                             title="equalize ratio", \
                                             disabled=False)
    activations_equalize_ratio_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    activations_max_samples_string = TextInput(value=M.state['activations_max_samples'], \
                                          title="max samples", \
                                          disabled=False)
    activations_max_samples_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    pca_fraction_variance_to_retain_string = TextInput(value=M.state['pca_fraction_variance_to_retain'], \
                                                       title="PCA fraction", \
                                                       disabled=False)
    pca_fraction_variance_to_retain_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    tsne_perplexity_string = TextInput(value=M.state['tsne_perplexity'], \
                                       title="perplexity", \
                                       disabled=False)
    tsne_perplexity_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    tsne_exaggeration_string = TextInput(value=M.state['tsne_exaggeration'], \
                                        title="exaggeration", \
                                        disabled=False)
    tsne_exaggeration_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    umap_neighbors_string = TextInput(value=M.state['umap_neighbors'], \
                                      title="neighbors", \
                                      disabled=False)
    umap_neighbors_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    umap_distance_string = TextInput(value=M.state['umap_distance'], \
                                     title="distance", \
                                     disabled=False)
    umap_distance_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    precision_recall_ratios_string = TextInput(value=M.state['precision_recall_ratios'], \
                                               title="P/Rs", \
                                               disabled=False)
    precision_recall_ratios_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))
    
    context_ms_string = TextInput(value=M.state['context_ms'], \
                                  title="context (msec)", \
                                  disabled=False)
    context_ms_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    shiftby_ms_string = TextInput(value=M.state['shiftby_ms'], \
                                  title="shift by (msec)", \
                                  disabled=False)
    shiftby_ms_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    representation = Select(title="representation", height=50, \
                            value=M.state['representation'], \
                            options=["waveform", "spectrogram", "mel-cepstrum"])
    representation.on_change('value', lambda a,o,n: C.generic_parameters_callback(''))

    window_ms_string = TextInput(value=M.state['window_ms'], \
                                 title="window (msec)", \
                                 disabled=False)
    window_ms_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    stride_ms_string = TextInput(value=M.state['stride_ms'], \
                                 title="stride (msec)", \
                                 disabled=False)
    stride_ms_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    mel_dct_string = TextInput(value=M.state['mel&dct'], \
                               title="Mel & DCT", \
                               disabled=False)
    mel_dct_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    optimizer = Select(title="optimizer", height=50, \
                       value=M.state['optimizer'], \
                       options=[("sgd","SGD"), ("adam","Adam"), ("adagrad","AdaGrad"), \
                                ("rmsprop","RMSProp")])
    optimizer.on_change('value', lambda a,o,n: C.generic_parameters_callback(''))

    learning_rate_string = TextInput(value=M.state['learning_rate'], \
                                     title="learning rate", \
                                     disabled=False)
    learning_rate_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    model_parameters = OrderedDict()
    for parameter in M.model_parameters:
      if parameter[2]=='':
        thisparameter = TextInput(value=M.state[parameter[0]], \
                                  title=parameter[1], \
                                  disabled=False, width=94)
        thisparameter.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))
      else:
        thisparameter = Select(value=M.state[parameter[0]], \
                               title=parameter[1], \
                               options=parameter[2], \
                               height=50, width=94)
      model_parameters[parameter[0]] = thisparameter

    configuration_contents = TextAreaInput(rows=49-3*np.ceil(len(model_parameters)/6).astype(np.int),
                                           max_length=50000, \
                                           disabled=True, css_classes=['fixedwidth'])
    if M.configuration_file:
        with open(M.configuration_file, 'r') as fid:
            configuration_contents.value = fid.read()


    cluster_algorithm = Select(title="cluster", height=50, \
                               value=M.state['cluster_algorithm'], \
                               options=["PCA 2D", "PCA 3D", \
                                        "tSNE 2D", "tSNE 3D", \
                                        "UMAP 2D", "UMAP 3D"])
    cluster_algorithm.on_change('value', lambda a,o,n: C.generic_parameters_callback(''))

    cluster_these_layers = MultiSelect(title='layers', \
                                       value=M.state['cluster_these_layers'], \
                                       options=[],
                                       height=108)
    cluster_these_layers.on_change('value', lambda a,o,n: C.generic_parameters_callback(''))
    cluster_these_layers_update()

    replicates_string = TextInput(value=M.state['replicates'], \
                                  title="replicates", \
                                  disabled=False)
    replicates_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    batch_seed_string = TextInput(value=M.state['batch_seed'], \
                                  title="batch seed", \
                                  disabled=False)
    batch_seed_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    weights_seed_string = TextInput(value=M.state['weights_seed'], \
                                    title="weights seed", \
                                    disabled=False)
    weights_seed_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    file_dialog_string = TextInput(disabled=False)
    file_dialog_string.on_change("value", C.file_dialog_path_callback)
    file_dialog_string.value = M.state['file_dialog_string']
     
    with open(os.path.join(os.path.dirname(os.path.realpath(__file__)),'..','..','README.md'), 'r', encoding='utf-8') as fid:
        contents = fid.read()
    html = markdown.markdown(contents, extensions=['tables','toc'])
    readme_contents = Div(text=html, style={'overflow':'scroll','width':'600px','height':'1397px'})

    wordcounts = Div(text="")
    wordcounts_update()

    wizard_buttons = set([
        labelsounds,
        makepredictions,
        fixfalsepositives,
        fixfalsenegatives,
        generalize,
        tunehyperparameters,
        findnovellabels,
        examineerrors,
        testdensely])

    action_buttons = set([
        detect,
        train,
        leaveoneout,
        leaveallout,
        xvalidate,
        mistakes,
        activations,
        cluster,
        visualize,
        accuracy,
        freeze,
        classify,
        ethogram,
        misses,
        compare,
        congruence])

    parameter_buttons = set([
        logs,
        model,
        wavtfcsvfiles,
        groundtruth,
        validationfiles,
        testfiles,
        wantedwords,
        labeltypes,
        prevalences])

    parameter_textinputs = set([
        logs_folder,
        model_file,
        wavtfcsvfiles_string,
        groundtruth_folder,
        validationfiles_string,
        testfiles_string,
        wantedwords_string,
        labeltypes_string,
        prevalences_string,

        time_sigma_string,
        time_smooth_ms_string,
        frequency_n_ms_string,
        frequency_nw_string,
        frequency_p_string,
        frequency_smooth_ms_string,
        nsteps_string,
        restore_from_string,
        save_and_validate_period_string,
        validate_percentage_string,
        mini_batch_string,
        kfold_string,
        activations_equalize_ratio_string,
        activations_max_samples_string,
        pca_fraction_variance_to_retain_string,
        tsne_perplexity_string,
        tsne_exaggeration_string,
        umap_neighbors_string,
        umap_distance_string,
        cluster_algorithm,
        cluster_these_layers,
        precision_recall_ratios_string,
        replicates_string,
        batch_seed_string,
        weights_seed_string,

        context_ms_string,
        shiftby_ms_string,
        representation,
        window_ms_string,
        stride_ms_string,
        mel_dct_string,
        optimizer,
        learning_rate_string] +

        list(model_parameters.values()))

    wizard2actions = {
            labelsounds: [detect,train,activations,cluster,visualize],
            makepredictions: [train, accuracy, freeze, classify, ethogram],
            fixfalsepositives: [activations, cluster, visualize],
            fixfalsenegatives: [detect, misses, activations, cluster, visualize],
            generalize: [leaveoneout, leaveallout, accuracy],
            tunehyperparameters: [xvalidate, accuracy, compare],
            findnovellabels: [detect, train, activations, cluster, visualize],
            examineerrors: [detect, mistakes, activations, cluster, visualize],
            testdensely: [detect, activations, cluster, visualize, classify, ethogram, congruence],
            None: action_buttons }

    action2parameterbuttons = {
            detect: [wavtfcsvfiles],
            train: [logs, groundtruth, wantedwords, testfiles, labeltypes],
            leaveoneout: [logs, groundtruth, validationfiles, testfiles, wantedwords, labeltypes],
            leaveallout: [logs, groundtruth, validationfiles, testfiles, wantedwords, labeltypes],
            xvalidate: [logs, groundtruth, testfiles, wantedwords, labeltypes],
            mistakes: [groundtruth],
            activations: [logs, model, groundtruth, wantedwords, labeltypes],
            cluster: [groundtruth],
            visualize: [groundtruth],
            accuracy: [logs],
            freeze: [logs, model],
            classify: [logs, model, wavtfcsvfiles, wantedwords, prevalences],
            ethogram: [model, wavtfcsvfiles],
            misses: [wavtfcsvfiles],
            compare: [logs],
            congruence: [groundtruth, validationfiles, testfiles],
            None: parameter_buttons }

    action2parametertextinputs = {
            detect: [wavtfcsvfiles_string, time_sigma_string, time_smooth_ms_string, frequency_n_ms_string, frequency_nw_string, frequency_p_string, frequency_smooth_ms_string],
            train: [context_ms_string, shiftby_ms_string, representation, window_ms_string, stride_ms_string, mel_dct_string, optimizer, learning_rate_string, replicates_string, batch_seed_string, weights_seed_string, logs_folder, groundtruth_folder, testfiles_string, wantedwords_string, labeltypes_string, nsteps_string, restore_from_string, save_and_validate_period_string, validate_percentage_string, mini_batch_string] + list(model_parameters.values()),
            leaveoneout: [context_ms_string, shiftby_ms_string, representation, window_ms_string, stride_ms_string, mel_dct_string, optimizer, learning_rate_string, batch_seed_string, weights_seed_string, logs_folder, groundtruth_folder, validationfiles_string, testfiles_string, wantedwords_string, labeltypes_string, nsteps_string, restore_from_string, save_and_validate_period_string, mini_batch_string] + list(model_parameters.values()),
            leaveallout: [context_ms_string, shiftby_ms_string, representation, window_ms_string, stride_ms_string, mel_dct_string, optimizer, learning_rate_string, batch_seed_string, weights_seed_string, logs_folder, groundtruth_folder, validationfiles_string, testfiles_string, wantedwords_string, labeltypes_string, nsteps_string, restore_from_string, save_and_validate_period_string, mini_batch_string] + list(model_parameters.values()),
            xvalidate: [context_ms_string, shiftby_ms_string, representation, window_ms_string, stride_ms_string, mel_dct_string, optimizer, learning_rate_string, batch_seed_string, weights_seed_string, logs_folder, groundtruth_folder, testfiles_string, wantedwords_string, labeltypes_string, nsteps_string, restore_from_string, save_and_validate_period_string, mini_batch_string, kfold_string] + list(model_parameters.values()),
            mistakes: [groundtruth_folder],
            activations: [context_ms_string, shiftby_ms_string, representation, window_ms_string, stride_ms_string, mel_dct_string, logs_folder, model_file, groundtruth_folder, wantedwords_string, labeltypes_string, activations_equalize_ratio_string, activations_max_samples_string, mini_batch_string] + list(model_parameters.values()),
            cluster: [groundtruth_folder, cluster_algorithm, cluster_these_layers, pca_fraction_variance_to_retain_string, tsne_perplexity_string, tsne_exaggeration_string, umap_neighbors_string, umap_distance_string],
            visualize: [groundtruth_folder],
            accuracy: [logs_folder, precision_recall_ratios_string],
            freeze: [context_ms_string, representation, window_ms_string, stride_ms_string, mel_dct_string, logs_folder, model_file] + list(model_parameters.values()),
            classify: [context_ms_string, shiftby_ms_string, representation, stride_ms_string, logs_folder, model_file, wavtfcsvfiles_string, wantedwords_string, prevalences_string] + list(model_parameters.values()),
            ethogram: [model_file, wavtfcsvfiles_string],
            misses: [wavtfcsvfiles_string],
            compare: [logs_folder],
            congruence: [groundtruth_folder, validationfiles_string, testfiles_string],
            None: parameter_textinputs }

0 View Source File : _plotting.py
License : GNU Affero General Public License v3.0
Project Creator : kernc

def plot(*, results: pd.Series,
         df: pd.DataFrame,
         indicators: List[_Indicator],
         filename='', plot_width=None,
         plot_equity=True, plot_return=False, plot_pl=True,
         plot_volume=True, plot_drawdown=False,
         smooth_equity=False, relative_equity=True,
         superimpose=True, resample=True,
         reverse_indicators=True,
         show_legend=True, open_browser=True):
    """
    Like much of GUI code everywhere, this is a mess.
    """
    # We need to reset global Bokeh state, otherwise subsequent runs of
    # plot() contain some previous run's cruft data (was noticed when
    # TestPlot.test_file_size() test was failing).
    if not filename and not IS_JUPYTER_NOTEBOOK:
        filename = _windos_safe_filename(str(results._strategy))
    _bokeh_reset(filename)

    COLORS = [BEAR_COLOR, BULL_COLOR]
    BAR_WIDTH = .8

    assert df.index.equals(results['_equity_curve'].index)
    equity_data = results['_equity_curve'].copy(deep=False)
    trades = results['_trades']

    plot_volume = plot_volume and not df.Volume.isnull().all()
    plot_equity = plot_equity and not trades.empty
    plot_return = plot_return and not trades.empty
    plot_pl = plot_pl and not trades.empty
    is_datetime_index = isinstance(df.index, pd.DatetimeIndex)

    from .lib import OHLCV_AGG
    # ohlc df may contain many columns. We're only interested in, and pass on to Bokeh, these
    df = df[list(OHLCV_AGG.keys())].copy(deep=False)

    # Limit data to max_candles
    if is_datetime_index:
        df, indicators, equity_data, trades = _maybe_resample_data(
            resample, df, indicators, equity_data, trades)

    df.index.name = None  # Provides source name @index
    df['datetime'] = df.index  # Save original, maybe datetime index
    df = df.reset_index(drop=True)
    equity_data = equity_data.reset_index(drop=True)
    index = df.index

    new_bokeh_figure = partial(
        _figure,
        x_axis_type='linear',
        width=plot_width,
        height=400,
        tools="xpan,xwheel_zoom,box_zoom,undo,redo,reset,save",
        active_drag='xpan',
        active_scroll='xwheel_zoom')

    pad = (index[-1] - index[0]) / 20

    fig_ohlc = new_bokeh_figure(
        x_range=Range1d(index[0], index[-1],
                        min_interval=10,
                        bounds=(index[0] - pad,
                                index[-1] + pad)) if index.size > 1 else None)
    figs_above_ohlc, figs_below_ohlc = [], []

    source = ColumnDataSource(df)
    source.add((df.Close >= df.Open).values.astype(np.uint8).astype(str), 'inc')

    trade_source = ColumnDataSource(dict(
        index=trades['ExitBar'],
        datetime=trades['ExitTime'],
        exit_price=trades['ExitPrice'],
        size=trades['Size'],
        returns_positive=(trades['ReturnPct'] > 0).astype(int).astype(str),
    ))

    inc_cmap = factor_cmap('inc', COLORS, ['0', '1'])
    cmap = factor_cmap('returns_positive', COLORS, ['0', '1'])
    colors_darker = [lightness(BEAR_COLOR, .35),
                     lightness(BULL_COLOR, .35)]
    trades_cmap = factor_cmap('returns_positive', colors_darker, ['0', '1'])

    if is_datetime_index:
        fig_ohlc.xaxis.formatter = CustomJSTickFormatter(
            args=dict(axis=fig_ohlc.xaxis[0],
                      formatter=DatetimeTickFormatter(days=['%d %b', '%a %d'],
                                                      months=['%m/%Y', "%b'%y"]),
                      source=source),
            code='''
this.labels = this.labels || formatter.doFormat(ticks
                                                .map(i => source.data.datetime[i])
                                                .filter(t => t !== undefined));
return this.labels[index] || "";
        ''')

    NBSP = '\N{NBSP}' * 4
    ohlc_extreme_values = df[['High', 'Low']].copy(deep=False)
    ohlc_tooltips = [
        ('x, y', NBSP.join(('$index',
                            '$y{0,0.0[0000]}'))),
        ('OHLC', NBSP.join(('@Open{0,0.0[0000]}',
                            '@High{0,0.0[0000]}',
                            '@Low{0,0.0[0000]}',
                            '@Close{0,0.0[0000]}'))),
        ('Volume', '@Volume{0,0}')]

    def new_indicator_figure(**kwargs):
        kwargs.setdefault('height', 90)
        fig = new_bokeh_figure(x_range=fig_ohlc.x_range,
                               active_scroll='xwheel_zoom',
                               active_drag='xpan',
                               **kwargs)
        fig.xaxis.visible = False
        fig.yaxis.minor_tick_line_color = None
        return fig

    def set_tooltips(fig, tooltips=(), vline=True, renderers=()):
        tooltips = list(tooltips)
        renderers = list(renderers)

        if is_datetime_index:
            formatters = {'@datetime': 'datetime'}
            tooltips = [("Date", "@datetime{%c}")] + tooltips
        else:
            formatters = {}
            tooltips = [("#", "@index")] + tooltips
        fig.add_tools(HoverTool(
            point_policy='follow_mouse',
            renderers=renderers, formatters=formatters,
            tooltips=tooltips, mode='vline' if vline else 'mouse'))

    def _plot_equity_section(is_return=False):
        """Equity section"""
        # Max DD Dur. line
        equity = equity_data['Equity'].copy()
        dd_end = equity_data['DrawdownDuration'].idxmax()
        if np.isnan(dd_end):
            dd_start = dd_end = equity.index[0]
        else:
            dd_start = equity[:dd_end].idxmax()
            # If DD not extending into the future, get exact point of intersection with equity
            if dd_end != equity.index[-1]:
                dd_end = np.interp(equity[dd_start],
                                   (equity[dd_end - 1], equity[dd_end]),
                                   (dd_end - 1, dd_end))

        if smooth_equity:
            interest_points = pd.Index([
                # Beginning and end
                equity.index[0], equity.index[-1],
                # Peak equity and peak DD
                equity.idxmax(), equity_data['DrawdownPct'].idxmax(),
                # Include max dd end points. Otherwise the MaxDD line looks amiss.
                dd_start, int(dd_end), min(int(dd_end + 1), equity.size - 1),
            ])
            select = pd.Index(trades['ExitBar']).union(interest_points)
            select = select.unique().dropna()
            equity = equity.iloc[select].reindex(equity.index)
            equity.interpolate(inplace=True)

        assert equity.index.equals(equity_data.index)

        if relative_equity:
            equity /= equity.iloc[0]
        if is_return:
            equity -= equity.iloc[0]

        yaxis_label = 'Return' if is_return else 'Equity'
        source_key = 'eq_return' if is_return else 'equity'
        source.add(equity, source_key)
        fig = new_indicator_figure(
            y_axis_label=yaxis_label,
            **({} if plot_drawdown else dict(height=110)))

        # High-watermark drawdown dents
        fig.patch('index', 'equity_dd',
                  source=ColumnDataSource(dict(
                      index=np.r_[index, index[::-1]],
                      equity_dd=np.r_[equity, equity.cummax()[::-1]]
                  )),
                  fill_color='#ffffea', line_color='#ffcb66')

        # Equity line
        r = fig.line('index', source_key, source=source, line_width=1.5, line_alpha=1)
        if relative_equity:
            tooltip_format = f'@{source_key}{{+0,0.[000]%}}'
            tick_format = '0,0.[00]%'
            legend_format = '{:,.0f}%'
        else:
            tooltip_format = f'@{source_key}{{$ 0,0}}'
            tick_format = '$ 0.0 a'
            legend_format = '${:,.0f}'
        set_tooltips(fig, [(yaxis_label, tooltip_format)], renderers=[r])
        fig.yaxis.formatter = NumeralTickFormatter(format=tick_format)

        # Peaks
        argmax = equity.idxmax()
        fig.scatter(argmax, equity[argmax],
                    legend_label='Peak ({})'.format(
                        legend_format.format(equity[argmax] * (100 if relative_equity else 1))),
                    color='cyan', size=8)
        fig.scatter(index[-1], equity.values[-1],
                    legend_label='Final ({})'.format(
                        legend_format.format(equity.iloc[-1] * (100 if relative_equity else 1))),
                    color='blue', size=8)

        if not plot_drawdown:
            drawdown = equity_data['DrawdownPct']
            argmax = drawdown.idxmax()
            fig.scatter(argmax, equity[argmax],
                        legend_label='Max Drawdown (-{:.1f}%)'.format(100 * drawdown[argmax]),
                        color='red', size=8)
        dd_timedelta_label = df['datetime'].iloc[int(round(dd_end))] - df['datetime'].iloc[dd_start]
        fig.line([dd_start, dd_end], equity.iloc[dd_start],
                 line_color='red', line_width=2,
                 legend_label=f'Max Dd Dur. ({dd_timedelta_label})'
                 .replace(' 00:00:00', '')
                 .replace('(0 days ', '('))

        figs_above_ohlc.append(fig)

    def _plot_drawdown_section():
        """Drawdown section"""
        fig = new_indicator_figure(y_axis_label="Drawdown")
        drawdown = equity_data['DrawdownPct']
        argmax = drawdown.idxmax()
        source.add(drawdown, 'drawdown')
        r = fig.line('index', 'drawdown', source=source, line_width=1.3)
        fig.scatter(argmax, drawdown[argmax],
                    legend_label='Peak (-{:.1f}%)'.format(100 * drawdown[argmax]),
                    color='red', size=8)
        set_tooltips(fig, [('Drawdown', '@drawdown{-0.[0]%}')], renderers=[r])
        fig.yaxis.formatter = NumeralTickFormatter(format="-0.[0]%")
        return fig

    def _plot_pl_section():
        """Profit/Loss markers section"""
        fig = new_indicator_figure(y_axis_label="Profit / Loss")
        fig.add_layout(Span(location=0, dimension='width', line_color='#666666',
                            line_dash='dashed', line_width=1))
        returns_long = np.where(trades['Size'] > 0, trades['ReturnPct'], np.nan)
        returns_short = np.where(trades['Size']   <   0, trades['ReturnPct'], np.nan)
        size = trades['Size'].abs()
        size = np.interp(size, (size.min(), size.max()), (8, 20))
        trade_source.add(returns_long, 'returns_long')
        trade_source.add(returns_short, 'returns_short')
        trade_source.add(size, 'marker_size')
        if 'count' in trades:
            trade_source.add(trades['count'], 'count')
        r1 = fig.scatter('index', 'returns_long', source=trade_source, fill_color=cmap,
                         marker='triangle', line_color='black', size='marker_size')
        r2 = fig.scatter('index', 'returns_short', source=trade_source, fill_color=cmap,
                         marker='inverted_triangle', line_color='black', size='marker_size')
        tooltips = [("Size", "@size{0,0}")]
        if 'count' in trades:
            tooltips.append(("Count", "@count{0,0}"))
        set_tooltips(fig, tooltips + [("P/L", "@returns_long{+0.[000]%}")],
                     vline=False, renderers=[r1])
        set_tooltips(fig, tooltips + [("P/L", "@returns_short{+0.[000]%}")],
                     vline=False, renderers=[r2])
        fig.yaxis.formatter = NumeralTickFormatter(format="0.[00]%")
        return fig

    def _plot_volume_section():
        """Volume section"""
        fig = new_indicator_figure(y_axis_label="Volume")
        fig.xaxis.formatter = fig_ohlc.xaxis[0].formatter
        fig.xaxis.visible = True
        fig_ohlc.xaxis.visible = False  # Show only Volume's xaxis
        r = fig.vbar('index', BAR_WIDTH, 'Volume', source=source, color=inc_cmap)
        set_tooltips(fig, [('Volume', '@Volume{0.00 a}')], renderers=[r])
        fig.yaxis.formatter = NumeralTickFormatter(format="0 a")
        return fig

    def _plot_superimposed_ohlc():
        """Superimposed, downsampled vbars"""
        time_resolution = pd.DatetimeIndex(df['datetime']).resolution
        resample_rule = (superimpose if isinstance(superimpose, str) else
                         dict(day='M',
                              hour='D',
                              minute='H',
                              second='T',
                              millisecond='S').get(time_resolution))
        if not resample_rule:
            warnings.warn(
                f"'Can't superimpose OHLC data with rule '{resample_rule}'"
                f"(index datetime resolution: '{time_resolution}'). Skipping.",
                stacklevel=4)
            return

        df2 = (df.assign(_width=1).set_index('datetime')
               .resample(resample_rule, label='left')
               .agg(dict(OHLCV_AGG, _width='count')))

        # Check if resampling was downsampling; error on upsampling
        orig_freq = _data_period(df['datetime'])
        resample_freq = _data_period(df2.index)
        if resample_freq  <  orig_freq:
            raise ValueError('Invalid value for `superimpose`: Upsampling not supported.')
        if resample_freq == orig_freq:
            warnings.warn('Superimposed OHLC plot matches the original plot. Skipping.',
                          stacklevel=4)
            return

        df2.index = df2['_width'].cumsum().shift(1).fillna(0)
        df2.index += df2['_width'] / 2 - .5
        df2['_width'] -= .1  # Candles don't touch

        df2['inc'] = (df2.Close >= df2.Open).astype(int).astype(str)
        df2.index.name = None
        source2 = ColumnDataSource(df2)
        fig_ohlc.segment('index', 'High', 'index', 'Low', source=source2, color='#bbbbbb')
        colors_lighter = [lightness(BEAR_COLOR, .92),
                          lightness(BULL_COLOR, .92)]
        fig_ohlc.vbar('index', '_width', 'Open', 'Close', source=source2, line_color=None,
                      fill_color=factor_cmap('inc', colors_lighter, ['0', '1']))

    def _plot_ohlc():
        """Main OHLC bars"""
        fig_ohlc.segment('index', 'High', 'index', 'Low', source=source, color="black")
        r = fig_ohlc.vbar('index', BAR_WIDTH, 'Open', 'Close', source=source,
                          line_color="black", fill_color=inc_cmap)
        return r

    def _plot_ohlc_trades():
        """Trade entry / exit markers on OHLC plot"""
        trade_source.add(trades[['EntryBar', 'ExitBar']].values.tolist(), 'position_lines_xs')
        trade_source.add(trades[['EntryPrice', 'ExitPrice']].values.tolist(), 'position_lines_ys')
        fig_ohlc.multi_line(xs='position_lines_xs', ys='position_lines_ys',
                            source=trade_source, line_color=trades_cmap,
                            legend_label=f'Trades ({len(trades)})',
                            line_width=8, line_alpha=1, line_dash='dotted')

    def _plot_indicators():
        """Strategy indicators"""

        def _too_many_dims(value):
            assert value.ndim >= 2
            if value.ndim > 2:
                warnings.warn(f"Can't plot indicators with >2D ('{value.name}')",
                              stacklevel=5)
                return True
            return False

        class LegendStr(str):
            # The legend string is such a string that only matches
            # itself if it's the exact same object. This ensures
            # legend items are listed separately even when they have the
            # same string contents. Otherwise, Bokeh would always consider
            # equal strings as one and the same legend item.
            def __eq__(self, other):
                return self is other

        ohlc_colors = colorgen()
        indicator_figs = []

        for i, value in enumerate(indicators):
            value = np.atleast_2d(value)

            # Use .get()! A user might have assigned a Strategy.data-evolved
            # _Array without Strategy.I()
            if not value._opts.get('plot') or _too_many_dims(value):
                continue

            is_overlay = value._opts['overlay']
            is_scatter = value._opts['scatter']
            if is_overlay:
                fig = fig_ohlc
            else:
                fig = new_indicator_figure()
                indicator_figs.append(fig)
            tooltips = []
            colors = value._opts['color']
            colors = colors and cycle(_as_list(colors)) or (
                cycle([next(ohlc_colors)]) if is_overlay else colorgen())
            legend_label = LegendStr(value.name)
            for j, arr in enumerate(value, 1):
                color = next(colors)
                source_name = f'{legend_label}_{i}_{j}'
                if arr.dtype == bool:
                    arr = arr.astype(int)
                source.add(arr, source_name)
                tooltips.append(f'@{{{source_name}}}{{0,0.0[0000]}}')
                if is_overlay:
                    ohlc_extreme_values[source_name] = arr
                    if is_scatter:
                        fig.scatter(
                            'index', source_name, source=source,
                            legend_label=legend_label, color=color,
                            line_color='black', fill_alpha=.8,
                            marker='circle', radius=BAR_WIDTH / 2 * 1.5)
                    else:
                        fig.line(
                            'index', source_name, source=source,
                            legend_label=legend_label, line_color=color,
                            line_width=1.3)
                else:
                    if is_scatter:
                        r = fig.scatter(
                            'index', source_name, source=source,
                            legend_label=LegendStr(legend_label), color=color,
                            marker='circle', radius=BAR_WIDTH / 2 * .9)
                    else:
                        r = fig.line(
                            'index', source_name, source=source,
                            legend_label=LegendStr(legend_label), line_color=color,
                            line_width=1.3)
                    # Add dashed centerline just because
                    mean = float(pd.Series(arr).mean())
                    if not np.isnan(mean) and (abs(mean)  <  .1 or
                                               round(abs(mean), 1) == .5 or
                                               round(abs(mean), -1) in (50, 100, 200)):
                        fig.add_layout(Span(location=float(mean), dimension='width',
                                            line_color='#666666', line_dash='dashed',
                                            line_width=.5))
            if is_overlay:
                ohlc_tooltips.append((legend_label, NBSP.join(tooltips)))
            else:
                set_tooltips(fig, [(legend_label, NBSP.join(tooltips))], vline=True, renderers=[r])
                # If the sole indicator line on this figure,
                # have the legend only contain text without the glyph
                if len(value) == 1:
                    fig.legend.glyph_width = 0
        return indicator_figs

    # Construct figure ...

    if plot_equity:
        _plot_equity_section()

    if plot_return:
        _plot_equity_section(is_return=True)

    if plot_drawdown:
        figs_above_ohlc.append(_plot_drawdown_section())

    if plot_pl:
        figs_above_ohlc.append(_plot_pl_section())

    if plot_volume:
        fig_volume = _plot_volume_section()
        figs_below_ohlc.append(fig_volume)

    if superimpose and is_datetime_index:
        _plot_superimposed_ohlc()

    ohlc_bars = _plot_ohlc()
    _plot_ohlc_trades()
    indicator_figs = _plot_indicators()
    if reverse_indicators:
        indicator_figs = indicator_figs[::-1]
    figs_below_ohlc.extend(indicator_figs)

    set_tooltips(fig_ohlc, ohlc_tooltips, vline=True, renderers=[ohlc_bars])

    source.add(ohlc_extreme_values.min(1), 'ohlc_low')
    source.add(ohlc_extreme_values.max(1), 'ohlc_high')

    custom_js_args = dict(ohlc_range=fig_ohlc.y_range,
                          source=source)
    if plot_volume:
        custom_js_args.update(volume_range=fig_volume.y_range)

    fig_ohlc.x_range.js_on_change('end', CustomJS(args=custom_js_args,
                                                  code=_AUTOSCALE_JS_CALLBACK))

    plots = figs_above_ohlc + [fig_ohlc] + figs_below_ohlc
    linked_crosshair = CrosshairTool(dimensions='both')

    for f in plots:
        if f.legend:
            f.legend.visible = show_legend
            f.legend.location = 'top_left'
            f.legend.border_line_width = 1
            f.legend.border_line_color = '#333333'
            f.legend.padding = 5
            f.legend.spacing = 0
            f.legend.margin = 0
            f.legend.label_text_font_size = '8pt'
            f.legend.click_policy = "hide"
        f.min_border_left = 0
        f.min_border_top = 3
        f.min_border_bottom = 6
        f.min_border_right = 10
        f.outline_line_color = '#666666'

        f.add_tools(linked_crosshair)
        wheelzoom_tool = next(wz for wz in f.tools if isinstance(wz, WheelZoomTool))
        wheelzoom_tool.maintain_focus = False

    kwargs = {}
    if plot_width is None:
        kwargs['sizing_mode'] = 'stretch_width'

    fig = gridplot(
        plots,
        ncols=1,
        toolbar_location='right',
        toolbar_options=dict(logo=None),
        merge_tools=True,
        **kwargs
    )
    show(fig, browser=None if open_browser else 'none')
    return fig


def plot_heatmaps(heatmap: pd.Series, agg: Union[Callable, str], ncols: int,

0 View Source File : classes.py
License : GNU General Public License v3.0
Project Creator : lucasbellinaso

  def buildBokehFigs(self):
    from bokeh.models import ColumnDataSource,Span,Band,Label
    from bokeh.plotting import figure as BkFig
    #BOKEH FIGURES:
    #Vector data:
    self.bodesource = ColumnDataSource( data={'omega':[], 'freqHz': [],
              'magdBT':[],'magT':[],'magdBG':[],'magG': [],'angT':[],'angG':[]})
    self.gpbodesource = ColumnDataSource(data ={'fHz':[],'magdB':[],'angdeg':[]})
    self.gzbodesource = ColumnDataSource(data ={'fHz':[],'magdB':[],'angdeg':[]})
    self.cpbodesource = ColumnDataSource(data ={'fHz':[],'magdB':[],'angdeg':[]})
    self.czbodesource = ColumnDataSource(data ={'fHz':[],'magdB':[],'angdeg':[]})
    self.PM_GMsource = ColumnDataSource( data = {'PMfcHz': [1.,1.],'GMfHz':[2.,2.],
                                 'ylimsmag':[-200,200], 'ylimsang':[-720,720] })
    self.rlocussource = ColumnDataSource(data={'x':[],'y':[],'K':[]})
    self.gprlocussource = ColumnDataSource(data={'x':[],'y':[],'K':[]})
    self.gzrlocussource = ColumnDataSource(data={'x':[],'y':[],'K':[]})
    self.cprlocussource = ColumnDataSource(data={'x':[],'y':[],'K':[]})
    self.czrlocussource = ColumnDataSource(data={'x':[],'y':[],'K':[]})
    self.krlocussource = ColumnDataSource(data={'x':[],'y':[],'K':[]})
    self.stepsource = ColumnDataSource(data={'t_s':[],'stepRYmf':[],'stepUYma':[],'stepRUmf':[]})
    #self.tRespsource = ColumnDataSource(data={'t_s':[],'r':[],'du':[],'dy':[],'dm':[],'u':[],'y':[]})
    self.tRespsource = ColumnDataSource(data=self.CtrAnWgt.waveVec_dict)

    #Shadows:
    MAX_OVERSHOT = 0.01*self.CtrAnWgt.OShotIn.value + 1
    MAX_RISE_TIME, MAX_SETTLING_TIME = self.CtrAnWgt.RTimeIn.value, self.CtrAnWgt.STimeIn.value
    _thetaZ = np.linspace(0,np.pi,100)
    _costh, _sinthN, _sinth = np.cos(_thetaZ), -np.sin(_thetaZ), np.sin(_thetaZ)
    self.shadowsource = ColumnDataSource(
        data={'x_s': [0,1e4],     'ylow': [-1e4,1e4],  'yup': [1e4,1e4],  
            'xn_z': [-1e4,-1], 'xp_z': [1,1e4] , 'zero':[0,0],
            'overshot':[MAX_OVERSHOT, MAX_OVERSHOT], 
            'risetime':[MAX_RISE_TIME,1e4] , 'riselevel':[0.9,0.9],
            'settlingtime':[MAX_SETTLING_TIME,1e4],
            'setlevel1':[0.98,0.98], 'setlevel2':[1.02,1.02]  } )
    self.shadowZsource=ColumnDataSource(
               data = {'x_z':_costh, 'ylow_z':_sinthN, 'yup_z':_sinth, 
                       'ylow': 100*[-1e4], 'yup': 100*[1e4]})

    self.shadows = {
    'rloc_s': Band(base='x_s', lower='ylow', upper='yup', level='underlay',
            source=self.shadowsource, fill_color='lightgrey', line_color='black'),
    'rloc_z1': Band(base='xn_z', lower='ylow', upper='yup', level='underlay', 
            source=self.shadowsource, fill_color='lightgrey'),
    'rloc_z2': Band(base='xp_z', lower='ylow', upper='yup', level='underlay', 
            source=self.shadowsource, fill_color='lightgrey'),
    'rloc_z3': Band(base='x_z', lower='ylow', upper='ylow_z', level='underlay',
           source=self.shadowZsource,fill_color='lightgrey',line_color='black'),
    'rloc_z4': Band(base='x_z', lower='yup_z', upper='yup', level='underlay',
           source=self.shadowZsource,fill_color='lightgrey',line_color='black'),
    'ovsht': Band(base='x_s', lower='overshot', upper='yup', level='underlay',
            source=self.shadowsource,fill_color='lightgrey', visible=True),
    'riset': Band(base='risetime', lower='ylow', upper='riselevel', 
            level='underlay', source=self.shadowsource,fill_color='lightgrey'),
    'sett1': Band(base='settlingtime', lower='riselevel', upper='setlevel1', 
                 level='underlay', source=self.shadowsource,fill_color='lightgrey'),
    'sett2': Band(base='settlingtime', lower='setlevel2', upper='overshot', 
                 level='underlay', source=self.shadowsource,fill_color='lightgrey') }                    

    _TTS_BD1 = [('sys',"$name"),("f","$x Hz"),("mag","$y dB")]
    _TTS_BD2 = [('sys',"$name"),("f","$x Hz"),("ang","$y°")]
    _TTS_RLOC= [("real","@x"),("imag","@y"),('K','@K{0.00 a}')]
    _TTS_TRESP = [('signal', "$name"), ("t", "$x s"), ("value", "$y") ]
    self.figMag = BkFig(title="Bode Magnitude", plot_height=300, plot_width=400,
               toolbar_location="above", tooltips = _TTS_BD1, x_axis_type="log",
               x_axis_label='f (Hz)', y_axis_label='mag (dB)')
    self.figAng =  BkFig(title="Bode Angle", plot_height=300, plot_width=400,
                toolbar_location="above", tooltips = _TTS_BD2, x_axis_type="log",
                x_axis_label='f (Hz)', y_axis_label='ang (°)')
    self.figAng.x_range  = self.figMag.x_range   #same axis
    self.figAng.yaxis.ticker=np.linspace(-720,720,17)
    self.figRLoc=  BkFig(title="Root Locus", plot_height=300, plot_width=400,
                toolbar_location="above", tooltips = _TTS_RLOC,
                x_axis_label='real', y_axis_label='imag')
    #self.figRLoc.hover.line_policy = 'interp'
    self.figTResp = BkFig(title="Step Response", plot_height=300, plot_width=400,
                toolbar_location="above", tooltips = _TTS_TRESP,
                x_axis_label='time (s)', y_axis_label='y') 
    self.figTResp2= BkFig(title="Disturbance Simulation", plot_height=300, plot_width=800, 
                toolbar_location="above", tooltips = _TTS_TRESP,
                x_axis_label='time (s)', y_axis_label='y, r, dy, dm')
 
    self.Bkgrid = bokeh.layouts.layout([[self.figMag, self.figRLoc],
                                        [self.figAng, self.figTResp],
                                        self.figTResp2])


    if self.dt in [None, 0.0]:   #continuous time
      self.figRLoc.add_layout(self.shadows['rloc_s'])
    else:                 #discrete time
        for strkey in ['rloc_z1', 'rloc_z2', 'rloc_z3', 'rloc_z4']:
          self.figRLoc.add_layout(self.shadows[strkey])
        self.Nyquistlimits = Span(location=0.5/self.dt,
                                 dimension='height', line_color='black',
                                 line_dash='dotted', line_width=1)
        self.figMag.add_layout(self.Nyquistlimits)
        self.figAng.add_layout(self.Nyquistlimits)
    for strkey in ['ovsht', 'riset', 'sett1', 'sett2']:
      self.figTResp.add_layout(self.shadows[strkey])
      #self.figTResp2.add_layout(self.shadows[strkey])


    #Bode Diagram:
    bodemagT=self.figMag.line(x='freqHz', y='magdBT',color="blue",line_width=1.5,
                alpha=0.8,name='|T(s)|',legend_label='T(s)',source=self.bodesource)
    bodemagG=self.figMag.line(x='freqHz', y='magdBG',color="green",line_width=1.5, 
                alpha=0.8,name='|Gp(s)|',line_dash='dashed',
                legend_label='Gp(s)',source=self.bodesource)
    bodeangT=self.figAng.line(x='freqHz', y='angT', color="blue", line_width=1.5,
                alpha=0.8, name='∡T(s)', source=self.bodesource)
    bodeangG=self.figAng.line(x='freqHz', y='angG',color="green", line_width=1.5,
                alpha=0.8,name='∡Gp(s)',line_dash='dashed',source=self.bodesource)
    bodeGpmag = self.figMag.x( x='fHz',y='magdB',line_color='blue', size=10,
                 name='Gp poles', source = self.gpbodesource)
    bodeGpang=self.figAng.x(x='fHz',y='angdeg',line_color='blue', size=10,
                 name='Gp poles angle', source = self.gpbodesource)
    bodeGzmag = self.figMag.circle(x='fHz',y='magdB',line_color='blue',size=8,
                 name='Gp zeros',fill_color=None, source = self.gzbodesource)
    bodeGzang=self.figAng.circle(x='fHz',y='angdeg',line_color='blue',size=8,
              name='Gp zeros angle', fill_color=None,source = self.gzbodesource)
    bodeCpmag = self.figMag.x(x='fHz',y='magdB',line_color='red',size=10,
                 name='C poles', source = self.cpbodesource)
    bodeCpang=self.figAng.x(x='fHz',y='angdeg',line_color='red',size=10,
                 name='C poles angle', source = self.cpbodesource)
    bodeCzmag = self.figMag.circle(x='fHz',y='magdB',line_color='red',size=8,
                 name='C zeros', fill_color=None, source = self.czbodesource)
    bodeCzang=self.figAng.circle(x='fHz',y='angdeg',line_color='red',size=8,
                name='C zeros angle',fill_color=None, source = self.czbodesource)
    self.GMSpan = Span(location=1, dimension='height',
                       line_color='black', line_dash='dotted', line_width=1)
    self.PMSpan = Span(location=1, dimension='height',
                       line_color='black', line_dash='dotted', line_width=1)
    self.PMtxt = Label(x=5, y=5, x_units='screen', y_units='screen', 
                         text=' ',  render_mode='css',border_line_color=None,
                        background_fill_color='white',text_font_size = '11px')
    self.GMtxt = Label(x=5, y=20, x_units='screen', y_units='screen', 
                         text=' ',  render_mode='css',border_line_color=None,
                        background_fill_color='white',text_font_size = '11px')
    #self.Clbltxt = Label(x=5, y=20, x_units='screen', y_units='screen', 
    #                     text=' ',  render_mode='css',border_line_color=None,
    #                    background_fill_color='white',text_font_size = '11px')
    #self.Clbltxt.text = 'C(s) = ' if self.dt in [None, 0.0] else 'C(z) = '
    #self.Cgaintxt = Label(x=40, y=20, x_units='screen', y_units='screen', 
    #                     text='K',  render_mode='css',border_line_color=None,
    #                    background_fill_color='white',text_font_size = '11px')
    #self.Cnumtxt = Label(x=100, y=30, x_units='screen', y_units='screen', 
    #                     text='N',  render_mode='css',border_line_color=None,
    #                    background_fill_color='white',text_font_size = '11px')
    #self.Cdentxt = Label(x=100, y=10, x_units='screen', y_units='screen', 
    #                     text='D',  render_mode='css',border_line_color=None,
    #                    background_fill_color='white',text_font_size = '11px')
    self.figMag.add_layout(self.GMSpan), self.figAng.add_layout(self.GMSpan)
    self.figMag.add_layout(self.PMSpan), self.figAng.add_layout(self.PMSpan)
    self.figAng.add_layout(self.PMtxt), self.figAng.add_layout(self.GMtxt)
    #self.figMag.add_layout(self.Clbltxt), self.figMag.add_layout(self.Cgaintxt)
    #self.figMag.add_layout(self.Cnumtxt), self.figMag.add_layout(self.Cdentxt)

    #Root Locus:
    rlocusline = self.figRLoc.dot(x='x',y='y',color='blue',
                                  name='rlocus', source = self.rlocussource)
    rlocusGpoles = self.figRLoc.x(x='x',y='y',color='blue', size=10,
                                name='Gp pole', source = self.gprlocussource)
    rlocusGzeros = self.figRLoc.circle(x='x',y='y',line_color='blue',size=8,
                 name='Gp zero', fill_color=None, source = self.gzrlocussource)
    rlocusCpoles = self.figRLoc.x(x='x',y='y',color='red', size=10,
                                name='C pole', source = self.cprlocussource)
    rlocusCzeros = self.figRLoc.circle(x='x',y='y',line_color='red',size=8,
                 name='C zero', fill_color=None, source = self.czrlocussource)
    rlocusMF = self.figRLoc.square(x='x',y='y', line_color='deeppink',size=8,
                 name='K', fill_color='deeppink', source = self.krlocussource)
    rlocuslinehv = self.figRLoc.line(x='x',y='y',line_alpha=0, 
                                  name='rlocus2', source = self.rlocussource)
    self.figRLoc.hover.renderers=[rlocuslinehv, rlocusGpoles, rlocusGzeros,
                                  rlocusCpoles, rlocusCzeros, rlocusMF]
    #self.figRLoc.hover.mode='mouse'   
    #self.figRLoc.hover.line_policy='next'
    #self.figRLoc.hover.point_policy='snap_to_data'
    self.Stabilitytxt = Label(x=10, y=200, x_units='screen', y_units='screen', 
                         text=' ',  render_mode='css',border_line_color=None,
                        background_fill_color='white',text_font_size = '11px')
    self.figRLoc.add_layout(self.Stabilitytxt)

    #Step response:
    self.figTResp.extra_y_ranges = {'u_range': bokeh.models.Range1d()}
    self.figTResp.add_layout(bokeh.models.LinearAxis(y_range_name="u_range",
                                                     axis_label='u'), 'right')
    self.figTResp.y_range = bokeh.models.Range1d(start = -0.1, end = 1.4)
    #add_graf = self.figTResp.line if self.dt in [None, 0.0] else self.figTResp.dot
    if self.dt in [None, 0.0]:
        stepR2Y=self.figTResp.line(x='t_s', y='stepRYmf',color="blue",line_width=1.5, name='y',
                        legend_label='y (closed loop)',  source=self.stepsource)
        stepU2Y=self.figTResp.line(x='t_s', y='stepUYma',color="green",
               legend_label='y (open loop)', line_dash='dashed', line_width=1.0,
                                name='y (ol)',source=self.stepsource, visible=False)
        stepR2U=self.figTResp.line(x='t_s', y='stepRUmf',color="red",
                       line_width=1.0, name='u',legend_label='u (closed loop)',
            line_dash='dashed', source=self.stepsource, y_range_name = 'u_range', visible=False)
    else:
        stepR2Y=self.figTResp.dot(x='t_s', y='stepRYmf',color="blue",
                                 line_width=1.5, name='y',  size=15,
                        legend_label='y (closed loop)',  source=self.stepsource)
        stepU2Y=self.figTResp.dot(x='t_s', y='stepUYma',color="green", size=15,
                                legend_label='y (open loop)', line_width=1.0,
                                name='y (ol)',source=self.stepsource, visible=False)
        stepR2U=self.figTResp.dot(x='t_s', y='stepRUmf',color="red", size=15,
                       line_width=1.0, name='u',legend_label='u (closed loop)',
                       source=self.stepsource, y_range_name = 'u_range', visible=False)
    self.figTResp.legend.location = 'bottom_right'
    self.figTResp.legend.click_policy = 'hide'

    #Disturbances response:
    self.figTResp2.extra_y_ranges = {'u_range': bokeh.models.Range1d()}
    self.figTResp2.add_layout(bokeh.models.LinearAxis(y_range_name="u_range",
                                                     axis_label='u, du'), 'right')
    self.figTResp2.y_range = bokeh.models.Range1d(start = -0.1, end = 1.4)
    if self.dt in [None, 0.0]:
        tRespY=self.figTResp2.line(x='t_s', y='y',color="blue", line_width=1.5, name='y',
                        legend_label='y',  source=self.tRespsource)
        tRespU=self.figTResp2.line(x='t_s', y='u',color="red", legend_label='u', line_width=1.5,
                      name='u',source=self.tRespsource, y_range_name = 'u_range', visible=False)
        tRespDU=self.figTResp2.line(x='t_s', y='du',color="indianred", line_width=1.0, name='du',legend_label='du',
                                    line_dash='dashed', source=self.tRespsource, y_range_name = 'u_range', visible=False)
        tRespR=self.figTResp2.line(x='t_s', y='r',color="green", line_width=1.0, name='r',legend_label='r',
                                  line_dash='dashed', source=self.tRespsource)
        tRespDY=self.figTResp2.line(x='t_s', y='dy',color="deepskyblue", line_width=1.0, name='dy',legend_label='dy',
                                    line_dash='dashed', source=self.tRespsource, visible=False)
        tRespDM=self.figTResp2.line(x='t_s', y='dm',color="lime", line_width=1.0, name='dm',legend_label='dm',
                                    line_dash='dashed', source=self.tRespsource, visible=False)
    else:
        tRespY=self.figTResp2.dot(x='t_s', y='y',color="blue", line_width=1.5, size=15, name='y',
                        legend_label='y',  source=self.tRespsource)
        tRespU=self.figTResp2.dot(x='t_s', y='u',color="red", legend_label='u', line_width=1.5, size=15,
                      name='u',source=self.tRespsource, y_range_name = 'u_range', visible=False)
        tRespDU=self.figTResp2.dot(x='t_s', y='du',color="indianred", line_width=1.0, name='du',legend_label='du',
                              size=15, line_dash='dashed', source=self.tRespsource, y_range_name = 'u_range', visible=False)
        tRespR=self.figTResp2.dot(x='t_s', y='r',color="green", line_width=1.0, name='r',legend_label='r',size=15,
                                  line_dash='dashed', source=self.tRespsource)
        tRespDY=self.figTResp2.dot(x='t_s', y='dy',color="deepskyblue", line_width=1.0, name='dy',legend_label='dy',
                                size=15,line_dash='dashed', source=self.tRespsource, visible=False)
        tRespDM=self.figTResp2.dot(x='t_s', y='dm',color="lime", line_width=1.0, name='dm',legend_label='dm',
                               size=15, line_dash='dashed', source=self.tRespsource, visible=False)
    self.figTResp2.legend.location = 'bottom_right'
    self.figTResp2.legend.click_policy = 'hide'

  def updateTFAndBokeh(self,b):

0 View Source File : interact.py
License : MIT License
Project Creator : nasa

def make_lightcurve_figure_elements(lc, lc_source):
    """Make the lightcurve figure elements.

    Parameters
    ----------
    lc : LightCurve
        Lightcurve to be shown.
    lc_source : bokeh.plotting.ColumnDataSource
        Bokeh object that enables the visualization.

    Returns
    ----------
    fig : `bokeh.plotting.figure` instance
    step_renderer : GlyphRenderer
    vertical_line : Span
    """
    if lc.mission == 'K2':
        title = "Lightcurve for {} (K2 C{})".format(
            lc.label, lc.campaign)
    elif lc.mission == 'Kepler':
        title = "Lightcurve for {} (Kepler Q{})".format(
            lc.label, lc.quarter)
    elif lc.mission == 'TESS':
        title = "Lightcurve for {} (TESS Sec. {})".format(
            lc.label, lc.sector)
    else:
        title = "Lightcurve for target {}".format(lc.label)

    fig = figure(title=title, plot_height=340, plot_width=600,
                 tools="pan,wheel_zoom,box_zoom,tap,reset",
                 toolbar_location="below",
                 border_fill_color="whitesmoke")
    fig.title.offset = -10
    fig.yaxis.axis_label = 'Flux (e/s)'
    fig.xaxis.axis_label = 'Time (days)'
    try:
        if (lc.mission == 'K2') or (lc.mission == 'Kepler'):
            fig.xaxis.axis_label = 'Time - 2454833 (days)'
        elif lc.mission == 'TESS':
            fig.xaxis.axis_label = 'Time - 2457000 (days)'
    except AttributeError:  # no mission keyword available
      pass


    ylims = get_lightcurve_y_limits(lc_source)
    fig.y_range = Range1d(start=ylims[0], end=ylims[1])

    # Add step lines, circles, and hover-over tooltips
    fig.step('time', 'flux', line_width=1, color='gray',
             source=lc_source, nonselection_line_color='gray',
             nonselection_line_alpha=1.0)
    circ = fig.circle('time', 'flux', source=lc_source, fill_alpha=0.3, size=8,
                      line_color=None, selection_color="firebrick",
                      nonselection_fill_alpha=0.0,
                      nonselection_fill_color="grey",
                      nonselection_line_color=None,
                      nonselection_line_alpha=0.0,
                      fill_color=None, hover_fill_color="firebrick",
                      hover_alpha=0.9, hover_line_color="white")
    tooltips = [("Cadence", "@cadence"),
                ("Time ({})".format(lc.time_format.upper()),
                 "@time{0,0.000}"),
                ("Time (ISO)", "@time_iso"),
                ("Flux", "@flux"),
                ("Quality Code", "@quality_code"),
                ("Quality Flag", "@quality")]
    fig.add_tools(HoverTool(tooltips=tooltips, renderers=[circ],
                            mode='mouse', point_policy="snap_to_data"))

    # Vertical line to indicate the cadence
    vertical_line = Span(location=lc.time[0], dimension='height',
                         line_color='firebrick', line_width=4, line_alpha=0.5)
    fig.add_layout(vertical_line)

    return fig, vertical_line


def add_gaia_figure_elements(tpf, fig, magnitude_limit=18):

0 View Source File : interact_bls.py
License : MIT License
Project Creator : nasa

def make_bls_figure_elements(result, bls_source, help_source):
    """Make a line plot of a BLS result.

    Parameters
    ----------
    result : BLS.model result
        BLS model result to plot
    bls_source : bokeh.plotting.ColumnDataSource
        Bokeh style source object for plotting BLS source
    help_source : bokeh.plotting.ColumnDataSource
        Bokeh style source object for rendering help button

    Returns
    -------
    fig : bokeh.plotting.figure
        Bokeh figure object
    vertical_line : bokeh.models.Span
        Vertical line to highlight current selected period
    """

    # Build Figure
    fig = figure(title='BLS Periodogram', plot_height=340, plot_width=450,
                 tools="pan,box_zoom,tap,reset",
                 toolbar_location="below",
                 border_fill_color="#FFFFFF", x_axis_type='log', active_drag="box_zoom")
    fig.title.offset = -10
    fig.yaxis.axis_label = 'Power'
    fig.xaxis.axis_label = 'Period [days]'
    fig.y_range = Range1d(start=result.power.min() * 0.95, end=result.power.max() * 1.05)
    fig.x_range = Range1d(start=result.period.min(), end=result.period.max())

    # Add circles for the selection of new period. These are always hidden
    fig.circle('period', 'power',
               source=bls_source,
               fill_alpha=0.,
               size=6,
               line_color=None,
               selection_color="white",
               nonselection_fill_alpha=0.0,
               nonselection_fill_color='white',
               nonselection_line_color=None,
               nonselection_line_alpha=0.0,
               fill_color=None,
               hover_fill_color="white",
               hover_alpha=0.,
               hover_line_color="white")

    # Add line for the BLS power
    fig.line('period', 'power', line_width=1, color='#191919',
             source=bls_source, nonselection_line_color='#191919',
             nonselection_line_alpha=1.0)

    # Vertical line to indicate the current period
    vertical_line = Span(location=0, dimension='height',
                         line_color='firebrick', line_width=3, line_alpha=0.5)
    fig.add_layout(vertical_line)

    # Help button
    question_mark = Text(x="period", y="power", text="helpme", text_color="grey",
                         text_align='center', text_baseline="middle",
                         text_font_size='12px', text_font_style='bold',
                         text_alpha=0.6)
    fig.add_glyph(help_source, question_mark)
    help = fig.circle('period', 'power', alpha=0.0, size=15, source=help_source,
                      line_width=2, line_color='grey', line_alpha=0.6)
    tooltips = help_source.data['help'][0]
    fig.add_tools(HoverTool(tooltips=tooltips, renderers=[help],
                            mode='mouse', point_policy="snap_to_data"))

    return fig, vertical_line


def show_interact_widget(lc, notebook_url='localhost:8888', minimum_period=None,

0 View Source File : visualization.py
License : MIT License
Project Creator : pedromartins4

def plot_rsi(stock):
    p = figure(x_axis_type="datetime", plot_width=WIDTH_PLOT, plot_height=200, title="RSI 15 days",
               tools=TOOLS, toolbar_location='above')

    p.line(x='date', y='rsi_15', line_width=2, color=BLUE, source=stock)

    low_box = BoxAnnotation(top=30, fill_alpha=0.1, fill_color=RED)
    p.add_layout(low_box)
    high_box = BoxAnnotation(bottom=70, fill_alpha=0.1, fill_color=GREEN)
    p.add_layout(high_box)

    # Horizontal line
    hline = Span(location=50, dimension='width', line_color='black', line_width=0.5)
    p.renderers.extend([hline])

    p.y_range = Range1d(0, 100)
    p.yaxis.ticker = [30, 50, 70]
    p.yaxis.formatter = PrintfTickFormatter(format="%f%%")
    p.grid.grid_line_alpha = 0.3

    return p


#### On-Balance Volume (OBV)
def plot_obv(stock):

0 View Source File : outputs.py
License : MIT License
Project Creator : PSLmodels

def liability_plot(df_base, df_reform, span, mtr_opt):
    df_base = ColumnDataSource(df_base)
    df_reform = ColumnDataSource(df_reform)
    tools = "pan, zoom_in, zoom_out, reset"
    fig = figure(plot_width=600, plot_height=500,
                 x_range=(-10000, 300000), y_range=(-20000, 100000), tools=tools, active_drag="pan")
    fig.yaxis.axis_label = "Tax Liabilities"
    fig.yaxis.formatter = NumeralTickFormatter(format="$0,000")

    filer_income = Span(location=span, dimension='height',
                        line_color='black', line_dash='dotted', line_width=1.5)
    fig.add_layout(filer_income)
    label_format = f'{span:,}'
    filer_income_label = Label(x=span, y=25, y_units='screen', x_offset=10, text="{}: $".format(mtr_opt) + label_format,
                               text_color='#303030', text_font="arial", text_font_style="italic", text_font_size="10pt")
    fig.add_layout(filer_income_label)
    axis = Span(location=0, dimension='width',
                line_color='#bfbfbf', line_width=1.5)
    fig.add_layout(axis)

    iitax_base = fig.line(x="Axis", y="Individual Income Tax", line_color='#2b83ba', muted_color='#2b83ba',
                          line_width=2, legend_label="Individual Income Tax Liability", muted_alpha=0.1, source=df_base)
    payroll_base = fig.line(x="Axis", y="Payroll Tax", line_color='#abdda4', muted_color='#abdda4',
                            line_width=2, legend_label='Payroll Tax Liability', muted_alpha=0.1, source=df_base)

    iitax_reform = fig.line(x="Axis", y="Individual Income Tax", line_color='#2b83ba', muted_color='#2b83ba',
                            line_width=2, line_dash='dashed', legend_label="Individual Income Tax Liability", muted_alpha=0.1, source=df_reform)
    payroll_reform = fig.line(x="Axis", y="Payroll Tax", line_color='#abdda4', muted_color='#abdda4',
                              line_width=2, line_dash='dashed', legend_label='Payroll Tax Liability', muted_alpha=0.1, source=df_reform)

    iitax_base.muted = False
    payroll_base.muted = False
    iitax_reform.muted = False
    payroll_reform.muted = False

    plot_js = """
    object1.visible = toggle.active
    object2.visible = toggle.active
    """
    base_callback = CustomJS(code=plot_js, args={})
    base_toggle = Toggle(label="Base (Solid)", button_type="default",
                         active=True)
    base_callback.args = {"toggle": base_toggle, "object1": iitax_base,
                          "object2": payroll_base}
    base_toggle.js_on_change('active', base_callback)

    reform_callback = CustomJS(code=plot_js, args={})
    reform_toggle = Toggle(label="Reform (Dashed)", button_type="default",
                           active=True)
    reform_callback.args = {"toggle": reform_toggle, "object1": iitax_reform,
                            "object2": payroll_reform}
    reform_toggle.js_on_change('active', reform_callback)

    fig.xaxis.formatter = NumeralTickFormatter(format="$0,000")
    fig.xaxis.axis_label = mtr_opt
    fig.xaxis.minor_tick_line_color = None

    fig.legend.click_policy = "mute"

    layout = column(fig, row(base_toggle, reform_toggle))

    data = json_item(layout)

    outputs = {
        "media_type": "bokeh",
        "title": "Tax Liabilities by {} (Holding Other Inputs Constant)".format(mtr_opt),
        "data": data
    }

    return outputs


def rate_plot(df_base, df_reform, span, mtr_opt):

0 View Source File : outputs.py
License : MIT License
Project Creator : PSLmodels

def rate_plot(df_base, df_reform, span, mtr_opt):
    df_base = ColumnDataSource(df_base)
    df_reform = ColumnDataSource(df_reform)
    tools = "pan, zoom_in, zoom_out, reset"
    fig = figure(plot_width=600, plot_height=500,
                 x_range=(-10000, 300000), y_range=(-0.3, 0.5), tools=tools, active_drag="pan")
    fig.yaxis.axis_label = "Tax Rate"
    fig.yaxis.formatter = NumeralTickFormatter(format="0%")

    filer_income = Span(location=span, dimension='height',
                        line_color='black', line_dash='dotted', line_width=1.5)
    fig.add_layout(filer_income)
    label_format = f'{span:,}'
    filer_income_label = Label(x=span, y=25, y_units='screen', x_offset=10, text="{}: $".format(mtr_opt) + label_format,
                               text_color='#303030', text_font="arial", text_font_style="italic", text_font_size="10pt")
    fig.add_layout(filer_income_label)
    axis = Span(location=0, dimension='width',
                line_color='#bfbfbf', line_width=1.5)
    fig.add_layout(axis)

    iitax_atr_base = fig.line(x="Axis", y="IATR", line_color='#2b83ba', muted_color='#2b83ba',
                              line_width=2, legend_label="Income Tax Average Rate", muted_alpha=0.1, source=df_base)
    payroll_atr_base = fig.line(x="Axis", y="PATR", line_color='#abdda4', muted_color='#abdda4',
                                line_width=2, legend_label='Payroll Tax Average Rate', muted_alpha=0.1, source=df_base)
    iitax_mtr_base = fig.line(x="Axis", y="Income Tax MTR", line_color='#fdae61', muted_color='#fdae61',
                              line_width=2, legend_label="Income Tax Marginal Rate", muted_alpha=0.1, source=df_base)
    payroll_mtr_base = fig.line(x="Axis", y="Payroll Tax MTR", line_color='#d7191c', muted_color='#d7191c',
                                line_width=2, legend_label='Payroll Tax Marginal Rate', muted_alpha=0.1, source=df_base)

    iitax_atr_reform = fig.line(x="Axis", y="IATR", line_color='#2b83ba', muted_color='#2b83ba', line_width=2,
                                line_dash='dashed', legend_label="Income Tax Average Rate", muted_alpha=0.1, source=df_reform)
    payroll_atr_reform = fig.line(x="Axis", y="PATR", line_color='#abdda4', muted_color='#abdda4', line_width=2,
                                  line_dash='dashed', legend_label='Payroll Tax Average Rate', muted_alpha=0.1, source=df_reform)
    iitax_mtr_reform = fig.line(x="Axis", y="Income Tax MTR", line_color='#fdae61', muted_color='#fdae61',
                                line_width=2, line_dash='dashed', legend_label="Income Tax Marginal Rate", muted_alpha=0.1, source=df_reform)
    payroll_mtr_reform = fig.line(x="Axis", y="Payroll Tax MTR", line_color='#d7191c', muted_color='#d7191c',
                                  line_width=2, line_dash='dashed', legend_label='Payroll Tax Marginal Rate', muted_alpha=0.1, source=df_reform)

    iitax_atr_base.muted = False
    iitax_mtr_base.muted = True
    payroll_atr_base.muted = True
    payroll_mtr_base.muted = True
    iitax_atr_reform.muted = False
    iitax_mtr_reform.muted = True
    payroll_atr_reform.muted = True
    payroll_mtr_reform.muted = True

    plot_js = """
    object1.visible = toggle.active
    object2.visible = toggle.active
    object3.visible = toggle.active
    object4.visible = toggle.active
    """
    base_callback = CustomJS(code=plot_js, args={})
    base_toggle = Toggle(label="Base (Solid)", button_type="default",
                         active=True)
    base_callback.args = {"toggle": base_toggle, "object1": iitax_atr_base,
                          "object2": payroll_atr_base, "object3": iitax_mtr_base,
                          "object4": payroll_mtr_base}
    base_toggle.js_on_change('active', base_callback)

    reform_callback = CustomJS(code=plot_js, args={})
    reform_toggle = Toggle(label="Reform (Dashed)", button_type="default",
                           active=True)
    reform_callback.args = {"toggle": reform_toggle, "object1": iitax_atr_reform,
                            "object2": payroll_atr_reform, "object3": iitax_mtr_reform,
                            "object4": payroll_mtr_reform}
    reform_toggle.js_on_change('active', reform_callback)

    fig.xaxis.formatter = NumeralTickFormatter(format="$0,000")
    fig.xaxis.axis_label = mtr_opt
    fig.xaxis.minor_tick_line_color = None

    fig.legend.click_policy = "mute"

    layout = column(fig, row(base_toggle, reform_toggle))

    data = json_item(layout)

    outputs = {
        "media_type": "bokeh",
        "title": "Tax Rates by {} (Holding Other Inputs Constant)".format(mtr_opt),
        "data": data
    }

    return outputs


def credit_plot(df_base, df_reform, span, mtr_opt):

0 View Source File : outputs.py
License : MIT License
Project Creator : PSLmodels

def credit_plot(df_base, df_reform, span, mtr_opt):
    df_base = ColumnDataSource(df_base)
    df_reform = ColumnDataSource(df_reform)
    tools = "pan, zoom_in, zoom_out, reset"
    fig = figure(plot_width=600, plot_height=500, x_range=(
        -2500, 70000), tools=tools, active_drag="pan")

    filer_income = Span(location=span, dimension='height',
                        line_color='black', line_dash='dotted', line_width=1.5)
    fig.add_layout(filer_income)
    label_format = f'{span:,}'
    filer_income_label = Label(x=span, y=45, y_units='screen', x_offset=10, text="{}: $".format(mtr_opt) + label_format,
                               text_color='#303030', text_font="arial", text_font_style="italic", text_font_size="10pt")
    fig.add_layout(filer_income_label)
    axis = Span(location=0, dimension='width',
                line_color='#bfbfbf', line_width=1.5)
    fig.add_layout(axis)

    eitc_base = fig.line(x="Axis", y="EITC", line_color='#2b83ba', muted_color='#2b83ba',
                         line_width=2, legend_label="Earned Income Tax Credit", muted_alpha=0.1, source=df_base)
    ctc_base = fig.line(x="Axis", y="CTC", line_color='#abdda4', muted_color='#abdda4',
                        line_width=2, legend_label='Nonrefundable Child Tax Credit', muted_alpha=0.1, source=df_base)
    ctc_refund_base = fig.line(x="Axis", y="CTC Refundable", line_color='#fdae61', muted_color='#fdae61',
                               line_width=2, legend_label='Refundable Child Tax Credit', muted_alpha=0.1, source=df_base)
    cdcc_base = fig.line(x="Axis", y="Child care credit", line_color='#d7191c', muted_color='#d7191c',
                         line_width=2, legend_label='Child and Dependent Care Credit', muted_alpha=0.1, source=df_base)

    eitc_reform = fig.line(x="Axis", y="EITC", line_color='#2b83ba', muted_color='#2b83ba', line_width=2,
                           line_dash='dashed', legend_label="Earned Income Tax Credit", muted_alpha=0.1, source=df_reform)
    ctc_reform = fig.line(x="Axis", y="CTC", line_color='#abdda4', muted_color='#abdda4', line_width=2,
                          line_dash='dashed', legend_label='Nonrefundable Child Tax Credit', muted_alpha=0.1, source=df_reform)
    ctc_refund_reform = fig.line(x="Axis", y="CTC Refundable", line_color='#fdae61', muted_color='#fdae61',
                                 line_width=2, line_dash='dashed', legend_label='Refundable Child Tax Credit', muted_alpha=0.1, source=df_reform)
    cdcc_reform = fig.line(x="Axis", y="Child care credit", line_color='#d7191c', muted_color='#d7191c', line_width=2,
                           line_dash='dashed', legend_label='Child and Dependent Care Credit', muted_alpha=0.1, source=df_reform)

    ctc_base.muted = True
    ctc_refund_base.muted = True
    cdcc_base.muted = True
    ctc_reform.muted = True
    ctc_refund_reform.muted = True
    cdcc_reform.muted = True

    plot_js = """
    object1.visible = toggle.active
    object2.visible = toggle.active
    object3.visible = toggle.active
    object4.visible = toggle.active
    """
    base_callback = CustomJS(code=plot_js, args={})
    base_toggle = Toggle(label="Base (Solid)", button_type="default",
                         active=True)
    base_callback.args = {"toggle": base_toggle, "object1": eitc_base,
                          "object2": cdcc_base, "object3": ctc_base,
                          "object4": ctc_refund_base}
    base_toggle.js_on_change('active', base_callback)

    reform_callback = CustomJS(code=plot_js, args={})
    reform_toggle = Toggle(label="Reform (Dashed)", button_type="default",
                           active=True)
    reform_callback.args = {"toggle": reform_toggle, "object1": eitc_reform,
                            "object2": cdcc_reform, "object3": ctc_reform,
                            "object4": ctc_refund_reform}
    reform_toggle.js_on_change('active', reform_callback)

    fig.yaxis.formatter = NumeralTickFormatter(format="$0,000")
    fig.yaxis.axis_label = "Tax Credits"
    fig.xaxis.formatter = NumeralTickFormatter(format="$0,000")
    fig.xaxis.axis_label = mtr_opt
    fig.xaxis.minor_tick_line_color = None

    fig.legend.click_policy = "mute"

    layout = column(fig, row(base_toggle, reform_toggle))

    data = json_item(layout)

    outputs = {
        "media_type": "bokeh",
        "title": "Tax Credits by {} (Holding Other Inputs Constant)".format(mtr_opt),
        "data": data
    }

    return outputs

0 View Source File : pyphi_plots.py
License : MIT License
Project Creator : salvadorgarciamunoz

def loadings(mvmobj,*,plotwidth=600,xgrid=False):
    """
    Column plots of loadings
    by Salvador Garcia-Munoz 
    ([email protected] ,[email protected])
    
    mvmobj: A model created with phi.pca or phi.pls
    
    """
    
    A= mvmobj['T'].shape[1]
    num_varX=mvmobj['P'].shape[0]    
    if 'Q' in mvmobj:
        is_pls=True
        lv_prefix='LV #'
    else:
        is_pls=False
        lv_prefix='PC #'       
    lv_labels = []   
    for a in list(np.arange(A)+1):
        lv_labels.append(lv_prefix+str(a))    
    if 'varidX' in mvmobj:
        X_loading_dict = {'XVar': mvmobj['varidX']}
        XVar=mvmobj['varidX']
    else:
        XVar = []
        for n in list(np.arange(num_varX)+1):
            XVar.append('XVar #'+str(n))               
        X_loading_dict = {'XVar': XVar}
    if 'Q' in mvmobj:
        for i in list(np.arange(A)):
            X_loading_dict.update({lv_labels[i] : mvmobj['Ws'][:,i].tolist()})
            
        num_varY=mvmobj['Q'].shape[0]
        if 'varidY' in mvmobj:
            Q_dict = {'YVar': mvmobj['varidY']}
            YVar=mvmobj['varidY']
        else:
            YVar = []
            for n in list(np.arange(num_varY)+1):
                YVar.append('YVar #'+str(n))               
            Q_dict = {'YVar': YVar}
        for i in list(np.arange(A)):
            Q_dict.update({lv_labels[i] : mvmobj['Q'][:,i].tolist()})
    else:
        for i in list(np.arange(A)):
            X_loading_dict.update({lv_labels[i] : mvmobj['P'][:,i].tolist()})
            
    TOOLS = "save,wheel_zoom,box_zoom,pan,reset,box_select,lasso_select"
    TOOLTIPS = [
                ("Variable:","@names")
                ]
      
    if is_pls:
        rnd_num=str(int(np.round(1000*np.random.random_sample())))
        output_file("Loadings X Space_"+rnd_num+".html",title='X Loadings PLS')
        for i in list(np.arange(A)):
            p = figure(x_range=XVar, title="X Space Loadings "+lv_labels[i],
                    tools=TOOLS,tooltips=TOOLTIPS,plot_width=plotwidth)
            source1 = ColumnDataSource(data=dict(x_=XVar, y_=mvmobj['Ws'][:,i].tolist(),names=XVar)) 
            
            #p.vbar(x=XVar, top=mvmobj['Ws'][:,i].tolist(), width=0.5)
            p.vbar(x='x_', top='y_', source=source1,width=0.5)
            p.ygrid.grid_line_color = None    
            if xgrid:
                p.xgrid.grid_line_color = 'lightgray'
                
            else:
                p.xgrid.grid_line_color = None    
                
            p.yaxis.axis_label = 'W* ['+str(i+1)+']'
            hline = Span(location=0, dimension='width', line_color='black', line_width=2)
            p.renderers.extend([hline])
            p.xaxis.major_label_orientation = 45
            if i==0:
                p_list=[p]
            else:
                p_list.append(p)
        show(column(p_list))    
        rnd_num=str(int(np.round(1000*np.random.random_sample())))
        output_file("Loadings Y Space_"+rnd_num+".html",title='Y Loadings PLS')
        for i in list(np.arange(A)):
            p = figure(x_range=YVar, title="Y Space Loadings "+lv_labels[i],
                    tools="save,box_zoom,pan,reset",tooltips=TOOLTIPS,plot_width=plotwidth)
            
            source1 = ColumnDataSource(data=dict(x_=YVar, y_=mvmobj['Q'][:,i].tolist(),names=YVar)) 
            #p.vbar(x=YVar, top=mvmobj['Q'][:,i].tolist(), width=0.5)
            p.vbar(x='x_', top='y_', source=source1,width=0.5)
            p.ygrid.grid_line_color = None    
            if xgrid:
                p.xgrid.grid_line_color = 'lightgray'
            else:
                p.xgrid.grid_line_color = None    
            p.yaxis.axis_label = 'Q ['+str(i+1)+']'
            hline = Span(location=0, dimension='width', line_color='black', line_width=2)
            p.renderers.extend([hline])
            p.xaxis.major_label_orientation = 45
            if i==0:
                p_list=[p]
            else:
                p_list.append(p)                    
        show(column(p_list))
    else:   
        rnd_num=str(int(np.round(1000*np.random.random_sample())))
        output_file("Loadings X Space_"+rnd_num+".html",title='X Loadings PCA') 
        for i in list(np.arange(A)):
            source1 = ColumnDataSource(data=dict(x_=XVar, y_=mvmobj['P'][:,i].tolist(),names=XVar))  
            
            p = figure(x_range=XVar, title="X Space Loadings "+lv_labels[i],
                    tools=TOOLS,tooltips=TOOLTIPS,plot_width=plotwidth)
            
            #p.vbar(x=XVar, top=mvmobj['P'][:,i].tolist(), width=0.5)
            
            p.vbar(x='x_', top='y_', source=source1,width=0.5)
            if xgrid:
                p.xgrid.grid_line_color = 'lightgray'
            else:
                p.xgrid.grid_line_color = None    
            p.yaxis.axis_label = 'P ['+str(i+1)+']'
            hline = Span(location=0, dimension='width', line_color='black', line_width=2)
            p.renderers.extend([hline])
            p.xaxis.major_label_orientation = 45
            if i==0:
                p_list=[p]
            else:
                p_list.append(p)
        show(column(p_list))
    return    

def loadings_map(mvmobj,dims,*,plotwidth=600):

0 View Source File : pyphi_plots.py
License : MIT License
Project Creator : salvadorgarciamunoz

def loadings_map(mvmobj,dims,*,plotwidth=600):
    """
    Scatter plot overlaying X and Y loadings 
    by Salvador Garcia-Munoz 
    ([email protected] ,[email protected])
    
    mvmobj: A model created with phi.pca or phi.pls
    dims: what latent spaces to plot in x and y axes e.g. dims=[1,2]
    """
    A= mvmobj['T'].shape[1]
    num_varX=mvmobj['P'].shape[0]    
    if 'Q' in mvmobj:
        lv_prefix='LV #'     
        lv_labels = []   
        for a in list(np.arange(A)+1):
            lv_labels.append(lv_prefix+str(a))    
        if 'varidX' in mvmobj:
            XVar=mvmobj['varidX']
        else:
            XVar = []
            for n in list(np.arange(num_varX)+1):
                XVar.append('XVar #'+str(n))               
        num_varY=mvmobj['Q'].shape[0]
        if 'varidY' in mvmobj:
            YVar=mvmobj['varidY']
        else:
            YVar = []
            for n in list(np.arange(num_varY)+1):
                YVar.append('YVar #'+str(n))               
    
        rnd_num=str(int(np.round(1000*np.random.random_sample())))
        output_file("Loadings Map"+rnd_num+".html",title='Loadings Map')
       
    
        x_ws = mvmobj['Ws'][:,dims[0]-1]
        x_ws = x_ws/np.max(np.abs(x_ws))
        y_ws = mvmobj['Ws'][:,dims[1]-1]
        y_ws = y_ws/np.max(np.abs(y_ws))
        
        x_q = mvmobj['Q'][:,dims[0]-1]
        x_q = x_q/np.max(np.abs(x_q))   
        y_q = mvmobj['Q'][:,dims[1]-1]
        y_q = y_q/np.max(np.abs(y_q))
        
        
        TOOLS = "save,wheel_zoom,box_zoom,pan,reset,box_select,lasso_select"
        TOOLTIPS = [
                ("index", "$index"),
                ("(x,y)", "($x, $y)"),
                ("Variable:","@names")
                ]
    
        source1 = ColumnDataSource(data=dict(x=x_ws, y=y_ws,names=XVar))  
        source2 = ColumnDataSource(data=dict(x=x_q, y=y_q,names=YVar)) 
        p = figure(tools=TOOLS, tooltips=TOOLTIPS,plot_width=plotwidth, title="Loadings Map LV["+str(dims[0])+"] - LV["+str(dims[1])+"]",
                                                                                                          x_range=(-1.5,1.5),y_range=(-1.5,1.5))
        p.circle('x', 'y', source=source1,size=10,color='darkblue')
        p.circle('x', 'y', source=source2,size=10,color='red')
        p.xaxis.axis_label = lv_labels [dims[0]-1]
        p.yaxis.axis_label = lv_labels [dims[1]-1]
        
        labelsX = LabelSet(x='x', y='y', text='names', level='glyph',x_offset=5, y_offset=5, source=source1, render_mode='canvas',text_color='darkgray')
        labelsY = LabelSet(x='x', y='y', text='names', level='glyph',x_offset=5, y_offset=5, source=source2, render_mode='canvas',text_color='darkgray')
        p.add_layout(labelsX)
        p.add_layout(labelsY)

        vline = Span(location=0, dimension='height', line_color='black', line_width=2)
        # Horizontal line
        hline = Span(location=0, dimension='width', line_color='black', line_width=2)
        p.renderers.extend([vline, hline])
        show(p)    
    else:
        lv_prefix='PC #'     
        lv_labels = []   
        for a in list(np.arange(A)+1):
            lv_labels.append(lv_prefix+str(a))    
        if 'varidX' in mvmobj:
            XVar=mvmobj['varidX']
        else:
            XVar = []
            for n in list(np.arange(num_varX)+1):
                XVar.append('XVar #'+str(n))                   
        rnd_num=str(int(np.round(1000*np.random.random_sample())))
        output_file("Loadings Map"+rnd_num+".html",title='Loadings Map')    
        x_p = mvmobj['P'][:,dims[0]-1]
        y_p = mvmobj['P'][:,dims[1]-1]                        
        TOOLS = "save,wheel_zoom,box_zoom,pan,reset,box_select,lasso_select"
        TOOLTIPS = [
                ("index", "$index"),
                ("(x,y)", "($x, $y)"),
                ("Variable:","@names")
                ]
    
        source1 = ColumnDataSource(data=dict(x=x_p, y=y_p,names=XVar))  
        p = figure(tools=TOOLS, tooltips=TOOLTIPS,plot_width=plotwidth, title="Loadings Map PC["+str(dims[0])+"] - PC["+str(dims[1])+"]",                                                                                                         x_range=(-1.5,1.5),y_range=(-1.5,1.5))
        p.circle('x', 'y', source=source1,size=10,color='darkblue')
        p.xaxis.axis_label = lv_labels [dims[0]-1]
        p.yaxis.axis_label = lv_labels [dims[1]-1]        
        labelsX = LabelSet(x='x', y='y', text='names', level='glyph',x_offset=5, y_offset=5, source=source1, render_mode='canvas',text_color='darkgray')
        p.add_layout(labelsX)
        vline = Span(location=0, dimension='height', line_color='black', line_width=2)
        # Horizontal line
        hline = Span(location=0, dimension='width', line_color='black', line_width=2)
        p.renderers.extend([vline, hline])
        show(p)            
    return  

def weighted_loadings(mvmobj,*,plotwidth=600,xgrid=False):

0 View Source File : pyphi_plots.py
License : MIT License
Project Creator : salvadorgarciamunoz

def weighted_loadings(mvmobj,*,plotwidth=600,xgrid=False):
    """
    Column plots of loadings weighted by r2x/r2y correspondingly
    by Salvador Garcia-Munoz 
    ([email protected] ,[email protected])
    
    mvmobj: A model created with phi.pca or phi.pls
    
    """
    A= mvmobj['T'].shape[1]
    num_varX=mvmobj['P'].shape[0]    
    if 'Q' in mvmobj:
        is_pls=True
        lv_prefix='LV #'
    else:
        is_pls=False
        lv_prefix='PC #'       
    lv_labels = []   
    for a in list(np.arange(A)+1):
        lv_labels.append(lv_prefix+str(a))    
    if 'varidX' in mvmobj:
        X_loading_dict = {'XVar': mvmobj['varidX']}
        XVar=mvmobj['varidX']
    else:
        XVar = []
        for n in list(np.arange(num_varX)+1):
            XVar.append('XVar #'+str(n))               
        X_loading_dict = {'XVar': XVar}
    if 'Q' in mvmobj:
        for i in list(np.arange(A)):
            X_loading_dict.update({lv_labels[i] : mvmobj['Ws'][:,i].tolist()})
            
        num_varY=mvmobj['Q'].shape[0]
        if 'varidY' in mvmobj:
            Q_dict = {'YVar': mvmobj['varidY']}
            YVar=mvmobj['varidY']
        else:
            YVar = []
            for n in list(np.arange(num_varY)+1):
                YVar.append('YVar #'+str(n))               
            Q_dict = {'YVar': YVar}
        for i in list(np.arange(A)):
            Q_dict.update({lv_labels[i] : mvmobj['Q'][:,i].tolist()})
    else:
        for i in list(np.arange(A)):
            X_loading_dict.update({lv_labels[i] : mvmobj['P'][:,i].tolist()})
            
    TOOLS = "save,wheel_zoom,box_zoom,pan,reset,box_select,lasso_select"
    TOOLTIPS = [
                ("Variable:","@names")
                ]
    
    if is_pls:
        rnd_num=str(int(np.round(1000*np.random.random_sample())))
        output_file("Loadings X Space_"+rnd_num+".html",title='X Weighted Loadings PLS')
        for i in list(np.arange(A)):
            p = figure(x_range=XVar, title="X Space Weighted Loadings "+lv_labels[i],
                     tools=TOOLS,tooltips=TOOLTIPS,plot_width=plotwidth)
            source1 = ColumnDataSource(data=dict(x_=XVar, y_=(mvmobj['r2xpv'][:,i] * mvmobj['Ws'][:,i]).tolist(),names=XVar)) 
             
            #p.vbar(x=XVar, top=(mvmobj['r2xpv'][:,i] * mvmobj['Ws'][:,i]).tolist(), width=0.5)
            p.vbar(x='x_', top='y_', source=source1,width=0.5)
            p.ygrid.grid_line_color = None    
            if xgrid:
                p.xgrid.grid_line_color = 'lightgray'
            else:
                p.xgrid.grid_line_color = None    

            p.yaxis.axis_label = 'W* ['+str(i+1)+']'
            hline = Span(location=0, dimension='width', line_color='black', line_width=2)
            p.renderers.extend([hline])
            p.xaxis.major_label_orientation = 45
            if i==0:
                p_list=[p]
            else:
                p_list.append(p)
        show(column(p_list)) 
        rnd_num=str(int(np.round(1000*np.random.random_sample())))
        output_file("Loadings Y Space_"+rnd_num+".html",title='Y Weighted Loadings PLS')
        for i in list(np.arange(A)):
            p = figure(x_range=YVar, title="Y Space Weighted Loadings "+lv_labels[i],
                     tools=TOOLS,tooltips=TOOLTIPS,plot_width=plotwidth)
            source1 = ColumnDataSource(data=dict(x_=YVar, y_=(mvmobj['r2ypv'][:,i] * mvmobj['Q'][:,i]).tolist(),names=YVar)) 
            
            #p.vbar(x=YVar, top=(mvmobj['r2ypv'][:,i] * mvmobj['Q'][:,i]).tolist(), width=0.5)
            p.vbar(x='x_', top='y_', source=source1,width=0.5)
            p.ygrid.grid_line_color = None    
            if xgrid:
                p.xgrid.grid_line_color = 'lightgray'
            else:
                p.xgrid.grid_line_color = None    
            p.yaxis.axis_label = 'Q ['+str(i+1)+']'
            hline = Span(location=0, dimension='width', line_color='black', line_width=2)
            p.renderers.extend([hline])
            p.xaxis.major_label_orientation = 45
            if i==0:
                p_list=[p]
            else:
                p_list.append(p)                    
        show(column(p_list))
    else:   
        rnd_num=str(int(np.round(1000*np.random.random_sample())))
        output_file("Loadings X Space_"+rnd_num+".html",title='X Weighted Loadings PCA') 
        for i in list(np.arange(A)):
            p = figure(x_range=XVar, title="X Space Weighted Loadings "+lv_labels[i],
                     tools=TOOLS,tooltips=TOOLTIPS,plot_width=plotwidth)
            source1 = ColumnDataSource(data=dict(x_=XVar, y_=(mvmobj['r2xpv'][:,i] * mvmobj['P'][:,i]).tolist(),names=XVar)) 
            
            #p.vbar(x=XVar, top=(mvmobj['r2xpv'][:,i] * mvmobj['P'][:,i]).tolist(), width=0.5)
            p.vbar(x='x_', top='y_', source=source1,width=0.5)
            p.ygrid.grid_line_color = None    
            if xgrid:
                p.xgrid.grid_line_color = 'lightgray'
            else:
                p.xgrid.grid_line_color = None    

            p.yaxis.axis_label = 'P ['+str(i+1)+']'
            hline = Span(location=0, dimension='width', line_color='black', line_width=2)
            p.renderers.extend([hline])
            p.xaxis.major_label_orientation = 45
            if i==0:
                p_list=[p]
            else:
                p_list.append(p)
        show(column(p_list))
    return  
 
def vip(mvmobj,*,plotwidth=600):

0 View Source File : pyphi_plots.py
License : MIT License
Project Creator : salvadorgarciamunoz

def score_scatter(mvmobj,xydim,*,CLASSID=False,colorby=False,Xnew=False,add_ci=False,add_labels=False,add_legend=True,plotwidth=600,plotheight=600):
    '''
    Score scatter plot
    by Salvador Garcia-Munoz 
    ([email protected] ,[email protected])
    
    mvmobj     : PLS or PCA object from phyphi
    xydim      : LV to plot on x and y axes. eg [1,2] will plot t1 vs t2
    CLASSID    : Pandas DataFrame with CLASSIDS
    colorby    : Category (one of the CLASSIDS) to color by
    Xnew       : New data for which to make the score plot this routine evaluates and plots
    add_ci     : when = True will add confidence intervals
    add_labels : When = True labels each point with Obs ID
    plotwidth  : If omitted, width is 600
    '''
    
    if isinstance(Xnew,bool):
        if 'obsidX' in mvmobj:
            ObsID_=mvmobj['obsidX']
        else:
            ObsID_ = []
            for n in list(np.arange(mvmobj['T'].shape[0])+1):
                ObsID_.append('Obs #'+str(n))  
        T_matrix=mvmobj['T']
    else:
        if isinstance(Xnew,np.ndarray):
            X_=Xnew.copy()
            ObsID_ = []
            for n in list(np.arange(Xnew.shape[0])+1):
                ObsID_.append('Obs #'+str(n))  
        elif isinstance(Xnew,pd.DataFrame):
            X_=np.array(Xnew.values[:,1:]).astype(float)
            ObsID_ = Xnew.values[:,0].astype(str)
            ObsID_ = ObsID_.tolist()
            
        if 'Q' in mvmobj:  
            xpred=phi.pls_pred(X_,mvmobj)
        else:
            xpred=phi.pca_pred(X_,mvmobj)
        T_matrix=xpred['Tnew']
        
    ObsNum_=[]    
    for n in list(range(1,len(ObsID_)+1)):
                ObsNum_.append('Obs #'+str(n))  
    
    if isinstance(CLASSID,np.bool): # No CLASSIDS
        rnd_num=str(int(np.round(1000*np.random.random_sample())))
        output_file("Score_Scatter_"+rnd_num+".html",title='Score Scatter t['+str(xydim[0])+'] - t['+str(xydim[1])+ ']')

        x_=T_matrix[:,[xydim[0]-1]]
        y_=T_matrix[:,[xydim[1]-1]]

           
        source = ColumnDataSource(data=dict(x=x_, y=y_,ObsID=ObsID_,ObsNum=ObsNum_))
        TOOLS = "save,wheel_zoom,box_zoom,pan,reset,box_select,lasso_select"
        TOOLTIPS = [
                ("Obs #", "@ObsNum"),
                ("(x,y)", "($x, $y)"),
                ("Obs: ","@ObsID")
                ]
        
        p = figure(tools=TOOLS, tooltips=TOOLTIPS,plot_width=plotwidth,plot_height=plotheight, title='Score Scatter t['+str(xydim[0])+'] - t['+str(xydim[1])+ ']')
        p.circle('x', 'y', source=source,size=7)
        if add_ci:
            T_aux1=mvmobj['T'][:,[xydim[0]-1]]
            T_aux2=mvmobj['T'][:,[xydim[1]-1]]
            T_aux = np.hstack((T_aux1,T_aux2))
            st=(T_aux.T @ T_aux)/T_aux.shape[0]
            [xd95,xd99,yd95p,yd95n,yd99p,yd99n]=phi.scores_conf_int_calc(st,mvmobj['T'].shape[0])
            p.line(xd95,yd95p,line_color="gold",line_dash='dashed')
            p.line(xd95,yd95n,line_color="gold",line_dash='dashed')
            p.line(xd99,yd99p,line_color="red",line_dash='dashed')
            p.line(xd99,yd99n,line_color="red",line_dash='dashed')
            
        if add_labels:
            labelsX = LabelSet(x='x', y='y', text='ObsID', level='glyph',x_offset=5, y_offset=5, source=source, render_mode='canvas')
            p.add_layout(labelsX)
        p.xaxis.axis_label = 't ['+str(xydim[0])+']'
        p.yaxis.axis_label = 't ['+str(xydim[1])+']'
        # Vertical line
        vline = Span(location=0, dimension='height', line_color='black', line_width=2)
        # Horizontal line
        hline = Span(location=0, dimension='width', line_color='black', line_width=2)
        p.renderers.extend([vline, hline])
        show(p)      
    else: # YES CLASSIDS
    
        Classes_=np.unique(CLASSID[colorby]).tolist()        
        
        A=len(Classes_)
        colormap =cm.get_cmap("rainbow")
        different_colors=A
        color_mapping=colormap(np.linspace(0,1,different_colors),1,True)
        bokeh_palette=["#%02x%02x%02x" % (r, g, b) for r, g, b in color_mapping[:,0:3]]  
        rnd_num=str(int(np.round(1000*np.random.random_sample())))               
        output_file("Score_Scatter_"+rnd_num+".html",title='Score Scatter t['+str(xydim[0])+'] - t['+str(xydim[1])+ ']') 
        x_=T_matrix[:,[xydim[0]-1]]
        y_=T_matrix[:,[xydim[1]-1]]          
        
        TOOLS = "save,wheel_zoom,box_zoom,pan,reset,box_select,lasso_select"
        TOOLTIPS = [
                ("Obs #", "@ObsNum"),
                ("(x,y)", "($x, $y)"),
                ("Obs: ","@ObsID"),
                ("Class:","@Class")
                ]        
        classid_=list(CLASSID[colorby])
        legend_it = []
        
        p = figure(tools=TOOLS, tooltips=TOOLTIPS,toolbar_location="above",plot_width=plotwidth,plot_height=plotheight,title='Score Scatter t['+str(xydim[0])+'] - t['+str(xydim[1])+ ']')

        for classid_in_turn in Classes_:                      
            x_aux       = []
            y_aux       = []
            obsid_aux   = []
            obsnum_aux  = []
            classid_aux = []
            
            for i in list(range(len(ObsID_))):
                
                if classid_[i]==classid_in_turn:
                    x_aux.append(x_[i][0])
                    y_aux.append(y_[i][0])
                    obsid_aux.append(ObsID_[i])
                    obsnum_aux.append(ObsNum_[i])
                    classid_aux.append(classid_in_turn)
            source = ColumnDataSource(data=dict(x=x_aux, y=y_aux,ObsID=obsid_aux,ObsNum=obsnum_aux, Class=classid_aux))        
            color_=bokeh_palette[Classes_.index(classid_in_turn)]
            if add_legend:
                c = p.circle('x','y',source=source,color=color_)
                aux_=classid_in_turn
                if isinstance(aux_,(float,int)):
                    aux_=str(aux_)
                #legend_it.append((classid_in_turn, [c]))
                legend_it.append((aux_, [c]))
            else:
                p.circle('x','y',source=source,color=color_)
            if add_labels:
                labelsX = LabelSet(x='x', y='y', text='ObsID', level='glyph',x_offset=5, y_offset=5, source=source, render_mode='canvas')
                p.add_layout(labelsX)
        if add_ci:
            T_aux1=mvmobj['T'][:,[xydim[0]-1]]
            T_aux2=mvmobj['T'][:,[xydim[1]-1]]
            T_aux = np.hstack((T_aux1,T_aux2))
            st=(T_aux.T @ T_aux)/T_aux.shape[0]
            [xd95,xd99,yd95p,yd95n,yd99p,yd99n]=phi.scores_conf_int_calc(st,mvmobj['T'].shape[0])
            p.line(xd95,yd95p,line_color="gold",line_dash='dashed')
            p.line(xd95,yd95n,line_color="gold",line_dash='dashed')
            p.line(xd99,yd99p,line_color="red",line_dash='dashed')
            p.line(xd99,yd99n,line_color="red",line_dash='dashed') 
        p.xaxis.axis_label = 't ['+str(xydim[0])+']'
        p.yaxis.axis_label = 't ['+str(xydim[1])+']'
        # Vertical line
        vline = Span(location=0, dimension='height', line_color='black', line_width=2)
        # Horizontal line
        hline = Span(location=0, dimension='width', line_color='black', line_width=2)
        p.renderers.extend([vline, hline])
        if add_legend:
            legend = Legend(items=legend_it, location='top_right')
            p.add_layout(legend, 'right')
            legend.click_policy="hide"
        show(p)
    return    

def score_line(mvmobj,dim,*,CLASSID=False,colorby=False,Xnew=False,add_ci=False,add_labels=False,add_legend=True,plotline=True,plotwidth=600,plotheight=600):

0 View Source File : pyphi_plots.py
License : MIT License
Project Creator : salvadorgarciamunoz

def diagnostics(mvmobj,*,Xnew=False,Ynew=False,score_plot_xydim=False,plotwidth=600,ht2_logscale=False,spe_logscale=False):
    """
    Plot calculated Hotelling's T2 and SPE
    by Salvador Garcia-Munoz 
    ([email protected] ,[email protected])
    
    mvmobj: A model created with phi.pca or phi.pls
    
    Xnew/Ynew:     Data used to calculate diagnostics[numpy arrays or pandas dataframes] 
    
    optional:
        
    score_plot_xydim: will add a score scatter plot at the bottom 
                      if sent with a list of [dimx, dimy] where dimx/dimy 
                      are integers and refer to the latent space to plot
                      in the x and y axes of the scatter plot. e.g. [1,2] will
                      add a t1-t2 plot 
    
    """
    
    if isinstance(score_plot_xydim,np.bool):
        add_score_plot = False
    else:
        add_score_plot = True
        
    if isinstance(Xnew,np.bool): #No Xnew was given need to plot all from model
        if 'obsidX' in mvmobj:
            ObsID_=mvmobj['obsidX']
        else:
            ObsID_ = []
            for n in list(np.arange(mvmobj['T'].shape[0])+1):
                ObsID_.append('Obs #'+str(n))  
                              
        Obs_num = np.arange(mvmobj['T'].shape[0])+1
        
        if add_score_plot and not(isinstance(score_plot_xydim,np.bool)):
            t_x  = mvmobj['T'][:,[score_plot_xydim[0]-1]]
            t_y  = mvmobj['T'][:,[score_plot_xydim[1]-1]]
        else:
            add_score_plot = False
        t2_   = mvmobj['T2']
        spex_ = mvmobj['speX']
        
        if ht2_logscale:
            t2_=np.log10(t2_)
        if spe_logscale:
            spex_= np.log10(spex_)
            
            
        if not(add_score_plot):
            if 'Q' in mvmobj:
                spey_=1
                source = ColumnDataSource(data=dict(x=Obs_num, ObsID=ObsID_,t2=t2_,spex=spex_,spey=mvmobj['speY']))  
            else:
                source = ColumnDataSource(data=dict(x=Obs_num, ObsID=ObsID_,t2=t2_,spex=spex_)) 
        else:
            if 'Q' in mvmobj:
                spey_=1
                source = ColumnDataSource(data=dict(x=Obs_num, ObsID=ObsID_,t2=t2_,spex=spex_,spey=mvmobj['speY'],tx=t_x,ty=t_y))  
            else:
                source = ColumnDataSource(data=dict(x=Obs_num, ObsID=ObsID_,t2=t2_,spex=spex_,tx=t_x,ty=t_y))
    else: #Xnew was given
        if isinstance(Xnew,np.ndarray):
            ObsID_ = []
            for n in list(np.arange(Xnew.shape[0])+1):
                ObsID_.append('Obs #'+str(n))  
        elif isinstance(Xnew,pd.DataFrame):
            X_=np.array(Xnew.values[:,1:]).astype(float)
            ObsID_ = Xnew.values[:,0].astype(str)
            ObsID_ = ObsID_.tolist()
            
        
        
        if add_score_plot and not(isinstance(score_plot_xydim,np.bool)):
            if 'Q' in mvmobj:  
                xpred=phi.pls_pred(X_,mvmobj)
            else:
                xpred=phi.pca_pred(X_,mvmobj)
            T_matrix=xpred['Tnew']
            t_x  = T_matrix[:,[score_plot_xydim[0]-1]]
            t_y  = T_matrix[:,[score_plot_xydim[1]-1]]
        else:
            add_score_plot = False
        
        t2_ = phi.hott2(mvmobj,Xnew=Xnew)
        
        Obs_num = np.arange(t2_.shape[0])+1
        
        if 'Q' in mvmobj and not(isinstance(Ynew,np.bool)):
            spex_,spey_ = phi.spe(mvmobj,Xnew,Ynew=Ynew)
        else:
            spex_ = phi.spe(mvmobj,Xnew)
            spey_ = False
            
        if ht2_logscale:
            t2_=np.log10(t2_)
        if spe_logscale:
            spex_= np.log10(spex_)
        ObsNum_=[]    
        for n in list(range(1,len(ObsID_)+1)):
            ObsNum_.append('Obs #'+str(n))  
                       
                       
        if not(add_score_plot):
            if 'Q' in mvmobj and not(isinstance(Ynew,np.bool)):
                source = ColumnDataSource(data=dict(x=Obs_num, ObsID=ObsID_,ObsNum=ObsNum_,t2=t2_,spex=spex_,spey=spey_))  
            else:
                source = ColumnDataSource(data=dict(x=Obs_num, ObsID=ObsID_,ObsNum=ObsNum_,t2=t2_,spex=spex_)) 
        else:
            if 'Q' in mvmobj and not(isinstance(Ynew,np.bool)):
                source = ColumnDataSource(data=dict(x=Obs_num, ObsID=ObsID_,ObsNum=ObsNum_,t2=t2_,spex=spex_,spey=spey_,tx=t_x,ty=t_y))  
            else:
                source = ColumnDataSource(data=dict(x=Obs_num, ObsID=ObsID_,ObsNum=ObsNum_,t2=t2_,spex=spex_,tx=t_x,ty=t_y))
    TOOLS = "save,wheel_zoom,box_zoom,reset,lasso_select"
    TOOLTIPS = [
            ("Obs #", "@x"),
            ("(x,y)", "($x, $y)"),
            ("Obs: ","@ObsID")
            ]
    
    rnd_num=str(int(np.round(1000*np.random.random_sample())))               
    output_file("Diagnostics"+rnd_num+".html",title='Diagnostics') 
    p = figure(tools=TOOLS, tooltips=TOOLTIPS, plot_width=plotwidth, title="Hotelling's T2")
    p.circle('x','t2',source=source)
    if ht2_logscale:
        p.line([0,Obs_num[-1]],[np.log10(mvmobj['T2_lim95']),np.log10(mvmobj['T2_lim95'])],line_color='gold')
        p.line([0,Obs_num[-1]],[np.log10(mvmobj['T2_lim99']),np.log10(mvmobj['T2_lim99'])],line_color='red')
    else:        
        p.line([0,Obs_num[-1]],[mvmobj['T2_lim95'],mvmobj['T2_lim95']],line_color='gold')
        p.line([0,Obs_num[-1]],[mvmobj['T2_lim99'],mvmobj['T2_lim99']],line_color='red')
    
    
    p.xaxis.axis_label = 'Observation sequence'
    p.yaxis.axis_label = "HT2"
    p_list=[p]
    
    p = figure(tools=TOOLS, tooltips=TOOLTIPS, plot_width=plotwidth, title='SPE X')
    p.circle('x','spex',source=source)
    
    if spe_logscale:
        p.line([0,Obs_num[-1]],[np.log10(mvmobj['speX_lim95']),np.log10(mvmobj['speX_lim95'])],line_color='gold')
        p.line([0,Obs_num[-1]],[np.log10(mvmobj['speX_lim99']),np.log10(mvmobj['speX_lim99'])],line_color='red')
    else:  
        p.line([0,Obs_num[-1]],[mvmobj['speX_lim95'],mvmobj['speX_lim95']],line_color='gold')
        p.line([0,Obs_num[-1]],[mvmobj['speX_lim99'],mvmobj['speX_lim99']],line_color='red')
    p.xaxis.axis_label = 'Observation sequence'
    p.yaxis.axis_label = 'SPE X-Space'
    p_list.append(p)
    
    p = figure(tools=TOOLS, tooltips=TOOLTIPS, plot_width=plotwidth, title='Outlier Map')
    p.circle('t2','spex',source=source)
    if ht2_logscale:
        vline = Span(location=np.log10(mvmobj['T2_lim99']), dimension='height', line_color='red', line_width=1)
    else:
        vline = Span(location=mvmobj['T2_lim99'], dimension='height', line_color='red', line_width=1)
    if spe_logscale:    
        hline = Span(location=np.log10(mvmobj['speX_lim99']), dimension='width', line_color='red', line_width=1)
    else:
        hline = Span(location=mvmobj['speX_lim99'], dimension='width', line_color='red', line_width=1)
    p.renderers.extend([vline, hline])
    
    p.xaxis.axis_label = "Hotelling's T2"
    p.yaxis.axis_label = 'SPE X-Space'
    p_list.append(p)
    
    
    if 'Q' in mvmobj and not(isinstance(spey_,np.bool)):
        p = figure(tools=TOOLS, tooltips=TOOLTIPS, plot_height=400, title='SPE Y')
        p.circle('x','spey',source=source)
        p.line([0,Obs_num[-1]],[mvmobj['speY_lim95'],mvmobj['speY_lim95']],line_color='gold')
        p.line([0,Obs_num[-1]],[mvmobj['speY_lim99'],mvmobj['speY_lim99']],line_color='red')
        p.xaxis.axis_label = 'Observation sequence'
        p.yaxis.axis_label = 'SPE Y-Space'
        p_list.append(p)
    if add_score_plot:
        p = figure(tools=TOOLS, tooltips=TOOLTIPS, plot_width=plotwidth, title='Score Scatter')
        p.circle('tx', 'ty', source=source,size=7)
        
        T_aux1=mvmobj['T'][:,[score_plot_xydim[0]-1]]
        T_aux2=mvmobj['T'][:,[score_plot_xydim[1]-1]]
        T_aux = np.hstack((T_aux1,T_aux2))
        st=(T_aux.T @ T_aux)/T_aux.shape[0]
        [xd95,xd99,yd95p,yd95n,yd99p,yd99n]=phi.scores_conf_int_calc(st,mvmobj['T'].shape[0])
        p.line(xd95,yd95p,line_color="gold",line_dash='dashed')
        p.line(xd95,yd95n,line_color="gold",line_dash='dashed')
        p.line(xd99,yd99p,line_color="red",line_dash='dashed')
        p.line(xd99,yd99n,line_color="red",line_dash='dashed') 
        p.xaxis.axis_label = 't ['+str(score_plot_xydim[0])+']'
        p.yaxis.axis_label = 't ['+str(score_plot_xydim[1])+']'
        # Vertical line
        vline = Span(location=0, dimension='height', line_color='black', line_width=2)
        # Horizontal line
        hline = Span(location=0, dimension='width', line_color='black', line_width=2)
        p.renderers.extend([vline, hline])
        #Do another p.figure
        p_list.append(p)
    
    show(column(p_list)) 
    return

def predvsobs(mvmobj,X,Y,*,CLASSID=False,colorby=False,x_space=False):

0 View Source File : pyphi_plots.py
License : MIT License
Project Creator : salvadorgarciamunoz

def contributions_plot(mvmobj,X,cont_type,*,Y=False,from_obs=False,to_obs=False,lv_space=False,plotwidth=800,plotheight=600,xgrid=False):
    """
    Calculate contributions to diagnostics
    by Salvador Garcia-Munoz 
    ([email protected] ,[email protected])
    
    mvmobj : A dictionary created by phi.pls or phi.pca
    
    X/Y:     Data [numpy arrays or pandas dataframes] - Y space is optional
    
    cont_type: 'ht2'
               'spe'
               'scores'
               
    from_obs: Scalar or list of scalars with observation(s) number(s) | first element is #0
              - OR -
              Strings or list of strings with observation(s) name(s) [if X/Y are pandas data frames]
              Used to off set calculations for scores or ht2
              "False' will calculate with respect to origin *default if not sent*
              
    to_obs: Scalar or list of scalars with observation(s) number(s)| first element is #0
              - OR -
            Strings or list of strings with observation(s) name(s) [if X/Y are pandas data frames]
            To calculate contributions for
            
            *Note: from_obs is ignored when cont_type='spe'*
            
    lv_space: Latent spaces over which to do the calculations [applicable to 'ht2' and 'scores']
    """
    good_to_go=True
    if isinstance(X,pd.DataFrame):
        ObsID=X.values[:,0].tolist()
        if isinstance(to_obs,str):
            to_obs_=ObsID.index(to_obs)
        elif isinstance(to_obs,int):
            to_obs_=to_obs
        elif isinstance(to_obs,list):
            if isinstance(to_obs[0],str):
                to_obs_=[]
                for o in to_obs:
                    to_obs_.append(ObsID.index(o))
            elif isinstance(to_obs[0],int):
                to_obs_=to_obs.copy()
        elif isinstance(to_obs,np.bool):
            good_to_go=False
        if not(isinstance(from_obs,np.bool)):
            if isinstance(from_obs,str):
                from_obs_=ObsID.index(from_obs)
            elif isinstance(from_obs,int):
                from_obs_=from_obs
            elif isinstance(from_obs,list):
                if isinstance(from_obs[0],str):
                    from_obs_=[]
                    for o in from_obs:
                        from_obs_.append(ObsID.index(o))
                elif isinstance(from_obs[0],int):
                    from_obs_=from_obs.copy()
        else:
            from_obs_=False
    else:
        if isinstance(to_obs,int) or isinstance(to_obs,list):
            to_obs_=to_obs.copy()
        else:
            good_to_go=False    
    if cont_type=='scores' and not(isinstance(Y,np.bool)):
        Y=False
        
    if isinstance(Y,np.bool) and good_to_go:
        Xconts=phi.contributions(mvmobj,X,cont_type,Y=False,from_obs=from_obs_,to_obs=to_obs_,lv_space=lv_space)
        Yconts=False
    elif not(isinstance(Y,np.bool)) and good_to_go and ('Q' in mvmobj) and cont_type=='spe':
        Xconts,Yconts=phi.contributions(mvmobj,X,cont_type,Y=Y,from_obs=from_obs_,to_obs=to_obs_,lv_space=lv_space)
    
    if 'varidX' in mvmobj:
        XVar=mvmobj['varidX']
    else:
        XVar = []
        for n in list(np.arange(mvmobj['P'].shape[0])+1):
            XVar.append('XVar #'+str(n))               

    rnd_num=str(int(np.round(1000*np.random.random_sample())))
    output_file("Contributions"+rnd_num+".html",title='Contributions')
    if isinstance(from_obs,list):
        from_txt=", ".join(map(str, from_obs))
        from_txt=" from obs: "+from_txt
    elif isinstance(from_obs,int):
        from_txt=" from obs: "+str(from_obs)
    elif isinstance(from_obs,str):    
        from_txt=" from obs: " + from_obs
    else:
        from_txt=""
    if isinstance(to_obs,list):
        to_txt=", ".join(map(str, to_obs))
        to_txt=", to obs: "+to_txt
    elif isinstance(to_obs,str):    
        to_txt=", to obs: " + to_obs
    elif isinstance(to_obs,int):
        to_txt =", to obs: "+ str(to_obs)
    else:
        to_txt=""
    
    p = figure(x_range=XVar, plot_height=plotheight,plot_width=plotwidth, title="Contributions Plot"+from_txt+to_txt,
                    tools="save,box_zoom,pan,reset")
    p.vbar(x=XVar, top=Xconts[0].tolist(), width=0.5)
    p.ygrid.grid_line_color = None    
    if xgrid:
        p.xgrid.grid_line_color = 'lightgray'
    else:
        p.xgrid.grid_line_color = None   
    p.yaxis.axis_label = 'Contributions to '+cont_type
    hline = Span(location=0, dimension='width', line_color='black', line_width=2)
    p.renderers.extend([hline])
    p.xaxis.major_label_orientation = 45
    p_list=[p]
    
    if not(isinstance(Yconts,np.bool)):
        if 'varidY' in mvmobj:
            YVar=mvmobj['varidY']
        else:
            YVar = []
            for n in list(np.arange(mvmobj['Q'].shape[0])+1):
                YVar.append('YVar #'+str(n))               
        
        p = figure(x_range=YVar, plot_height=plotheight,plot_width=plotwidth, title="Contributions Plot",
                    tools="save,box_zoom,pan,reset")
        p.vbar(x=YVar, top=Yconts[0].tolist(), width=0.5)
        p.ygrid.grid_line_color = None    
        if xgrid:
            p.xgrid.grid_line_color = 'lightgray'
        else:
            p.xgrid.grid_line_color = None   
        p.yaxis.axis_label = 'Contributions to '+cont_type
        hline = Span(location=0, dimension='width', line_color='black', line_width=2)
        p.renderers.extend([hline])
        p.xaxis.major_label_orientation = 45
        p_list.append(p)
        
    show(column(p_list))  
    return

def plot_spectra(X,*,xaxis=False,plot_title='Main Title',tab_title='Tab Title',xaxis_label='X- axis',yaxis_label='Y- axis'): 

0 View Source File : pyphi_plots.py
License : MIT License
Project Creator : salvadorgarciamunoz

def mb_weights(mvmobj,*,plotwidth=600,plotheight=400):
    """
    Super weights for Multi-block models
    by Salvador Garcia-Munoz 
    ([email protected] ,[email protected])
    
    mvmobj: A multi-block PLS model created with phi.mbpls
    """
    A= mvmobj['T'].shape[1]
    lv_prefix='LV #'        
    lv_labels = []   
    for a in list(np.arange(A)+1):
        lv_labels.append(lv_prefix+str(a))    
    XVar=mvmobj['Xblocknames']        
    for i in list(np.arange(A)):
        rnd_num=str(int(np.round(1000*np.random.random_sample())))
        output_file("blockweights_"+rnd_num+".html",title="Block Weights")         
        px = figure(x_range=XVar, title="Block weights for MBPLS"+lv_labels[i],
             tools="save,box_zoom,hover,reset", tooltips=[("Var:","@x_")],plot_width=plotwidth,plot_height=plotheight)   
        source1 = ColumnDataSource(data=dict(x_=XVar, y_=mvmobj['Wt'][:,i].tolist(),names=XVar)) 
        px.vbar(x='x_', top='y_', source=source1,width=0.5)
        px.y_range.range_padding = 0.1
        px.ygrid.grid_line_color = None
        px.axis.minor_tick_line_color = None
        px.outline_line_color = None
        px.yaxis.axis_label = 'Wt'+str(i+1)+']'
        px.xaxis.major_label_orientation = 45  
        hline = Span(location=0, dimension='width', line_color='black', line_width=2)
        px.renderers.extend([hline])
        if i==0:
            p_list=[px]
        else:
            p_list.append(px)
    show(column(p_list))  

    return

        

def mb_r2pb(mvmobj,*,plotwidth=600,plotheight=400):

0 View Source File : pyphi_plots.py
License : MIT License
Project Creator : salvadorgarciamunoz

def mb_vip(mvmobj,*,plotwidth=600,plotheight=400):
    """
    Super weights for Multi-block models
    by Salvador Garcia-Munoz 
    ([email protected] ,[email protected])
    
    mvmobj: A multi-block PLS model created with phi.mbpls
    """
    A= mvmobj['T'].shape[1]
   
    XVar=mvmobj['Xblocknames']        
    Wt=mvmobj['Wt']
    r2y=mvmobj['r2y']
    vip=np.zeros((Wt.shape[0],1))
    if A>1:
        for a in list(range(A)):
            vip=vip+Wt[:,[a]]*r2y[a]
    else:
        vip=Wt[:,[0]]*r2y
        
    vip=np.reshape(vip,-1)
    index=np.argsort(vip)
    index=index[::-1]
    XVar_=[XVar[i] for i in index]
    XVar = XVar_
    vip=vip[index]
    rnd_num=str(int(np.round(1000*np.random.random_sample())))
    output_file("blockvip"+rnd_num+".html",title="Block VIP") 
    source1 = ColumnDataSource(data=dict(x_=XVar, y_=vip.tolist(),names=XVar))         
    px = figure(x_range=XVar, title="Block VIP for MBPLS",
         tools="save,box_zoom,hover,reset",tooltips=[("Block:","@x_")],plot_width=plotwidth,plot_height=plotheight)   
    
    px.vbar(x='x_', top='y_', source=source1,width=0.5)
    px.y_range.range_padding = 0.1
    px.ygrid.grid_line_color = None
    px.axis.minor_tick_line_color = None
    px.outline_line_color = None
    px.yaxis.axis_label = 'Block VIP'
    px.xaxis.major_label_orientation = 45  
    hline = Span(location=0, dimension='width', line_color='black', line_width=2)
    px.renderers.extend([hline])
    show(px)  
    return

0 View Source File : callout.py
License : Apache License 2.0
Project Creator : spotify

    def line(self,
             location,
             orientation='width',
             line_color='black',
             line_dash='solid',
             line_width=2,
             line_alpha=1.0):
        """Add line callout to the chart.

        Args:
            location (numeric):
            orientation (str, optional): (default: 'width')
                - 'width'
                - 'height'
            line_color (str, optional): Color name or hex value.
                See chartify.color_palettes.show() for available color names.
            line_dash (str, optional): Dash style for the line. One of:
                - 'solid'
                - 'dashed'
                - 'dotted'
                - 'dotdash'
                - 'dashdot'
            line_width (int, optional): Width of the line
            line_alpha (float, optional): Alpha of the line. Between 0 and 1.

        Returns:
            Current chart object
        """
        # Convert datetime values to epoch if datetime axis.
        if isinstance(self._chart.axes,
                      DatetimeXNumericalYAxes) and orientation == 'height':
            location = self._chart.axes._convert_timestamp_to_epoch_ms(location)
        line_color = colors.Color(line_color).get_hex_l()
        location_units = 'data'
        span = bokeh.models.Span(
            location=location,
            dimension=orientation,
            line_color=line_color,
            line_dash=line_dash,
            line_width=line_width,
            location_units=location_units,
            line_alpha=line_alpha)
        self._chart.figure.add_layout(span)
        return self._chart

    def line_segment(self,

0 View Source File : benchmark.py
License : GNU General Public License v3.0
Project Creator : thiagopbueno

def plot_total_reward(
    data, colors, group_by=None, filter_regex=None, width=800, height=600
):
    if filter_regex:
        data = data[data["experiment_id"].str.contains(filter_regex, regex=True)]

    if data.empty:
        return None

    factors = []

    for path in data["experiment_id"]:
        folders = path.split("/")

        category_idx, group_by_idx = None, None
        for i, folder in enumerate(folders):
            if category_idx is None and "=" in folder:
                category_idx = i

            if group_by and group_by_idx is None and re.search(f"{group_by}=", folder):
                group_by_idx = i

        category = "/".join(folders[:category_idx])

        if group_by_idx is not None:
            subcategory = folders[group_by_idx]
            del folders[group_by_idx]
            name = "/".join(folders[category_idx:])
            factors.append((category, subcategory, name))
        else:
            name = "/".join(folders[category_idx:])
            factors.append((category, name))

    p = figure(
        title="Total Rewards",
        toolbar_location="above",
        y_range=FactorRange(*factors),
        plot_height=height,
        plot_width=width,
    )

    p.title.text_font_size = "12pt"
    p.yaxis.major_label_text_font_size = "9pt"
    p.yaxis.group_text_font_size = "10pt"

    p.hbar(y=factors, right=data["mean"], height=0.6)

    mean_value = data["mean"].mean()
    max_value = data["mean"].max()
    vline1 = Span(
        location=mean_value, dimension="height", line_color="red", line_width=2
    )
    vline2 = Span(
        location=max_value, dimension="height", line_color="green", line_width=2
    )
    p.renderers.extend([vline1, vline2])

    return p


@st.cache

0 View Source File : traces_visualizer.py
License : GNU General Public License v3.0
Project Creator : thiagopbueno

def plot_total_reward_per_run(dataframes_dict):
    x = [str(run) for run in np.arange(0, len(dataframes_dict))]
    total_rewards = np.empty((len(dataframes_dict),))

    for filepath, df in dataframes_dict.items():
        run_regex = re.search(r".*/run(.*)/.*", filepath)
        run = int(run_regex.group(1))
        total_rewards[run] = df.sum()

    p = figure(
        title="Total Reward per Episode",
        toolbar_location="above",
        x_axis_label="Runs",
        x_range=x,
        plot_height=400,
        plot_width=700,
        background_fill_color="#fafafa",
    )

    p.title.text_font_size = "14pt"
    p.xaxis.axis_label_text_font_size = "12pt"

    p.vbar(x=x, top=total_rewards, width=0.6)

    mean = np.mean(total_rewards)
    hline = Span(location=mean, dimension="width", line_color="red", line_width=2)
    p.renderers.append(hline)

    return p


def plot_total_reward_histogram(dataframes):

0 View Source File : traces_visualizer.py
License : GNU General Public License v3.0
Project Creator : thiagopbueno

def plot_total_reward_histogram(dataframes):
    total_rewards = [df.sum() for df in dataframes]
    mean, std = np.mean(total_rewards), np.std(total_rewards)
    lower = mean - std
    upper = mean + std

    hist, edges = np.histogram(total_rewards, density=True, bins=30)

    p = figure(
        title="Histogram",
        toolbar_location="above",
        x_axis_label="Total Reward per Episode",
        plot_height=400,
        plot_width=700,
        background_fill_color="#fafafa",
    )

    p.title.text_font_size = "14pt"
    p.y_range.start = 0
    p.xaxis.axis_label_text_font_size = "12pt"
    p.xaxis.major_label_text_font_size = "11pt"
    p.yaxis.major_label_text_font_size = "11pt"

    p.quad(top=hist, bottom=0, left=edges[:-1], right=edges[1:])
    vline1 = Span(location=mean, dimension="height", line_color="red", line_width=2)
    vline2 = Span(location=lower, dimension="height", line_color="green", line_width=2)
    vline3 = Span(location=upper, dimension="height", line_color="green", line_width=2)
    p.renderers.extend([vline1, vline2, vline3])

    return p


def plot_rewards_per_run(dataframes):

0 View Source File : bokeh_waveform_plot.py
License : BSD 3-Clause "New" or "Revised" License
Project Creator : XENONnT

def plot_event(peaks, signal, labels, event, colors, yscale='linear'):
    """
    Wrapper for plot peaks to highlight main/alt. S1/S2

    :param peaks: Peaks in event
    :param signal: Dictionary containing main/alt. S1/S2
    :param labels: dict with labels to be used
    :param event: Event to set correctly x-ranges.
    :param colors: Colors to be used for unknown, s1 and s2 signals.
    :param yscale: string of yscale type.

    :return: bokeh.plotting.figure instance
    """
    waveform = plot_peaks(peaks, time_scalar=1000, colors=colors, yscale=yscale)
    # Highlight main and alternate S1/S2:
    start = peaks[0]['time']
    end = strax.endtime(peaks)[-1]
    # Workaround did not manage to scale via pixels...
    ymax = np.max((peaks['data'].T / peaks['dt']).T)
    ymax -= 0.1 * ymax
    for s, p in signal.items():
        if p.shape[0]:
            pos = (p[0]['center_time'] - start) / 1000
            main = bokeh.models.Span(location=pos,
                                     dimension='height',
                                     line_alpha=0.6,
                                     )
            vline_label = bokeh.models.Label(x=pos,
                                             y=ymax,
                                             angle=np.pi / 2,
                                             text=labels[s],
                                             )
            if 'alt' in s:
                main.line_dash = 'dotted'
            else:
                main.line_dash = 'dashed'
            waveform.add_layout(main)
            waveform.add_layout(vline_label)

    # Get some meaningful x-range limit to 10% left and right extending
    # beyond first last peak, clip at event boundary.
    length = (end - start) / 10**3

    waveform.x_range.start = max(-0.1 * length, (event['time'] - start) / 10**3)
    waveform.x_range.end = min(1.1 * length, (event['endtime'] - start) / 10**3)
    return waveform


def plot_peak_detail(peak,