python/ash2shukla/streamlit-bokeh-events/streamlit_bokeh_events/__init__.py

streamlit_bokeh_events
__init__.py
from bokeh.embed import json_item
import json

import os
from random import choices
from string import ascii_letters
import streamlit.components.v1 as components

_RELEASE = True

if not _RELEASE:
    _component_func = components.declare_component(
        "streamlit_bokeh_events", url="http://localhost:3001",
    )
else:
    parent_dir = os.path.dirname(os.path.abspath(__file__))
    build_dir = os.path.join(parent_dir, "frontend/build")
    _component_func = components.declare_component("streamlit_bokeh_events", path=build_dir)


def streamlit_bokeh_events(bokeh_plot=None, events="", key=None, debounce_time=1000, refresh_on_update=True, override_height=None):
    """Returns event dict

    Keyword arguments:
    bokeh_plot -- Bokeh figure object (default None)
    events -- Comma separated list of events dispatched by bokeh eg. "event1,event2,event3" (default "")
    debounce_time -- Time in ms to wait before dispatching latest event (default 1000)
    refresh_on_update -- Should the chart be re-rendered on refresh (default False)
        : Set to False if you are not updating the datasource at runtime
    override_height -- Override plot viewport height
    """
    if key is None:
        raise ValueError("key can not be None.")

    div_id = "".join(choices(ascii_letters, k=16))
    fig_dict = json_item(bokeh_plot, div_id)
    json_figure = json.dumps(fig_dict)
    component_value = _component_func(
        bokeh_plot=json_figure,
        events=events,
        key=key,
        _id=div_id,
        default=None,
        debounce_time=debounce_time,
        refresh_on_update=refresh_on_update,
        override_height=override_height
    )
    return component_value


if not _RELEASE:
    import streamlit as st
    import pandas as pd
    from bokeh.plotting import figure
    from bokeh.models import ColumnDataSource, CustomJS
    from bokeh.models import DataTable, TableColumn
    from bokeh.plotting import figure

    st.set_page_config(layout="wide")
    # import function
    # from streamlit_bokeh_events import streamlit_bokeh_events
    col1, col2 = st.beta_columns(2)
    df = pd.read_csv('https://raw.githubusercontent.com/mwaskom/seaborn-data/master/iris.csv')
    # create plot
    cds = ColumnDataSource(df)
    columns = list(map(lambda colname: TableColumn(field=colname, title=colname), df.columns))

    # define events
    cds.selected.js_on_change(
        "indices",
        CustomJS(
            args=dict(source=cds),
            code="""
            document.dispatchEvent(
                new CustomEvent("INDEX_SELECT", {detail: {data: source.selected.indices}})
            )
            """
        )
    )

    table = DataTable(source=cds, columns=columns)
    with col1:
        result = streamlit_bokeh_events(
            bokeh_plot=table,
            events="INDEX_SELECT",
            key="foo",
            refresh_on_update=False,
            debounce_time=0,
            override_height=500
        )
        if result:
            if result.get("INDEX_SELECT"):
                st.write(df.iloc[result.get("INDEX_SELECT")["data"]])

    plot = figure(tools="lasso_select,zoom_in")
    df["colors"] = df.species.replace({"setosa": "#583d72", "versicolor": "#9f5f80", "virginica": "#ffba93"})
    cds_lasso = ColumnDataSource(df)
    cds_lasso.selected.js_on_change(
        "indices",
        CustomJS(
            args=dict(source=cds_lasso),
            code="""
            document.dispatchEvent(
                new CustomEvent("LASSO_SELECT", {detail: {data: source.selected.indices}})
            )
            """
        )
    )

    plot.circle("sepal_length", "sepal_width", fill_alpha=0.5, color="colors", size=10, line_color=None, source=cds_lasso)
    with col2:
        result_lasso = streamlit_bokeh_events(
            bokeh_plot=plot,
            events="LASSO_SELECT",
            key="bar",
            refresh_on_update=False,
            debounce_time=0)
        if result_lasso:
            if result_lasso.get("LASSO_SELECT"):
                st.write(df.iloc[result_lasso.get("LASSO_SELECT")["data"]])