python/JaneliaSciComp/SongExplorer/src/gui/view.py

view.py
import os
from bokeh.models.widgets import RadioButtonGroup, TextInput, Button, Div, DateFormatter, TextAreaInput, Select, NumberFormatter, Slider, Toggle, ColorPicker, MultiSelect
from bokeh.models.formatters import FuncTickFormatter
from bokeh.models import ColumnDataSource, TableColumn, DataTable, LayoutDOM, Span
from bokeh.plotting import figure
from bokeh.transform import linear_cmap
from bokeh.events import Tap, DoubleTap, PanStart, Pan, PanEnd, ButtonClick, MouseWheel
from bokeh.models.callbacks import CustomJS
from bokeh.models.markers import Circle
import numpy as np
import glob
from datetime import datetime
import markdown
import pandas as pd
import wave
import scipy.io.wavfile as spiowav
from scipy.signal import decimate, spectrogram
import logging 
import base64
import io
from natsort import natsorted
import pims
import av
from bokeh import palettes
from itertools import cycle, product
import ast
from bokeh.core.properties import Instance, String, List, Float
from bokeh.util.compiler import TypeScript
import asyncio
from collections import OrderedDict

bokehlog = logging.getLogger("songexplorer") 

import model as M
import controller as C

bokeh_document, cluster_dot_palette, snippet_palette, p_cluster, cluster_dots, p_snippets, snippets_label_sources_clustered, snippets_label_sources_annotated, snippets_wave_sources, snippets_wave_glyphs, snippets_gram_sources, snippets_gram_glyphs, snippets_quad_grey, dot_size_cluster, dot_alpha_cluster, cluster_circle_fuchsia, p_waveform, p_spectrogram, p_probability, probability_source, probability_glyph, spectrogram_source, spectrogram_glyph, waveform_span_red, spectrogram_span_red, waveform_quad_grey_clustered, waveform_quad_grey_annotated, waveform_quad_grey_pan, waveform_quad_fuchsia, spectrogram_quad_grey_clustered, spectrogram_quad_grey_annotated, spectrogram_quad_grey_pan, spectrogram_quad_fuchsia, snippets_quad_fuchsia, waveform_source, waveform_glyph, waveform_label_source_clustered, waveform_label_source_annotated, spectrogram_label_source_clustered, spectrogram_label_source_annotated, which_layer, which_species, which_word, which_nohyphen, which_kind, color_picker, circle_radius, dot_size, dot_alpha, zoom_context, zoom_offset, zoomin, zoomout, reset, panleft, panright, allleft, allout, allright, save_indicator, label_count_widgets, label_text_widgets, play, play_callback, video_toggle, video_div, undo, redo, detect, misses, configuration_file, train, leaveoneout, leaveallout, xvalidate, mistakes, activations, cluster, visualize, accuracy, freeze, classify, ethogram, compare, congruence, status_ticker, waitfor, file_dialog_source, file_dialog_source, configuration_contents, logs, logs_folder, model, model_file, wavtfcsvfiles, wavtfcsvfiles_string, groundtruth, groundtruth_folder, validationfiles, testfiles, validationfiles_string, testfiles_string, wantedwords, wantedwords_string, labeltypes, labeltypes_string, prevalences, prevalences_string, copy, labelsounds, makepredictions, fixfalsepositives, fixfalsenegatives, generalize, tunehyperparameters, findnovellabels, examineerrors, testdensely, doit, time_sigma_string, time_smooth_ms_string, frequency_n_ms_string, frequency_nw_string, frequency_p_string, frequency_smooth_ms_string, nsteps_string, restore_from_string, save_and_validate_period_string, validate_percentage_string, mini_batch_string, kfold_string, activations_equalize_ratio_string, activations_max_samples_string, pca_fraction_variance_to_retain_string, tsne_perplexity_string, tsne_exaggeration_string, umap_neighbors_string, umap_distance_string, cluster_algorithm, cluster_these_layers, precision_recall_ratios_string, context_ms_string, shiftby_ms_string, representation, window_ms_string, stride_ms_string, mel_dct_string, optimizer, learning_rate_string, replicates_string, batch_seed_string, weights_seed_string, file_dialog_string, file_dialog_table, readme_contents, wordcounts, wizard_buttons, action_buttons, parameter_buttons, parameter_textinputs, wizard2actions, action2parameterbuttons, action2parametertextinputs, model_parameters = [None]*164

class ScatterNd(LayoutDOM):

    __implementation__ = TypeScript("""
import {LayoutDOM, LayoutDOMView} from "models/layouts/layout_dom"
import {ColumnDataSource} from "models/sources/column_data_source"
import {LayoutItem} from "core/layout"
import * as p from "core/properties"

declare namespace Plotly {
  class newPlot { constructor(el: HTMLElement, data: object, OPTIONS: object) }
}

let OPTIONS2 = {
  margin: { l: 0, r: 0, b: 0, t: 0 },
  showlegend: false,
  xaxis: { visible: false },
  yaxis: { visible: false },
  hovermode: 'closest',
  shapes: [ {
      type: 'circle',
      xref: 'x', yref: 'y',
      x0: 0, y0: 0,
      x1: 0, y1: 0,
      line: { color: 'fuchsia' } } ]
}
let OPTIONS3 = {
  margin: { l: 0, r: 0, b: 0, t: 0 },
  hovermode: 'closest',
  hoverlabel: { bgcolor: 'white' },
  showlegend: false,
  scene: {
    xaxis: { visible: false },
    yaxis: { visible: false },
    zaxis: { visible: false },
  },
  shapes: [],
}

// https://github.com/caosdoar/spheres
let icosphere12 = [[0.525731, 0.850651, 0]]
let icosphere42 = icosphere12.slice().concat([[0.809017, 0.5, 0.309017],
                                              [0, 0, 1]])
let icosphere162 = icosphere42.slice().concat([[0.69378, 0.702046, 0.160622],
                                               [0.587785, 0.688191, 0.425325],
                                               [0.433889, 0.862668, 0.259892],
                                               [0.273267, 0.961938, 0],
                                               [0.16246, 0.951057, 0.262866]])

// @ts-ignore
let xicosphere = []
// @ts-ignore
let yicosphere = []
// @ts-ignore
let zicosphere = []
icosphere162.forEach((x)=>{
  // @ts-ignore
  let V = []
  for (let i=1; i==1 || i==-1 && x[0]>0; i-=2) {
    for (let j=1; j==1 || j==-1 && x[1]>0; j-=2) {
      for (let k=1; k==1 || k==-1 && x[2]>0; k-=2) {
        // @ts-ignore
        V = V.concat([i*x[0], j*x[1], k*x[2]])
      }
    }
  }
  // @ts-ignore
  xicosphere = xicosphere.concat(V)
  V.push(V.shift())
  // @ts-ignore
  yicosphere = yicosphere.concat(V)
  V.push(V.shift())
  // @ts-ignore
  zicosphere = zicosphere.concat(V)
});

export class ScatterNdView extends LayoutDOMView {
  model: ScatterNd

  initialize(): void {
    super.initialize()

    const url = "https://cdn.plot.ly/plotly-latest.min.js"
    const script = document.createElement("script")
    script.onload = () => this._init()
    script.async = false
    script.src = url
    document.head.appendChild(script)
  }

  ndims() {
    if (this.model.dots_source.data[this.model.dz].length==0) {
      return 0 }
    else if (isNaN(this.model.dots_source.data[this.model.dz][0])) {
      return 2 }
    return 3
  }

  get_dots_data() {
    return {x: this.model.dots_source.data[this.model.dx],
            y: this.model.dots_source.data[this.model.dy],
            z: this.model.dots_source.data[this.model.dz],
            text: this.model.dots_source.data[this.model.dl],
            marker: {
              color: this.model.dots_source.data[this.model.dc],
              size: this.model.dot_size_source.data[this.model.ds][0],
              opacity: this.model.dot_alpha_source.data[this.model.da][0],
            }
           };
  }

  set_circle_fuchsia_data2() {
    if (this.model.circle_fuchsia_source.data[this.model.cx].length==0) {
      OPTIONS2.shapes[0].x0 = 0
      OPTIONS2.shapes[0].y0 = 0
      OPTIONS2.shapes[0].x1 = 0
      OPTIONS2.shapes[0].y1 = 0 }
    else {
      OPTIONS2.shapes[0].line.color = this.model.circle_fuchsia_source.data[this.model.cc][0]
      let x = this.model.circle_fuchsia_source.data[this.model.cx][0]
      let y = this.model.circle_fuchsia_source.data[this.model.cy][0]
      let r = this.model.circle_fuchsia_source.data[this.model.cr][0]
      OPTIONS2.shapes[0].x0 = x-r
      OPTIONS2.shapes[0].y0 = y-r
      OPTIONS2.shapes[0].x1 = x- -r
      OPTIONS2.shapes[0].y1 = y- -r }
  }

  get_circle_fuchsia_data3() {
    if (this.model.circle_fuchsia_source.data[this.model.cx].length==0) {
      return {type: 'mesh3d',
              x:[0], y:[0], z:[0],
             }; }
    else {
      let radius = this.model.circle_fuchsia_source.data[this.model.cr][0]
      return {type: 'mesh3d',
              // @ts-ignore
              x: xicosphere.map(x=>x*radius+this.model.circle_fuchsia_source.data[this.model.cx][0]),
              // @ts-ignore
              y: yicosphere.map(x=>x*radius+this.model.circle_fuchsia_source.data[this.model.cy][0]),
              // @ts-ignore
              z: zicosphere.map(x=>x*radius+this.model.circle_fuchsia_source.data[this.model.cz][0]),
             }; }
  }

  private _init(): void {
    new Plotly.newPlot(this.el,
                       [{alphahull: 1.0,
                         opacity: 0.2,
                        },
                        {hovertemplate: "%{text} < extra> < /extra>",
                         mode: 'markers',
                        }],
                       {xaxis: { visible: false },
                        yaxis: { visible: false } });

    this.connect(this.model.dots_source.change, () => {
      let new_data = this.get_dots_data()
      let N = this.ndims()
      if (N==2) {
        this.set_circle_fuchsia_data2()
        // @ts-ignore
        Plotly.update(this.el, {type: '', x:[[]], y:[[]], z:[[]]}, OPTIONS2, [0]);
        // @ts-ignore
        Plotly.update(this.el,
                      {type: 'scatter',
                       x: [new_data['x']], y: [new_data['y']],
                       text: [new_data['text']],
                       marker: new_data['marker'] },
                      OPTIONS2,
                      [1]);
      }
      else if (N==3) {
        // @ts-ignore
        Plotly.update(this.el, {type: 'mesh3d', x:[[]], y:[[]], z:[[]]}, OPTIONS3, [0]);
        // @ts-ignore
        Plotly.update(this.el,
                       {type: 'scatter3d',
                        x: [new_data['x']], y: [new_data['y']], z: [new_data['z']],
                        text: [new_data['text']],
                        marker: new_data['marker'] },
                       OPTIONS3,
                       [1]);
      }
    });

    this.connect(this.model.dot_size_source.change, () => {
      let new_data = this.get_dots_data()
      // @ts-ignore
      Plotly.restyle(this.el, { marker: new_data['marker'] }, [1]);
    });

    this.connect(this.model.dot_alpha_source.change, () => {
      let new_data = this.get_dots_data()
      // @ts-ignore
      Plotly.restyle(this.el, { marker: new_data['marker'] }, [1]);
    });

    // @ts-ignore
    ( < HTMLDivElement>this.el).on('plotly_click', (data) => {
      let N = this.ndims()
      if (N==2) {
        // @ts-ignore
        this.model.click_position = [data.points[0].x,data.points[0].y] }
      else if (N==3) {
        // @ts-ignore
        this.model.click_position = [data.points[0].x,data.points[0].y,data.points[0].z] }
    });

    this.connect(this.model.circle_fuchsia_source.change, () => {
      let N = this.ndims()
      if (N==2) {
        this.set_circle_fuchsia_data2()
        // @ts-ignore
        Plotly.relayout(this.el, OPTIONS2); }
      else if (N==3) {
        let new_data = this.get_circle_fuchsia_data3()
        // @ts-ignore
        Plotly.restyle(this.el,
                       {x: [new_data['x']], y: [new_data['y']], z: [new_data['z']],
                        color: this.model.circle_fuchsia_source.data[this.model.cc][0]},
                       [0]); }
    });
  }

  get child_models(): LayoutDOM[] { return [] }

  _update_layout(): void {
    this.layout = new LayoutItem()
    this.layout.set_sizing(this.box_sizing())
  }
}

export namespace ScatterNd {
  export type Attrs = p.AttrsOf < Props>

  export type Props = LayoutDOM.Props & {
    cx: p.Property < string>
    cy: p.Property < string>
    cz: p.Property < string>
    cr: p.Property < string>
    cc: p.Property < string>
    dx: p.Property < string>
    dy: p.Property < string>
    dz: p.Property < string>
    dl: p.Property < string>
    dc: p.Property < string>
    ds: p.Property < string>
    da: p.Property < string>
    click_position: p.Property < number[]>
    circle_fuchsia_source: p.Property < ColumnDataSource>
    dots_source: p.Property < ColumnDataSource>
    dot_size_source: p.Property < ColumnDataSource>
    dot_alpha_source: p.Property < ColumnDataSource>
  }
}

export interface ScatterNd extends ScatterNd.Attrs {}

export class ScatterNd extends LayoutDOM {
  properties: ScatterNd.Props

  constructor(attrs?: Partial < ScatterNd.Attrs>) { super(attrs) }

  static __name__ = "ScatterNd"

  static init_ScatterNd() {
    this.prototype.default_view = ScatterNdView

    this.define < ScatterNd.Props>({
      cx: [ p.String   ],
      cy: [ p.String   ],
      cz: [ p.String   ],
      cr: [ p.String   ],
      cc: [ p.String   ],
      dx: [ p.String   ],
      dy: [ p.String   ],
      dz: [ p.String   ],
      dl: [ p.String   ],
      dc: [ p.String   ],
      ds: [ p.String   ],
      da: [ p.String   ],
      click_position:  [ p.Array   ],
      circle_fuchsia_source: [ p.Instance ],
      dots_source: [ p.Instance ],
      dot_size_source: [ p.Instance ],
      dot_alpha_source: [ p.Instance ],
    })
  }
}
"""
)

    cx = String
    cy = String
    cz = String
    cr = String
    cc = String

    dx = String
    dy = String
    dz = String
    dl = String
    dc = String
    ds = String
    da = String

    click_position = List(Float)

    circle_fuchsia_source = Instance(ColumnDataSource)
    dots_source = Instance(ColumnDataSource)
    dot_size_source = Instance(ColumnDataSource)
    dot_alpha_source = Instance(ColumnDataSource)

def cluster_initialize(newcolors=True):
    global precomputed_dots
    global p_cluster_xmax, p_cluster_ymax, p_cluster_zmax
    global p_cluster_xmin, p_cluster_ymin, p_cluster_zmin

    cluster_file = os.path.join(groundtruth_folder.value,'cluster.npz')
    if not os.path.isfile(cluster_file):
        bokehlog.info("ERROR: "+cluster_file+" not found")
        return False
    npzfile = np.load(cluster_file, allow_pickle=True)
    M.clustered_samples = npzfile['samples']
    M.clustered_activations = npzfile['activations_clustered']

    M.clustered_starts_sorted = [x['ticks'][0] for x in M.clustered_samples]
    isort = np.argsort(M.clustered_starts_sorted)
    for i in range(len(M.clustered_activations)):
        if M.clustered_activations[i] is not None:
            layer0 = i
            M.clustered_activations[i] = M.clustered_activations[i][isort,:]
    M.clustered_samples = [M.clustered_samples[x] for x in isort]
    M.clustered_starts_sorted = [M.clustered_starts_sorted[x] for x in isort]

    M.clustered_stops = [x['ticks'][1] for x in M.clustered_samples]
    M.iclustered_stops_sorted = np.argsort(M.clustered_stops)

    cluster_isnotnan = [not np.isnan(x[0]) and not np.isnan(x[1]) \
                        for x in M.clustered_activations[layer0]]

    M.nlayers = len(M.clustered_activations)
    M.ndcluster = np.shape(M.clustered_activations[layer0])[1]
    cluster_dots.data.update(dx=[], dy=[], dz=[], dl=[], dc=[])
    cluster_circle_fuchsia.data.update(cx=[], cy=[], cz=[], cr=[], cc=[])

    M.layers = ["input"]+["hidden #"+str(i) for i in range(1,M.nlayers-1)]+["output"]
    M.species = set([x['label'].split('-')[0]+'-' \
                     for x in M.clustered_samples if '-' in x['label']])
    M.species |= set([''])
    M.species = natsorted(list(M.species))
    M.words = set(['-'+x['label'].split('-')[1] \
                   for x in M.clustered_samples if '-' in x['label']])
    M.words |= set([''])
    M.words = natsorted(list(M.words))
    M.nohyphens = set([x['label'] for x in M.clustered_samples if '-' not in x['label']])
    M.nohyphens |= set([''])
    M.nohyphens = natsorted(list(M.nohyphens))
    M.kinds = set([x['kind'] for x in M.clustered_samples])
    M.kinds |= set([''])
    M.kinds = natsorted(list(M.kinds))

    if newcolors:
        allcombos = [x[0][:-1]+x[1] for x in product(M.species[1:], M.words[1:])]
        M.cluster_dot_colors = { l:c for l,c in zip(allcombos+ M.nohyphens[1:],
                                                    cycle(cluster_dot_palette)) }
    M.clustered_labels = set([x['label'] for x in M.clustered_samples])

    p_cluster_xmin, p_cluster_xmax = [0]*M.nlayers, [0]*M.nlayers
    p_cluster_ymin, p_cluster_ymax = [0]*M.nlayers, [0]*M.nlayers
    p_cluster_zmin, p_cluster_zmax = [0]*M.nlayers, [0]*M.nlayers
    precomputed_dots = [None]*M.nlayers
    for ilayer in range(M.nlayers):
        precomputed_dots[ilayer] = [None]*len(M.species)
        if M.clustered_activations[ilayer] is not None:
            p_cluster_xmin[ilayer] = np.min(M.clustered_activations[ilayer][:,0])
            p_cluster_xmax[ilayer] = np.max(M.clustered_activations[ilayer][:,0])
            p_cluster_ymin[ilayer] = np.min(M.clustered_activations[ilayer][:,1])
            p_cluster_ymax[ilayer] = np.max(M.clustered_activations[ilayer][:,1])
            if M.ndcluster==3:
                p_cluster_zmin[ilayer] = np.min(M.clustered_activations[ilayer][:,2])
                p_cluster_zmax[ilayer] = np.max(M.clustered_activations[ilayer][:,2])
        for (ispecies,specie) in enumerate(M.species):
            precomputed_dots[ilayer][ispecies] = [None]*len(M.words)
            for (iword,word) in enumerate(M.words):
                precomputed_dots[ilayer][ispecies][iword] = [None]*len(M.nohyphens)
                for (inohyphen,nohyphen) in enumerate(M.nohyphens):
                    precomputed_dots[ilayer][ispecies][iword][inohyphen] = \
                            [None]*len(M.kinds)
                    for (ikind,kind) in enumerate(M.kinds):
                        if inohyphen!=0 and (ispecies!=0 or iword!=0):
                            continue
                        if M.clustered_activations[ilayer] is None:
                            continue
                        M.ilayer=ilayer
                        bidx = np.logical_and([specie in x['label'] and \
                                               word in x['label'] and \
                                               (nohyphen=="" or nohyphen==x['label']) and \
                                               (kind=="" or kind==x['kind']) \
                                               for x in M.clustered_samples], \
                                               cluster_isnotnan)
                        if not any(bidx):
                            continue
                        if inohyphen>0:
                            colors = [M.cluster_dot_colors[nohyphen] for b in bidx if b]
                        else:
                            colors = [M.cluster_dot_colors[x['label']] \
                                      if x['label'] in M.cluster_dot_colors else "black" \
                                      for x,b in zip(M.clustered_samples,bidx) if b]
                        data = {'x': M.clustered_activations[ilayer][bidx,0], \
                                'y': M.clustered_activations[ilayer][bidx,1], \
                                'l': [x['label'] for x,b in zip(M.clustered_samples,bidx) if b], \
                                'c': colors }
                        if M.ndcluster==2:
                            data['z'] = [np.nan]*len(M.clustered_activations[ilayer][bidx,1])
                        else:
                            data['z'] = M.clustered_activations[ilayer][bidx,2]
                        precomputed_dots[ilayer][ispecies][iword][inohyphen][ikind] = data

    which_layer.options = M.layers
    which_species.options = M.species
    which_word.options = M.words
    which_nohyphen.options = M.nohyphens
    which_kind.options = M.kinds

    circle_radius.disabled=False
    dot_size.disabled=False
    dot_alpha.disabled=False

    M.ispecies=0
    M.iword=0
    M.inohyphen=0
    M.ikind=0

    return True

def cluster_update():
    global cluster_dots
    global p_cluster_xmax, p_cluster_xmin, p_cluster_ymax, p_cluster_ymin
    dot_alpha.disabled=False
    if precomputed_dots == None:
        return
    selected_dots = precomputed_dots[M.ilayer][M.ispecies][M.iword][M.inohyphen][M.ikind]
    if selected_dots is None:
        kwargs = dict(dx=[0,0,0,0,0,0,0,0],
                      dy=[0,0,0,0,0,0,0,0],
                      dz=[0,0,0,0,0,0,0,0],
                      dl=['', '', '', '', '', '', '', ''],
                      dc=['#ffffff00', '#ffffff00', '#ffffff00', '#ffffff00',
                          '#ffffff00', '#ffffff00', '#ffffff00', '#ffffff00'])
    else:
        kwargs = dict(dx=[*selected_dots['x'],
                          p_cluster_xmin[M.ilayer], p_cluster_xmin[M.ilayer],
                          p_cluster_xmin[M.ilayer], p_cluster_xmin[M.ilayer],
                          p_cluster_xmax[M.ilayer], p_cluster_xmax[M.ilayer],
                          p_cluster_xmax[M.ilayer], p_cluster_xmax[M.ilayer]],
                      dy=[*selected_dots['y'],
                          p_cluster_ymin[M.ilayer], p_cluster_ymin[M.ilayer],
                          p_cluster_ymax[M.ilayer], p_cluster_ymax[M.ilayer],
                          p_cluster_ymin[M.ilayer], p_cluster_ymin[M.ilayer],
                          p_cluster_ymax[M.ilayer], p_cluster_ymax[M.ilayer]],
                      dz=[*selected_dots['z'],
                          p_cluster_zmin[M.ilayer], p_cluster_zmax[M.ilayer],
                          p_cluster_zmin[M.ilayer], p_cluster_zmax[M.ilayer],
                          p_cluster_zmin[M.ilayer], p_cluster_zmax[M.ilayer],
                          p_cluster_zmin[M.ilayer], p_cluster_zmax[M.ilayer]],
                      dl=[*selected_dots['l'], '', '', '', '', '', '', '', ''],
                      dc=[*selected_dots['c'],
                          '#ffffff00', '#ffffff00', '#ffffff00', '#ffffff00',
                          '#ffffff00', '#ffffff00', '#ffffff00', '#ffffff00'])
    cluster_dots.data.update(**kwargs)
    extent = min(p_cluster_xmax[M.ilayer] - p_cluster_xmin[M.ilayer],
                 p_cluster_ymax[M.ilayer] - p_cluster_ymin[M.ilayer])
    if M.ndcluster==3:
        extent = min(extent, p_cluster_zmax[M.ilayer] - p_cluster_zmin[M.ilayer])
    circle_radius.end = max(np.finfo(np.float32).eps, extent)
    circle_radius.step = extent/100
    #npoints = np.shape(M.clustered_activations[M.ilayer])[0]
    #dot_size.value = max(1, round(100 * extent / np.sqrt(npoints)))

def within_an_annotation(sample):
    if len(M.annotated_starts_sorted)>0:
        ifrom = np.searchsorted(M.annotated_starts_sorted, sample['ticks'][0],
                                side='right') - 1
        if 0  < = ifrom and ifrom  <  len(M.annotated_starts_sorted) and \
                    M.annotated_samples[ifrom]['ticks'][1] >= sample['ticks'][1]:
            return ifrom
    return -1

def snippets_update(redraw_wavs):
    if len(M.species)==0:
        return
    if M.isnippet>0 and not np.isnan(M.xcluster) and not np.isnan(M.ycluster) \
                and (M.ndcluster==2 or not np.isnan(M.zcluster)):
        snippets_quad_fuchsia.data.update(
                left=[M.xsnippet*(M.snippets_gap_pix+M.snippets_pix)],
                right=[(M.xsnippet+1)*(M.snippets_gap_pix+M.snippets_pix)-
                       M.snippets_gap_pix],
                top=[-M.ysnippet*snippets_dy+1],
                bottom=[-M.ysnippet*snippets_dy-1 \
                        -2*(M.snippets_waveform and M.snippets_spectrogram)])
    else:
        snippets_quad_fuchsia.data.update(left=[], right=[], top=[], bottom=[])

    isubset = np.where([M.species[M.ispecies] in x['label'] and
                      M.words[M.iword] in x['label'] and
                      (M.nohyphens[M.inohyphen]=="" or \
                       M.nohyphens[M.inohyphen]==x['label']) and
                      (M.kinds[M.ikind]=="" or \
                       M.kinds[M.ikind]==x['kind']) for x in M.clustered_samples])[0]
    origin = [M.xcluster,M.ycluster]
    if M.ndcluster==3:
        origin.append(M.zcluster)
    distance = [] if M.clustered_activations[M.ilayer] is None else \
               np.linalg.norm(M.clustered_activations[M.ilayer][isubset,:] - origin, \
                              axis=1)
    isort = np.argsort(distance)
    ywavs, scales = [], []
    gram_freqs, gram_times, gram_images, ilows, ihighs  = [], [], [], [], []
    labels_clustered, labels_annotated = [], []
    for isnippet in range(M.snippets_nx*M.snippets_ny):
        if isnippet < len(distance) and \
                    distance[isort[isnippet]]  <  float(M.state["circle_radius"]):
            M.nearest_samples[isnippet] = isubset[isort[isnippet]]
            thissample = M.clustered_samples[M.nearest_samples[isnippet]]
            labels_clustered.append(thissample['label'])
            iannotated = within_an_annotation(thissample)
            if iannotated == -1:
                labels_annotated.append('')
            else:
                labels_annotated.append(M.annotated_samples[iannotated]['label'])
            midpoint = np.mean(thissample['ticks'], dtype=int)
            if redraw_wavs:
                _, wavs = spiowav.read(thissample['file'], mmap=True)
                if np.ndim(wavs)==1:
                  wavs = np.expand_dims(wavs, axis=1)
                start_frame = max(0, midpoint-M.snippets_tic//2)
                nframes_to_get = min(np.shape(wavs)[0] - start_frame,
                                     M.snippets_tic+1,
                                     M.snippets_tic+1+(midpoint-M.snippets_tic//2))
                left_pad = max(0, M.snippets_pix-nframes_to_get if start_frame==0 else 0)
                right_pad = max(0, M.snippets_pix-nframes_to_get if start_frame>0 else 0)
                ywav = [[]]*M.audio_nchannels
                scale = [[]]*M.audio_nchannels
                gram_freq = [[]]*M.audio_nchannels
                gram_time = [[]]*M.audio_nchannels
                gram_image = [[]]*M.audio_nchannels
                ilow = [[]]*M.audio_nchannels
                ihigh = [[]]*M.audio_nchannels
                for ichannel in range(M.audio_nchannels):
                    wavi = wavs[start_frame : start_frame+nframes_to_get, ichannel]
                    if M.snippets_waveform:
                        wavi_downsampled = decimate(wavi, M.snippets_decimate_by,
                                        n=M.filter_order,
                                        ftype='iir', zero_phase=True)
                        np.pad(wavi_downsampled, ((left_pad, right_pad),),
                               'constant', constant_values=(np.nan,))
                        wavi_trimmed = wavi_downsampled[:M.snippets_pix]
                        scale[ichannel]=np.minimum(np.iinfo(np.int16).max-1,
                                                   np.max(np.abs(wavi_trimmed)))
                        ywav[ichannel]=wavi_trimmed/scale[ichannel]
                    if M.snippets_spectrogram:
                        window_length = round(M.spectrogram_length_ms[ichannel]/1000*M.audio_tic_rate)
                        gram_freq[ichannel], gram_time[ichannel], gram_image[ichannel] = \
                                spectrogram(wavi,
                                            fs=M.audio_tic_rate,
                                            window=M.spectrogram_window,
                                            nperseg=window_length,
                                            noverlap=round(window_length*M.spectrogram_overlap))
                        ilow[ichannel] = np.argmin(np.abs(gram_freq[ichannel] - \
                                                          M.spectrogram_low_hz[ichannel]))
                        ihigh[ichannel] = np.argmin(np.abs(gram_freq[ichannel] - \
                                                           M.spectrogram_high_hz[ichannel]))
                ywavs.append(ywav)
                scales.append(scale)
                gram_freqs.append(gram_freq)
                gram_times.append(gram_time)
                gram_images.append(gram_image)
                ilows.append(ilow)
                ihighs.append(ihigh)
        else:
            M.nearest_samples[isnippet] = -1
            labels_clustered.append('')
            labels_annotated.append('')
            scales.append([0]*M.audio_nchannels)
            ywavs.append([np.full(M.snippets_pix,np.nan)]*M.audio_nchannels)
            gram_images.append([])
    snippets_label_sources_clustered.data.update(text=labels_clustered)
    snippets_label_sources_annotated.data.update(text=labels_annotated)
    left_clustered, right_clustered, top_clustered, bottom_clustered = [], [], [], []
    for isnippet in range(M.snippets_nx*M.snippets_ny):
        ix, iy = isnippet%M.snippets_nx, isnippet//M.snippets_nx
        if redraw_wavs:
            xdata = range(ix*(M.snippets_gap_pix+M.snippets_pix),
                          (ix+1)*(M.snippets_gap_pix+M.snippets_pix)-M.snippets_gap_pix)
            for ichannel in range(M.audio_nchannels):
                if M.snippets_waveform:
                    ydata = -iy*snippets_dy + \
                            (M.audio_nchannels-1-2*ichannel)/M.audio_nchannels + \
                            ywavs[isnippet][ichannel]/M.audio_nchannels
                    snippets_wave_sources[isnippet][ichannel].data.update(x=xdata, y=ydata)
                    ipalette = int(np.floor(scales[isnippet][ichannel] /
                                            np.iinfo(np.int16).max *
                                            len(snippet_palette)))
                    snippets_wave_glyphs[isnippet][ichannel].glyph.line_color = snippet_palette[ipalette]
                if M.snippets_spectrogram and gram_images[isnippet]:
                    snippets_gram_glyphs[isnippet][ichannel].glyph.x = xdata[0]
                    snippets_gram_glyphs[isnippet][ichannel].glyph.y = \
                            -iy*snippets_dy -1 \
                            -2*(M.snippets_waveform and M.snippets_spectrogram) \
                            +M.audio_nchannels-1 - ichannel
                    snippets_gram_glyphs[isnippet][ichannel].glyph.dw = xdata[-1] - xdata[0] + \
                                                                        xdata[1] - xdata[0]
                    snippets_gram_glyphs[isnippet][ichannel].glyph.dh = 2/M.audio_nchannels
                    snippets_gram_sources[isnippet][ichannel].data.update(image=[np.log10( \
                            gram_images[isnippet][ichannel][ilows[isnippet][ichannel]:1+ihighs[isnippet][ichannel],:])])
                else:
                    snippets_gram_sources[isnippet][ichannel].data.update(image=[])
        if labels_annotated[isnippet]!='':
            left_clustered.append(ix*(M.snippets_gap_pix+M.snippets_pix))
            right_clustered.append((ix+1)*(M.snippets_gap_pix+M.snippets_pix)-M.snippets_gap_pix)
            top_clustered.append(-iy*snippets_dy+1)
            bottom_clustered.append(-iy*snippets_dy-1 \
                                    -2*(M.snippets_waveform and M.snippets_spectrogram))
    snippets_quad_grey.data.update(left=left_clustered, right=right_clustered,
                                   top=top_clustered, bottom=bottom_clustered)

def nparray2base64wav(data, samplerate):
    fid=io.BytesIO()
    wav=wave.open(fid, "w")
    wav.setframerate(samplerate)
    wav.setnchannels(1)
    wav.setsampwidth(2)
    wav.writeframes(data.tobytes())
    wav.close()
    fid.seek(0)
    ret_val = base64.b64encode(fid.read()).decode('utf-8')
    fid.close()
    return ret_val

def nparray2base64mp4(filename, start_sec, stop_sec):
    vid = pims.open(filename)

    start_frame = round(start_sec * vid.frame_rate).astype(np.int)
    stop_frame = round(stop_sec * vid.frame_rate).astype(np.int)

    fid=io.BytesIO()
    container = av.open(fid, mode='w', format='mp4')

    stream = container.add_stream('h264', rate=vid.frame_rate)
    stream.width = video_div.width = vid.frame_shape[0]
    stream.height = video_div.height = vid.frame_shape[1]
    stream.pix_fmt = 'yuv420p'

    for iframe in range(start_frame, stop_frame):
        frame = av.VideoFrame.from_ndarray(np.array(vid[iframe]), format='rgb24')
        for packet in stream.encode(frame):
            container.mux(packet)

    for packet in stream.encode():
        container.mux(packet)

    container.close()
    fid.seek(0)
    ret_val = base64.b64encode(fid.read()).decode('utf-8')
    fid.close()
    return ret_val

# _context_update() might be able to be folded back in to context_update() with bokeh 2.0
# ditto for _doit_callback() and _groundtruth_update()
# see https://discourse.bokeh.org/t/bokeh-server-is-it-possible-to-push-updates-to-js-in-the-middle-of-a-python-callback/3455/4

def reset_video():
    play_callback.code = C.play_callback_code % ("", "")

def __context_update(wavi, tapped_sample, istart_bounded, ilength):
    if video_toggle.active:
        sample_basename=os.path.basename(tapped_sample)
        sample_dirname=os.path.dirname(tapped_sample)
        vids = list(filter(lambda x: x!=sample_basename and
                                     os.path.splitext(x)[0] == \
                                         os.path.splitext(sample_basename)[0] and
                                     os.path.splitext(x)[1].lower() in \
                                         ['.avi','.mp4','.mov'],
                           os.listdir(sample_dirname)))
        base64vid = nparray2base64mp4(os.path.join(sample_dirname,vids[0]),
                                      istart_bounded / M.audio_tic_rate,
                                      (istart_bounded+ilength) / M.audio_tic_rate) \
                    if len(vids)==1 else ""
        video_toggle.button_type="default"
    else:
        base64vid = ""

    play_callback.code = C.play_callback_code % \
                         (nparray2base64wav(wavi, M.audio_tic_rate), \
                          base64vid)

def _context_update(wavi, tapped_sample, istart_bounded, ilength):
    if video_toggle.active:
        video_toggle.button_type="warning"
    bokeh_document.add_next_tick_callback(lambda: \
            __context_update(wavi, tapped_sample, istart_bounded, ilength))

def context_update():
    p_waveform.title.text = p_spectrogram.title.text = ''
    tapped_ticks = [np.nan, np.nan]
    istart = np.nan
    scales = [0]*M.audio_nchannels
    ywav = [np.full(1,np.nan)]*M.audio_nchannels
    xwav = [np.full(1,np.nan)]*M.audio_nchannels
    gram_freq = [np.full(1,np.nan)]*M.audio_nchannels
    gram_time = [np.full(1,np.nan)]*M.audio_nchannels
    gram_image = [np.full((1,1),np.nan)]*M.audio_nchannels
    yprob = [np.full(1,np.nan)]*len(M.clustered_labels)
    xprob = [np.full(1,np.nan)]*len(M.clustered_labels)
    ilow = [0]*M.audio_nchannels
    ihigh = [1]*M.audio_nchannels
    xlabel_clustered, tlabel_clustered = [], []
    xlabel_annotated, tlabel_annotated = [], []
    left_clustered, right_clustered = [], []
    left_annotated, right_annotated = [], []

    if M.isnippet>=0:
        play.disabled=False
        video_toggle.disabled=False
        zoom_context.disabled=False
        zoom_offset.disabled=False
        zoomin.disabled=False
        zoomout.disabled=False
        reset.disabled=False
        panleft.disabled=False
        panright.disabled=False
        allleft.disabled=False
        allout.disabled=False
        allright.disabled=False
        tapped_sample = M.clustered_samples[M.isnippet]
        tapped_ticks = tapped_sample['ticks']
        M.context_midpoint_tic = np.mean(tapped_ticks, dtype=int)
        istart = M.context_midpoint_tic-M.context_width_tic//2 + M.context_offset_tic
        if M.context_waveform:
            p_waveform.title.text = tapped_sample['file']
        elif M.context_spectrogram:
            p_spectrogram.title.text = tapped_sample['file']
        _, wavs = spiowav.read(tapped_sample['file'], mmap=True)
        if np.ndim(wavs)==1:
            wavs = np.expand_dims(wavs, axis=1)
        M.file_nframes = np.shape(wavs)[0]
        probs = [None]*len(M.clustered_labels)
        for ilabel,label in enumerate(M.clustered_labels):
            prob_wavfile = tapped_sample['file'][:-4]+'-'+label+'.wav'
            if os.path.isfile(prob_wavfile):
                prob_tic_rate, probs[ilabel] = spiowav.read(prob_wavfile, mmap=True)
        if istart+M.context_width_tic>0 and istart < M.file_nframes:
            istart_bounded = np.maximum(0, istart)
            context_tic_adjusted = M.context_width_tic+1-(istart_bounded-istart)
            ilength = np.minimum(M.file_nframes-istart_bounded, context_tic_adjusted)

            tic2pix = M.context_width_tic / M.gui_width_pix
            context_decimate_by = round(tic2pix/M.filter_ratio_max) if \
                     tic2pix>M.filter_ratio_max else 1
            context_pix = round(M.context_width_tic / context_decimate_by)

            if any([isinstance(x, np.ndarray) for x in probs]):
                tic_rate_ratio = prob_tic_rate / M.audio_tic_rate
                tic2pix = round(M.context_width_tic*tic_rate_ratio / M.gui_width_pix)
                prob_decimate_by = round(tic2pix/M.filter_ratio_max) if \
                         tic2pix>M.filter_ratio_max else 1
                prob_pix = round(M.context_width_tic*tic_rate_ratio / prob_decimate_by)

            for ichannel in range(M.audio_nchannels):
                wavi = wavs[istart_bounded : istart_bounded+ilength, ichannel]
                if len(wavi) < M.context_width_tic+1:
                    npad = M.context_width_tic+1-len(wavi)
                    if istart < 0:
                        wavi = np.concatenate((np.full((npad,),0), wavi))
                    if istart+M.context_width_tic>M.file_nframes:
                        wavi = np.concatenate((wavi, np.full((npad,),0)))

                if ichannel==0:
                    if bokeh_document: 
                        bokeh_document.add_next_tick_callback(lambda: \
                                _context_update(wavi,
                                                tapped_sample['file'],
                                                istart_bounded,
                                                ilength))

                if M.context_waveform:
                    wavi_downsampled = decimate(wavi, context_decimate_by, n=M.filter_order,
                                                ftype='iir', zero_phase=True)
                    wavi_trimmed = wavi_downsampled[:context_pix]

                    scales[ichannel]=np.minimum(np.iinfo(np.int16).max-1,
                                                np.max(np.abs(wavi_trimmed)))
                    wavi_scaled = wavi_trimmed/scales[ichannel]
                    icliplow = np.where((wavi_scaled  <  M.context_waveform_low[ichannel]))[0]
                    icliphigh = np.where((wavi_scaled > M.context_waveform_high[ichannel]))[0]
                    wavi_zoomed = np.copy(wavi_scaled)
                    wavi_zoomed[icliplow] = M.context_waveform_low[ichannel]
                    wavi_zoomed[icliphigh] = M.context_waveform_high[ichannel]
                    ywav[ichannel] = (wavi_zoomed - M.context_waveform_low[ichannel]) / \
                                      (M.context_waveform_high[ichannel] - M.context_waveform_low[ichannel]) \
                                      * 2 - 1
                    xwav[ichannel]=[(istart+i*context_decimate_by)/M.audio_tic_rate \
                                     for i in range(len(wavi_trimmed))]
                else:
                    xwav[ichannel] = [istart/M.audio_tic_rate,
                                       (istart+(context_pix-1)*context_decimate_by)/M.audio_tic_rate]

                if M.context_spectrogram:
                    window_length = round(M.spectrogram_length_ms[ichannel]/1000*M.audio_tic_rate)
                    gram_freq[ichannel], gram_time[ichannel], gram_image[ichannel] = \
                            spectrogram(wavi,
                                        fs=M.audio_tic_rate,
                                        window=M.spectrogram_window,
                                        nperseg=window_length,
                                        noverlap=round(window_length*M.spectrogram_overlap))
                    ilow[ichannel] = np.argmin(np.abs(gram_freq[ichannel] - \
                                                      M.spectrogram_low_hz[ichannel]))
                    ihigh[ichannel] = np.argmin(np.abs(gram_freq[ichannel] - \
                                                       M.spectrogram_high_hz[ichannel]))

            for ilabel in range(len(M.clustered_labels)):
                if not isinstance(probs[ilabel], np.ndarray):  continue
                prob_istart = int(np.rint(istart_bounded*tic_rate_ratio))
                prob_istop = int(np.rint((istart_bounded+ilength)*tic_rate_ratio))
                probi = probs[ilabel][prob_istart : prob_istop : prob_decimate_by]
                if len(probi) < round(M.context_width_tic*tic_rate_ratio)+1:
                    npad = round(M.context_width_tic*tic_rate_ratio)+1-len(probi)
                    if istart < 0:
                        probi = np.concatenate((np.full((npad,),0), probi))
                    if istart+round(M.context_width_tic*tic_rate_ratio)>M.file_nframes:
                        probi = np.concatenate((probi, np.full((npad,),0)))
                probi_trimmed = probi[:prob_pix]
                yprob[ilabel] = probi_trimmed / np.iinfo(np.int16).max
                xprob[ilabel]=[(prob_istart+i*prob_decimate_by)/prob_tic_rate \
                                 for i in range(len(probi_trimmed))]

            if M.context_spectrogram:
                p_spectrogram.yaxis.formatter = FuncTickFormatter(
                    args=dict(low_hz=[gram_freq[i][x] / M.context_spectrogram_freq_scale \
                                      for i,x in enumerate(ilow)],
                              high_hz=[gram_freq[i][x] / M.context_spectrogram_freq_scale \
                                       for i,x in enumerate(ihigh)]),
                    code="""
                         if (tick==0) {
                             return low_hz[low_hz.length-1] }
                         else if (tick == high_hz.length) {
                             return high_hz[0] }
                         else {
                             return low_hz[low_hz.length-tick-1] + "," + high_hz[high_hz.length-tick] }
                         """)

            ileft = np.searchsorted(M.clustered_starts_sorted, istart+M.context_width_tic)
            samples_to_plot = set(range(0,ileft))
            iright = np.searchsorted(M.clustered_stops, istart,
                                    sorter=M.iclustered_stops_sorted)
            samples_to_plot &= set([M.iclustered_stops_sorted[i] for i in \
                    range(iright, len(M.iclustered_stops_sorted))])

            tapped_wav_in_view = False
            for isample in samples_to_plot:
                if tapped_sample['file']!=M.clustered_samples[isample]['file']:
                    continue
                L = np.max([istart, M.clustered_samples[isample]['ticks'][0]])
                R = np.min([istart+M.context_width_tic,
                            M.clustered_samples[isample]['ticks'][1]])
                xlabel_clustered.append((L+R)/2/M.audio_tic_rate)
                tlabel_clustered.append(M.clustered_samples[isample]['kind']+'\n'+\
                              M.clustered_samples[isample]['label'])
                left_clustered.append(L/M.audio_tic_rate)
                right_clustered.append(R/M.audio_tic_rate)
                if tapped_sample==M.clustered_samples[isample] and not np.isnan(M.xcluster):
                    if M.context_waveform:
                        waveform_quad_fuchsia.data.update(left=[L/M.audio_tic_rate],
                                                          right=[R/M.audio_tic_rate],
                                                          top=[1],
                                                          bottom=[0])
                    if M.context_spectrogram:
                        spectrogram_quad_fuchsia.data.update(left=[L/M.audio_tic_rate],
                                                             right=[R/M.audio_tic_rate],
                                                             top=[M.audio_nchannels],
                                                             bottom=[M.audio_nchannels/2])
                    tapped_wav_in_view = True

            if M.context_waveform:
                waveform_span_red.location=xwav[0][0]
                waveform_span_red.visible=True
            if M.context_spectrogram:
                spectrogram_span_red.location=xwav[0][0]
                spectrogram_span_red.visible=True

            if not tapped_wav_in_view:
                if M.context_waveform:
                    waveform_quad_fuchsia.data.update(left=[], right=[], top=[], bottom=[])
                if M.context_spectrogram:
                    spectrogram_quad_fuchsia.data.update(left=[], right=[], top=[], bottom=[])

            if len(M.annotated_starts_sorted)>0:
                ileft = np.searchsorted(M.annotated_starts_sorted,
                                        istart+M.context_width_tic)
                samples_to_plot = set(range(0,ileft))
                iright = np.searchsorted(M.annotated_stops, istart,
                                         sorter=M.iannotated_stops_sorted)
                samples_to_plot &= set([M.iannotated_stops_sorted[i] for i in \
                        range(iright, len(M.iannotated_stops_sorted))])

                for isample in samples_to_plot:
                    if tapped_sample['file']!=M.annotated_samples[isample]['file']:
                        continue
                    L = np.max([istart, M.annotated_samples[isample]['ticks'][0]])
                    R = np.min([istart+M.context_width_tic,
                                M.annotated_samples[isample]['ticks'][1]])
                    xlabel_annotated.append((L+R)/2/M.audio_tic_rate)
                    tlabel_annotated.append(M.annotated_samples[isample]['label'])
                    left_annotated.append(L/M.audio_tic_rate)
                    right_annotated.append(R/M.audio_tic_rate)
    else:
        play.disabled=True
        video_toggle.disabled=True
        zoom_context.disabled=True
        zoom_offset.disabled=True
        zoomin.disabled=True
        zoomout.disabled=True
        reset.disabled=True
        panleft.disabled=True
        panright.disabled=True
        allleft.disabled=True
        allout.disabled=True
        allright.disabled=True
        if M.context_waveform:
            waveform_quad_fuchsia.data.update(left=[], right=[], top=[], bottom=[])
            waveform_span_red.location=0
            waveform_span_red.visible=False
        if M.context_spectrogram:
            spectrogram_quad_fuchsia.data.update(left=[], right=[], top=[], bottom=[])
            spectrogram_span_red.location=0
            spectrogram_span_red.visible=False
        reset_video()

    for ichannel in range(M.audio_nchannels):
        xdata = xwav[ichannel]
        if M.context_waveform:
            ydata = (ywav[ichannel] + M.audio_nchannels-1-2*ichannel) / M.audio_nchannels
            waveform_source[ichannel].data.update(x=xdata, y=ydata)
            ipalette = int(np.floor(scales[ichannel] /
                                    np.iinfo(np.int16).max *
                                    len(snippet_palette)))
            waveform_glyph[ichannel].glyph.line_color = snippet_palette[ipalette]
        if M.context_spectrogram and not np.isnan(gram_time[ichannel][0]):
            spectrogram_glyph[ichannel].glyph.x = xdata[0]
            spectrogram_glyph[ichannel].glyph.y = M.audio_nchannels-1 - ichannel
            spectrogram_glyph[ichannel].glyph.dw = xdata[-1] - xdata[0]
            spectrogram_glyph[ichannel].glyph.dh = 1
            spectrogram_source[ichannel].data.update(image=[np.log10( \
                    gram_image[ichannel][ilow[ichannel]:1+ihigh[ichannel],:])])
        else:
            spectrogram_source[ichannel].data.update(image=[])

    probability_source.data.update(xs=xprob, ys=yprob,
                                   colors=cluster_dot_palette[:len(M.clustered_labels)],
                                   labels=list(M.clustered_labels))

    if M.context_waveform:
        waveform_quad_grey_clustered.data.update(left=left_clustered,
                                                 right=right_clustered,
                                                 top=[1]*len(left_clustered),
                                                 bottom=[0]*len(left_clustered))
        waveform_quad_grey_annotated.data.update(left=left_annotated,
                                                 right=right_annotated,
                                                 top=[0]*len(left_annotated),
                                                 bottom=[-1]*len(left_annotated))
        waveform_label_source_clustered.data.update(x=xlabel_clustered,
                                                    y=[1]*len(xlabel_clustered),
                                                    text=tlabel_clustered)
        waveform_label_source_annotated.data.update(x=xlabel_annotated,
                                                    y=[-1]*len(xlabel_annotated),
                                                    text=tlabel_annotated)
    if M.context_spectrogram:
        spectrogram_quad_grey_clustered.data.update(left=left_clustered,
                                                    right=right_clustered,
                                                    top=[M.audio_nchannels]*len(left_clustered),
                                                    bottom=[M.audio_nchannels/2]*len(left_clustered))
        spectrogram_quad_grey_annotated.data.update(left=left_annotated,
                                                    right=right_annotated,
                                                    top=[M.audio_nchannels/2]*len(left_annotated),
                                                    bottom=[0]*len(left_annotated))
        spectrogram_label_source_clustered.data.update(x=xlabel_clustered,
                                                       y=[M.audio_nchannels]*len(xlabel_clustered),
                                                       text=tlabel_clustered)
        spectrogram_label_source_annotated.data.update(x=xlabel_annotated,
                                                       y=[0]*len(xlabel_annotated),
                                                       text=tlabel_annotated)

def save_update(n):
    save_indicator.label=str(n)
    if n==0:
        save_indicator.button_type="default"
    elif n < 10:
        save_indicator.button_type="warning"
    else:
        save_indicator.button_type="danger"

def waitfor_update():
    if len(M.waitfor_job)>0:
        waitfor.disabled=False

def model_file_update(attr, old, new):
    M.save_state_callback()
    buttons_update()

def cluster_these_layers_update():
    if os.path.isfile(os.path.join(groundtruth_folder.value,'activations.npz')):
        npzfile = np.load(os.path.join(groundtruth_folder.value,'activations.npz'),
                          allow_pickle=True)
        nlayers = len(list(filter(lambda x: x.startswith('arr_'), npzfile.files)))
        cluster_these_layers.options = [("0", "input"),
                                        *[(str(i), "hidden #"+str(i)) \
                                          for i in range(1,nlayers-1)],
                                        (str(nlayers-1), "output")]
    else:
        cluster_these_layers.options = []

def _groundtruth_update():
    wordcounts_update()
    cluster_these_layers_update()
    M.save_state_callback()
    groundtruth.button_type="default"
    groundtruth.disabled=True
    buttons_update()

def groundtruth_update():
    groundtruth.button_type="warning"
    groundtruth.disabled=True
    if bokeh_document: 
        bokeh_document.add_next_tick_callback(_groundtruth_update)

def wantedwords_update_other():
    wantedwords = [x.value for x in label_text_widgets if x.value!='']
    if 'other' not in wantedwords:
        wantedwords.append('other')
    wantedwords_string.value=str.join(',',wantedwords)

def buttons_update():
    for button in wizard_buttons:
        button.button_type="success" if button==M.wizard else "default"
    for button in action_buttons:
        button.button_type="primary" if button==M.action else "default"
        button.disabled=False if button in wizard2actions[M.wizard] else True
    if M.action in [detect,classify]:
        wavtfcsvfiles.label='wav files:'
    elif M.action==ethogram:
        wavtfcsvfiles.label='tf files:'
    elif M.action==misses:
        wavtfcsvfiles.label='csv files:'
    else:
        wavtfcsvfiles.label='wav,tf,csv files:'
    if M.action == classify:
        model.label='pb file:'
    elif M.action == ethogram:
        model.label='threshold file:'
    else:
        model.label='checkpoint file:'
    for button in parameter_buttons:
        button.disabled=False if button in action2parameterbuttons[M.action] else True
    okay=True if M.action else False
    for textinput in parameter_textinputs:
        if textinput in action2parametertextinputs[M.action]:
            if textinput==window_ms_string:
                window_ms_string.disabled=True \
                        if representation.value=='waveform' else False
            elif textinput==stride_ms_string:
                stride_ms_string.disabled=True \
                        if representation.value=='waveform' else False
            elif textinput==mel_dct_string:
                mel_dct_string.disabled=False \
                        if representation.value=='mel-cepstrum' else True
            elif textinput==pca_fraction_variance_to_retain_string:
                pca_fraction_variance_to_retain_string.disabled=False \
                        if cluster_algorithm.value[:4] in ['tSNE','UMAP'] else True
            elif textinput==tsne_perplexity_string:
                tsne_perplexity_string.disabled=False \
                        if cluster_algorithm.value.startswith('tSNE') else True
            elif textinput==tsne_exaggeration_string:
                tsne_exaggeration_string.disabled=False \
                        if cluster_algorithm.value.startswith('tSNE') else True
            elif textinput==umap_neighbors_string:
                umap_neighbors_string.disabled=False \
                        if cluster_algorithm.value.startswith('UMAP') else True
            elif textinput==umap_distance_string:
                umap_distance_string.disabled=False \
                        if cluster_algorithm.value.startswith('UMAP') else True
            else:
                textinput.disabled=False
            if textinput.disabled==False and textinput.value=='':
                if M.action==classify:
                    if textinput not in [wantedwords_string, prevalences_string]:
                        okay=False
                elif M.action==congruence:
                    if textinput not in [validationfiles_string, testfiles_string]:
                        okay=False
                else:
                    if textinput not in [testfiles_string, restore_from_string]:
                        okay=False
        else:
            textinput.disabled=True
    if M.action==classify and \
            prevalences_string.value!='' and wantedwords_string.value=='':
        okay=False
    if M.action==congruence and \
            validationfiles_string.value=='' and testfiles_string.value=='':
        okay=False
    if M.action==cluster and len(cluster_these_layers.value)==0:
        okay=False
    doit.button_type="default"
    if okay:
        doit.disabled=False
        doit.button_type="danger"
    else:
        doit.disabled=True
        doit.button_type="default"
    cluster_these_layers.disabled = False  # https://github.com/bokeh/bokeh/issues/10507

def file_dialog_update():
    thispath = os.path.join(M.file_dialog_root,M.file_dialog_filter)
    files = glob.glob(thispath)
    uniqdirnames = set([os.path.dirname(x) for x in files])
    files = natsorted(['.', '..', *files])
    if len(uniqdirnames)==1:
        names=[os.path.basename(x) + ('/' if os.path.isdir(x) else '') for x in files]
    else:
        names=[x + ('/' if os.path.isdir(x) else '') for x in files]
    file_dialog = dict(
        names=names,
        sizes=[os.path.getsize(f) for f in files],
        dates=[datetime.fromtimestamp(os.path.getmtime(f)) for f in files],
    )
    file_dialog_source.data = file_dialog

def wordcounts_update():
    if not os.path.isdir(groundtruth_folder.value):
        return
    dfs = []
    for subdir in filter(lambda x: os.path.isdir(os.path.join(groundtruth_folder.value,x)), \
                         os.listdir(groundtruth_folder.value)):
        for csvfile in filter(lambda x: x.endswith('.csv'), \
                              os.listdir(os.path.join(groundtruth_folder.value, subdir))):
            filepath = os.path.join(groundtruth_folder.value, subdir, csvfile)
            if os.path.getsize(filepath) > 0:
                df = pd.read_csv(filepath, header=None, index_col=False)
                if 5 < =len(df.columns) < =6:
                    dfs.append(df)
                else:
                    bokehlog.info("WARNING: "+csvfile+" is not in the correct format")
    if dfs:
        df = pd.concat(dfs)
        M.kinds = sorted(set(df[3]))
        words = sorted(set(df[4]))
        bkinds = {}
        table = np.empty((1+len(words),len(M.kinds)), dtype=np.int)
        for iword,word in enumerate(words):
            bword = np.array(df[4]==word)
            for ikind,kind in enumerate(M.kinds):
                if kind not in bkinds:
                    bkinds[kind] = np.array(df[3]==kind)
                table[iword,ikind] = np.sum(np.logical_and(bkinds[kind], bword))
        for ikind,kind in enumerate(M.kinds):
            table[len(words),ikind] = np.sum(bkinds[kind])
        words += ['TOTAL']

        if len(words)>len(M.kinds):
            rows = words
            cols = M.kinds
        else:
            rows = M.kinds
            cols = words
            table = np.transpose(table)
        table_str = ' < table> < tr> < th> < /th> < th nowrap>'+' < /th> < th nowrap>'.join(cols)+' < /th> < /tr>'
        for irow,row in enumerate(rows):
            table_str += ' < tr> < th nowrap>'+row+' < /th>'
            for icol,col in enumerate(cols):
                table_str += ' < td align="center">'+str(table[irow,icol])+' < /td>'
            table_str += ' < /tr>'
        table_str += ' < /table>'
        wordcounts.text = table_str
    else:
        wordcounts.text = ""

async def status_ticker_update():
    if len(M.status_ticker_queue)>0:
        newtext = []
        for k in M.status_ticker_queue.keys():
            if M.status_ticker_queue[k]=="pending":
                color = "gray"
            elif M.status_ticker_queue[k]=="running":
                color = "black"
            elif M.status_ticker_queue[k]=="succeeded":
                color = "blue"
            elif M.status_ticker_queue[k]=="failed":
                color = "red"
            newtext.append(" < span style='color:"+color+"'>"+k+" < /span>")
        newtext = (', ').join(newtext)
    else:
        newtext = ''
    status_ticker.text = status_ticker_pre+newtext+status_ticker_post

def init(_bokeh_document):
    global bokeh_document, cluster_dot_palette, snippet_palette, p_cluster, cluster_dots, p_cluster_dots, precomputed_dots, snippets_dy, p_snippets, snippets_label_sources_clustered, snippets_label_sources_annotated, snippets_wave_sources, snippets_wave_glyphs, snippets_gram_sources, snippets_gram_glyphs, snippets_quad_grey, dot_size_cluster, dot_alpha_cluster, cluster_circle_fuchsia, p_waveform, p_spectrogram, p_probability, probability_source, probability_glyph, spectrogram_source, spectrogram_glyph, waveform_span_red, spectrogram_span_red, waveform_quad_grey_clustered, waveform_quad_grey_annotated, waveform_quad_grey_pan, waveform_quad_fuchsia, spectrogram_quad_grey_clustered, spectrogram_quad_grey_annotated, spectrogram_quad_grey_pan, spectrogram_quad_fuchsia, snippets_quad_fuchsia, waveform_source, waveform_glyph, waveform_label_source_clustered, waveform_label_source_annotated, spectrogram_label_source_clustered, spectrogram_label_source_annotated, which_layer, which_species, which_word, which_nohyphen, which_kind, color_picker, circle_radius, dot_size, dot_alpha, zoom_context, zoom_offset, zoomin, zoomout, reset, panleft, panright, allleft, allout, allright, save_indicator, label_count_widgets, label_text_widgets, play, play_callback, video_toggle, video_div, undo, redo, detect, misses, configuration_file, train, leaveoneout, leaveallout, xvalidate, mistakes, activations, cluster, visualize, accuracy, freeze, classify, ethogram, compare, congruence, status_ticker, waitfor, file_dialog_source, file_dialog_source, configuration_contents, logs, logs_folder, model, model_file, wavtfcsvfiles, wavtfcsvfiles_string, groundtruth, groundtruth_folder, validationfiles, testfiles, validationfiles_string, testfiles_string, wantedwords, wantedwords_string, labeltypes, labeltypes_string, prevalences, prevalences_string, copy, labelsounds, makepredictions, fixfalsepositives, fixfalsenegatives, generalize, tunehyperparameters, findnovellabels, examineerrors, testdensely, doit, time_sigma_string, time_smooth_ms_string, frequency_n_ms_string, frequency_nw_string, frequency_p_string, frequency_smooth_ms_string, nsteps_string, restore_from_string, save_and_validate_period_string, validate_percentage_string, mini_batch_string, kfold_string, activations_equalize_ratio_string, activations_max_samples_string, pca_fraction_variance_to_retain_string, tsne_perplexity_string, tsne_exaggeration_string, umap_neighbors_string, umap_distance_string, cluster_algorithm, cluster_these_layers, precision_recall_ratios_string, context_ms_string, shiftby_ms_string, representation, window_ms_string, stride_ms_string, mel_dct_string, optimizer, learning_rate_string, replicates_string, batch_seed_string, weights_seed_string, file_dialog_string, file_dialog_table, readme_contents, wordcounts, wizard_buttons, action_buttons, parameter_buttons, parameter_textinputs, wizard2actions, action2parameterbuttons, action2parametertextinputs, status_ticker_update, status_ticker_pre, status_ticker_post, model_parameters

    bokeh_document = _bokeh_document

    M.cluster_circle_color = M.cluster_circle_color

    if '#' in M.cluster_dot_palette:
      cluster_dot_palette = ast.literal_eval(M.cluster_dot_palette)
    else:
      cluster_dot_palette = getattr(palettes, M.cluster_dot_palette)

    snippet_palette = getattr(palettes, M.snippets_colormap)

    dot_size_cluster = ColumnDataSource(data=dict(ds=[M.state["dot_size"]]))
    dot_alpha_cluster = ColumnDataSource(data=dict(da=[M.state["dot_alpha"]]))

    cluster_dots = ColumnDataSource(data=dict(dx=[], dy=[], dz=[], dl=[], dc=[]))
    cluster_circle_fuchsia = ColumnDataSource(data=dict(cx=[], cy=[], cz=[], cr=[], cc=[]))
    p_cluster = ScatterNd(dx='dx', dy='dy', dz='dz', dl='dl', dc='dc',
                          dots_source=cluster_dots,
                          cx='cx', cy='cy', cz='cz', cr='cr', cc='cc',
                          circle_fuchsia_source=cluster_circle_fuchsia,
                          ds='ds',
                          dot_size_source=dot_size_cluster,
                          da='da',
                          dot_alpha_source=dot_alpha_cluster,
                          width=M.gui_width_pix//2)
    p_cluster.on_change("click_position", lambda a,o,n: C.cluster_tap_callback(n))

    precomputed_dots = None

    snippets_dy = 2*M.snippets_waveform + 2*M.snippets_spectrogram

    p_snippets = figure(plot_width=M.gui_width_pix//2, \
                        background_fill_color='#FFFFFF', toolbar_location=None)
    p_snippets.toolbar.active_drag = None
    p_snippets.grid.visible = False
    p_snippets.xaxis.visible = False
    p_snippets.yaxis.visible = False

    snippets_gram_sources=[None]*(M.snippets_nx*M.snippets_ny)
    snippets_gram_glyphs=[None]*(M.snippets_nx*M.snippets_ny)
    for ixy in range(M.snippets_nx*M.snippets_ny):
        snippets_gram_sources[ixy]=[None]*M.audio_nchannels
        snippets_gram_glyphs[ixy]=[None]*M.audio_nchannels
        for ichannel in range(M.audio_nchannels):
            snippets_gram_sources[ixy][ichannel] = ColumnDataSource(data=dict(image=[]))
            snippets_gram_glyphs[ixy][ichannel] = p_snippets.image('image',
                    source=snippets_gram_sources[ixy][ichannel],
                    palette=M.spectrogram_colormap)

    snippets_quad_grey = ColumnDataSource(data=dict(left=[], right=[], top=[], bottom=[]))
    p_snippets.quad('left','right','top','bottom',source=snippets_quad_grey,
                fill_color="lightgrey", fill_alpha=0.5, line_color="lightgrey")

    snippets_wave_sources=[None]*(M.snippets_nx*M.snippets_ny)
    snippets_wave_glyphs=[None]*(M.snippets_nx*M.snippets_ny)
    for ixy in range(M.snippets_nx*M.snippets_ny):
        snippets_wave_sources[ixy]=[None]*M.audio_nchannels
        snippets_wave_glyphs[ixy]=[None]*M.audio_nchannels
        for ichannel in range(M.audio_nchannels):
            snippets_wave_sources[ixy][ichannel]=ColumnDataSource(data=dict(x=[], y=[]))
            snippets_wave_glyphs[ixy][ichannel]=p_snippets.line(
                    'x', 'y', source=snippets_wave_sources[ixy][ichannel])

    xdata = [(i%M.snippets_nx)*(M.snippets_gap_pix+M.snippets_pix)
             for i in range(M.snippets_nx*M.snippets_ny)]
    ydata = [-(i//M.snippets_nx*snippets_dy-1)
             for i in range(M.snippets_nx*M.snippets_ny)]
    text = ['' for i in range(M.snippets_nx*M.snippets_ny)]
    snippets_label_sources_clustered = ColumnDataSource(data=dict(x=xdata, y=ydata, text=text))
    p_snippets.text('x', 'y', source=snippets_label_sources_clustered, text_font_size='6pt',
                    text_baseline='top',
                    text_color='black' if M.snippets_waveform else 'white')

    xdata = [(i%M.snippets_nx)*(M.snippets_gap_pix+M.snippets_pix)
             for i in range(M.snippets_nx*M.snippets_ny)]
    ydata = [-(i//M.snippets_nx*snippets_dy+1+2*(M.snippets_waveform and M.snippets_spectrogram))
             for i in range(M.snippets_nx*M.snippets_ny)]
    text_annotated = ['' for i in range(M.snippets_nx*M.snippets_ny)]
    snippets_label_sources_annotated = ColumnDataSource(data=dict(x=xdata, y=ydata,
                                                                  text=text_annotated))
    p_snippets.text('x', 'y', source=snippets_label_sources_annotated,
                    text_font_size='6pt',
                    text_color='white' if M.snippets_spectrogram else 'black')

    snippets_quad_fuchsia = ColumnDataSource(data=dict(left=[], right=[], top=[], bottom=[]))
    p_snippets.quad('left','right','top','bottom',source=snippets_quad_fuchsia,
                fill_color=None, line_color="fuchsia")

    p_snippets.on_event(Tap, C.snippets_tap_callback)
    p_snippets.on_event(DoubleTap, C.snippets_doubletap_callback)

    p_waveform = figure(plot_width=M.gui_width_pix,
                        plot_height=M.context_waveform_height_pix,
                        background_fill_color='#FFFFFF', toolbar_location=None)
    p_waveform.toolbar.active_drag = None
    p_waveform.grid.visible = False
    if M.context_spectrogram:
        p_waveform.xaxis.visible = False
    else:
        p_waveform.xaxis.axis_label = 'Time (sec)'
    p_waveform.yaxis.axis_label = ""
    p_waveform.yaxis.ticker = []
    p_waveform.x_range.range_padding = p_waveform.y_range.range_padding = 0.0
    p_waveform.y_range.start = -1
    p_waveform.y_range.end = 1
    p_waveform.title.text=' '

    waveform_span_red = Span(location=0, dimension='height', line_color='red')
    p_waveform.add_layout(waveform_span_red)
    waveform_span_red.visible=False

    waveform_quad_grey_clustered = ColumnDataSource(data=dict(left=[], right=[], top=[], bottom=[]))
    p_waveform.quad('left','right','top','bottom',source=waveform_quad_grey_clustered,
                fill_color="lightgrey", fill_alpha=0.5, line_color="lightgrey",
                level='underlay')
    waveform_quad_grey_annotated = ColumnDataSource(data=dict(left=[], right=[], top=[], bottom=[]))
    p_waveform.quad('left','right','top','bottom',source=waveform_quad_grey_annotated,
                fill_color="lightgrey", fill_alpha=0.5, line_color="lightgrey",
                level='underlay')
    waveform_quad_grey_pan = ColumnDataSource(data=dict(left=[], right=[], top=[], bottom=[]))
    p_waveform.quad('left','right','top','bottom',source=waveform_quad_grey_pan,
                fill_color="lightgrey", fill_alpha=0.5, line_color="lightgrey",
                level='underlay')
    waveform_quad_fuchsia = ColumnDataSource(data=dict(left=[], right=[], top=[], bottom=[]))
    p_waveform.quad('left','right','top','bottom',source=waveform_quad_fuchsia,
                fill_color=None, line_color="fuchsia", level='underlay')

    waveform_source=[None]*M.audio_nchannels
    waveform_glyph=[None]*M.audio_nchannels
    for ichannel in range(M.audio_nchannels):
        waveform_source[ichannel] = ColumnDataSource(data=dict(x=[], y=[]))
        waveform_glyph[ichannel] = p_waveform.line('x', 'y', source=waveform_source[ichannel])

    waveform_label_source_clustered = ColumnDataSource(data=dict(x=[], y=[], text=[]))
    p_waveform.text('x', 'y', source=waveform_label_source_clustered,
                   text_font_size='6pt', text_align='center', text_baseline='top',
                   text_line_height=0.8, level='underlay')
    waveform_label_source_annotated = ColumnDataSource(data=dict(x=[], y=[], text=[]))
    p_waveform.text('x', 'y', source=waveform_label_source_annotated,
                   text_font_size='6pt', text_align='center', text_baseline='bottom',
                   text_line_height=0.8, level='underlay')

    p_waveform.on_event(DoubleTap, lambda e: C.context_doubletap_callback(e, 0))

    p_waveform.on_event(PanStart, C.waveform_pan_start_callback)
    p_waveform.on_event(Pan, C.waveform_pan_callback)
    p_waveform.on_event(PanEnd, C.waveform_pan_end_callback)
    p_waveform.on_event(Tap, C.waveform_tap_callback)

    p_spectrogram = figure(plot_width=M.gui_width_pix,
                           plot_height=M.context_spectrogram_height_pix,
                           background_fill_color='#FFFFFF', toolbar_location=None)
    p_spectrogram.toolbar.active_drag = None
    p_spectrogram.x_range.range_padding = p_spectrogram.y_range.range_padding = 0
    p_spectrogram.xgrid.visible = False
    p_spectrogram.ygrid.visible = True
    p_spectrogram.xaxis.axis_label = 'Time (sec)'
    p_spectrogram.yaxis.axis_label = 'Frequency (' + M.context_spectrogram_units + ')'
    p_spectrogram.yaxis.ticker = list(range(1+M.audio_nchannels))

    spectrogram_source = [None]*M.audio_nchannels
    spectrogram_glyph = [None]*M.audio_nchannels
    for ichannel in range(M.audio_nchannels):
        spectrogram_source[ichannel] = ColumnDataSource(data=dict(image=[]))
        spectrogram_glyph[ichannel] = p_spectrogram.image('image',
                                                          source=spectrogram_source[ichannel],
                                                          palette=M.spectrogram_colormap,
                                                          level="image")

    p_spectrogram.on_event(MouseWheel, C.spectrogram_mousewheel_callback)

    p_spectrogram.on_event(DoubleTap,
                           lambda e: C.context_doubletap_callback(e, M.audio_nchannels/2))

    p_spectrogram.on_event(PanStart, C.spectrogram_pan_start_callback)
    p_spectrogram.on_event(Pan, C.spectrogram_pan_callback)
    p_spectrogram.on_event(PanEnd, C.spectrogram_pan_end_callback)
    p_spectrogram.on_event(Tap, C.spectrogram_tap_callback)

    spectrogram_span_red = Span(location=0, dimension='height', line_color='red')
    p_spectrogram.add_layout(spectrogram_span_red)
    spectrogram_span_red.visible=False

    spectrogram_quad_grey_clustered = ColumnDataSource(data=dict(left=[], right=[], top=[], bottom=[]))
    p_spectrogram.quad('left','right','top','bottom',source=spectrogram_quad_grey_clustered,
                fill_color="lightgrey", fill_alpha=0.5, line_color="lightgrey",
                level='underlay')
    spectrogram_quad_grey_annotated = ColumnDataSource(data=dict(left=[], right=[], top=[], bottom=[]))
    p_spectrogram.quad('left','right','top','bottom',source=spectrogram_quad_grey_annotated,
                fill_color="lightgrey", fill_alpha=0.5, line_color="lightgrey",
                level='underlay')
    spectrogram_quad_grey_pan = ColumnDataSource(data=dict(left=[], right=[], top=[], bottom=[]))
    p_spectrogram.quad('left','right','top','bottom',source=spectrogram_quad_grey_pan,
                fill_color="lightgrey", fill_alpha=0.5, line_color="lightgrey",
                level='underlay')
    spectrogram_quad_fuchsia = ColumnDataSource(data=dict(left=[], right=[], top=[], bottom=[]))
    p_spectrogram.quad('left','right','top','bottom',source=spectrogram_quad_fuchsia,
                fill_color=None, line_color="fuchsia", level='underlay')

    spectrogram_label_source_clustered = ColumnDataSource(data=dict(x=[], y=[], text=[]))
    p_spectrogram.text('x', 'y', source=spectrogram_label_source_clustered,
                   text_font_size='6pt', text_align='center', text_baseline='top',
                   text_line_height=0.8, level='underlay', text_color='white')
    spectrogram_label_source_annotated = ColumnDataSource(data=dict(x=[], y=[], text=[]))
    p_spectrogram.text('x', 'y', source=spectrogram_label_source_annotated,
                   text_font_size='6pt', text_align='center', text_baseline='bottom',
                   text_line_height=0.8, level='underlay', text_color='white')

    TOOLTIPS = """
         < div> < div> < span style="color:@colors;">@labels < /span> < /div> < /div>
    """

    p_probability = figure(plot_width=M.gui_width_pix, tooltips=TOOLTIPS,
                           plot_height=M.context_probability_height_pix,
                           background_fill_color='#FFFFFF', toolbar_location=None)
    p_probability.toolbar.active_drag = None
    p_probability.grid.visible = False
    p_probability.yaxis.axis_label = "Probability"
    p_probability.x_range.range_padding = p_probability.y_range.range_padding = 0.0
    p_probability.y_range.start = 0
    p_probability.y_range.end = 1
    p_probability.xaxis.visible = False

    probability_source = ColumnDataSource(data=dict(xs=[], ys=[], colors=[], labels=[]))
    probability_glyph = p_probability.multi_line(xs='xs', ys='ys',
                                                 source=probability_source, color='colors')

    probability_span_red = Span(location=0, dimension='height', line_color='red')
    p_probability.add_layout(probability_span_red)
    probability_span_red.visible=False

    which_layer = Select(title="layer:")
    which_layer.on_change('value', lambda a,o,n: C.layer_callback(n))

    which_species = Select(title="species:")
    which_species.on_change('value', lambda a,o,n: C.species_callback(n))

    which_word = Select(title="word:")
    which_word.on_change('value', lambda a,o,n: C.word_callback(n))

    which_nohyphen = Select(title="no hyphen:")
    which_nohyphen.on_change('value', lambda a,o,n: C.nohyphen_callback(n))

    which_kind = Select(title="kind:")
    which_kind.on_change('value', lambda a,o,n: C.kind_callback(n))

    color_picker = ColorPicker(title="color:", disabled=True)
    color_picker.on_change("color", lambda a,o,n: C.color_picker_callback(n))

    circle_radius = Slider(start=0, end=10, step=1, \
                           value=M.state["circle_radius"], \
                           title="circle radius", \
                           disabled=True)
    circle_radius.on_change("value_throttled", C.circle_radius_callback)

    dot_size = Slider(start=1, end=24, step=1, \
                      value=M.state["dot_size"], \
                      title="dot size", \
                      disabled=True)
    dot_size.on_change("value", C.dot_size_callback)

    dot_alpha = Slider(start=0.01, end=1.0, step=0.01, \
                       value=M.state["dot_alpha"], \
                       title="dot alpha", \
                       disabled=True)
    dot_alpha.on_change("value", C.dot_alpha_callback)

    cluster_update()

    zoom_context = TextInput(value=str(M.context_width_ms),
                             title="context (msec):",
                             disabled=True)
    zoom_context.on_change("value", C.zoom_context_callback)

    zoom_offset = TextInput(value=str(M.context_offset_ms),
                            title="offset (msec):",
                            disabled=True)
    zoom_offset.on_change("value", C.zoom_offset_callback)

    zoomin = Button(label='\u2191', disabled=True)
    zoomin.on_click(C.zoomin_callback)

    zoomout = Button(label='\u2193', disabled=True)
    zoomout.on_click(C.zoomout_callback)

    reset = Button(label='\u25ef', disabled=True)
    reset.on_click(C.zero_callback)

    panleft = Button(label='\u2190', disabled=True)
    panleft.on_click(C.panleft_callback)

    panright = Button(label='\u2192', disabled=True)
    panright.on_click(C.panright_callback)

    allleft = Button(label='\u21e4', disabled=True)
    allleft.on_click(C.allleft_callback)

    allout = Button(label='\u2913', disabled=True)
    allout.on_click(C.allout_callback)

    allright = Button(label='\u21e5', disabled=True)
    allright.on_click(C.allright_callback)

    save_indicator = Button(label='0')

    label_count_callbacks=[]
    label_count_widgets=[]
    label_text_callbacks=[]
    label_text_widgets=[]

    for i in range(M.nlabels):
        label_count_callbacks.append(lambda i=i: C.label_count_callback(i))
        label_count_widgets.append(Button(label='0', css_classes=['hide-label'], width=40))
        label_count_widgets[-1].on_click(label_count_callbacks[-1])

        label_text_callbacks.append(lambda a,o,n,i=i: C.label_text_callback(n,i))
        label_text_widgets.append(TextInput(value=M.state['labels'][i],
                                            css_classes=['hide-label']))
        label_text_widgets[-1].on_change("value", label_text_callbacks[-1])

    C.label_count_callback(M.ilabel)

    play = Button(label='play', disabled=True)
    play_callback = CustomJS(args=dict(waveform_span_red=waveform_span_red,
                                       spectrogram_span_red=spectrogram_span_red,
                                       probability_span_red=probability_span_red),
                             code=C.play_callback_code % ("",""))
    play.js_on_event(ButtonClick, play_callback)
    play.on_change('disabled', lambda a,o,n: reset_video())
    play.js_on_change('disabled', play_callback)

    video_toggle = Toggle(label='video', active=False, disabled=True)
    video_toggle.on_click(lambda x: context_update())

    video_div = Div(text=""" < video id="context_video"> < /video>""", width=0, height=0)

    undo = Button(label='undo', disabled=True)
    undo.on_click(C.undo_callback)

    redo = Button(label='redo', disabled=True)
    redo.on_click(C.redo_callback)

    detect = Button(label='detect')
    detect.on_click(lambda: C.action_callback(detect, C.detect_actuate))

    misses = Button(label='misses')
    misses.on_click(lambda: C.action_callback(misses, C.misses_actuate))

    train = Button(label='train')
    train.on_click(lambda: C.action_callback(train, C.train_actuate))

    leaveoneout = Button(label='omit one')
    leaveoneout.on_click(lambda: C.action_callback(leaveoneout,
                                                   lambda: C.leaveout_actuate(False)))

    leaveallout = Button(label='omit all')
    leaveallout.on_click(lambda: C.action_callback(leaveallout,
                                                   lambda: C.leaveout_actuate(True)))

    xvalidate = Button(label='x-validate')
    xvalidate.on_click(lambda: C.action_callback(xvalidate, C.xvalidate_actuate))

    mistakes = Button(label='mistakes')
    mistakes.on_click(lambda: C.action_callback(mistakes, C.mistakes_actuate))

    activations = Button(label='activations')
    activations.on_click(lambda: C.action_callback(activations, C.activations_actuate))

    cluster = Button(label='cluster')
    cluster.on_click(lambda: C.action_callback(cluster, C.cluster_actuate))

    visualize = Button(label='visualize')
    visualize.on_click(lambda: C.action_callback(visualize, C.visualize_actuate))

    accuracy = Button(label='accuracy')
    accuracy.on_click(lambda: C.action_callback(accuracy, C.accuracy_actuate))

    freeze = Button(label='freeze')
    freeze.on_click(lambda: C.action_callback(freeze, C.freeze_actuate))

    classify = Button(label='classify')
    classify.on_click(C.classify_callback)

    ethogram = Button(label='ethogram')
    ethogram.on_click(lambda: C.action_callback(ethogram, C.ethogram_actuate))

    compare = Button(label='compare')
    compare.on_click(lambda: C.action_callback(compare, C.compare_actuate))

    congruence = Button(label='congruence')
    congruence.on_click(lambda: C.action_callback(congruence, C.congruence_actuate))

    status_ticker_pre=" < div style='overflow:auto; white-space:nowrap; width:"+str(M.gui_width_pix-126)+"px'>status: "
    status_ticker_post=" < /div>"
    status_ticker = Div(text=status_ticker_pre+status_ticker_post)

    file_dialog_source = ColumnDataSource(data=dict(names=[], sizes=[], dates=[]))
    file_dialog_source.selected.on_change('indices', C.file_dialog_callback)

    file_dialog_columns = [
        TableColumn(field="names", title="Name", width=M.gui_width_pix//2-50-115-30),
        TableColumn(field="sizes", title="Size", width=50, \
                    formatter=NumberFormatter(format="0 b")),
        TableColumn(field="dates", title="Date", width=115, \
                    formatter=DateFormatter(format="%Y-%m-%d %H:%M:%S")),
    ]
    file_dialog_table = DataTable(source=file_dialog_source, \
                                  columns=file_dialog_columns, \
                                  height=727, width=M.gui_width_pix//2-11, \
                                  index_position=None,
                                  fit_columns=False)

    waitfor = Toggle(label='wait for last job', active=False, disabled=True, width=100)
    waitfor.on_click(C.waitfor_callback)

    logs = Button(label='logs folder:', width=110)
    logs.on_click(C.logs_callback)
    logs_folder = TextInput(value=M.state['logs'], title="", disabled=False)
    logs_folder.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    model = Button(label='checkpoint file:', width=110)
    model.on_click(C.model_callback)
    model_file = TextInput(value=M.state['model'], title="", disabled=False)
    model_file.on_change('value', model_file_update)

    wavtfcsvfiles = Button(label='wav,tf,csv files:', width=110)
    wavtfcsvfiles.on_click(C.wavtfcsvfiles_callback)
    wavtfcsvfiles_string = TextInput(value=M.state['wavtfcsvfiles'], title="", disabled=False)
    wavtfcsvfiles_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    groundtruth = Button(label='ground truth:', width=110)
    groundtruth.on_click(C.groundtruth_callback)
    groundtruth_folder = TextInput(value=M.state['groundtruth'], title="", disabled=False)
    groundtruth_folder.on_change('value', lambda a,o,n: groundtruth_update())

    validationfiles = Button(label='validation files:', width=110)
    validationfiles.on_click(C.validationfiles_callback)
    validationfiles_string = TextInput(value=M.state['validationfiles'], title="", disabled=False)
    validationfiles_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    testfiles = Button(label='test files:', width=110)
    testfiles.on_click(C.testfiles_callback)
    testfiles_string = TextInput(value=M.state['testfiles'], title="", disabled=False)
    testfiles_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    wantedwords = Button(label='wanted words:', width=110)
    wantedwords.on_click(C.wantedwords_callback)
    wantedwords_string = TextInput(value=M.state['wantedwords'], title="", disabled=False)
    wantedwords_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    labeltypes = Button(label='label types:', width=110)
    labeltypes_string = TextInput(value=M.state['labeltypes'], title="", disabled=False)
    labeltypes_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    prevalences = Button(label='prevalences:', width=110)
    prevalences.on_click(C.prevalences_callback)
    prevalences_string = TextInput(value=M.state['prevalences'], title="", disabled=False)
    prevalences_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    copy = Button(label='copy')
    copy.on_click(C.copy_callback)

    labelsounds = Button(label='label sounds')
    labelsounds.on_click(lambda: C.wizard_callback(labelsounds))

    makepredictions = Button(label='make predictions')
    makepredictions.on_click(lambda: C.wizard_callback(makepredictions))

    fixfalsepositives = Button(label='fix false positives')
    fixfalsepositives.on_click(lambda: C.wizard_callback(fixfalsepositives))

    fixfalsenegatives = Button(label='fix false negatives')
    fixfalsenegatives.on_click(lambda: C.wizard_callback(fixfalsenegatives))

    generalize = Button(label='test generalization')
    generalize.on_click(lambda: C.wizard_callback(generalize))

    tunehyperparameters = Button(label='tune h-parameters')
    tunehyperparameters.on_click(lambda: C.wizard_callback(tunehyperparameters))

    findnovellabels = Button(label='find novel labels')
    findnovellabels.on_click(lambda: C.wizard_callback(findnovellabels))

    examineerrors = Button(label='examine errors')
    examineerrors.on_click(lambda: C.wizard_callback(examineerrors))

    testdensely = Button(label='test densely')
    testdensely .on_click(lambda: C.wizard_callback(testdensely))

    doit = Button(label='do it!', disabled=True)
    doit.on_click(C.doit_callback)

    time_sigma_string = TextInput(value=M.state['time_sigma'], \
                                  title="time σ", \
                                  disabled=False)
    time_sigma_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    time_smooth_ms_string = TextInput(value=M.state['time_smooth_ms'], \
                                      title="time smooth", \
                                      disabled=False)
    time_smooth_ms_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    frequency_n_ms_string = TextInput(value=M.state['frequency_n_ms'], \
                                      title="freq N (msec)", \
                                      disabled=False)
    frequency_n_ms_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    frequency_nw_string = TextInput(value=M.state['frequency_nw'], \
                                    title="freq NW", \
                                    disabled=False)
    frequency_nw_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    frequency_p_string = TextInput(value=M.state['frequency_p'], \
                                   title="freq ρ", \
                                   disabled=False)
    frequency_p_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    frequency_smooth_ms_string = TextInput(value=M.state['frequency_smooth_ms'], \
                                           title="freq smooth", \
                                           disabled=False)
    frequency_smooth_ms_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    nsteps_string = TextInput(value=M.state['nsteps'], title="# steps", disabled=False)
    nsteps_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    restore_from_string = TextInput(value=M.state['restore_from'], title="restore from", disabled=False)
    restore_from_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    save_and_validate_period_string = TextInput(value=M.state['save_and_validate_interval'], \
                                                title="validate period", \
                                                disabled=False)
    save_and_validate_period_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    validate_percentage_string = TextInput(value=M.state['validate_percentage'], \
                                           title="validate %", \
                                           disabled=False)
    validate_percentage_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    mini_batch_string = TextInput(value=M.state['mini_batch'], \
                                  title="mini-batch", \
                                  disabled=False)
    mini_batch_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    kfold_string = TextInput(value=M.state['kfold'], title="k-fold",  disabled=False)
    kfold_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    activations_equalize_ratio_string = TextInput(value=M.state['activations_equalize_ratio'], \
                                             title="equalize ratio", \
                                             disabled=False)
    activations_equalize_ratio_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    activations_max_samples_string = TextInput(value=M.state['activations_max_samples'], \
                                          title="max samples", \
                                          disabled=False)
    activations_max_samples_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    pca_fraction_variance_to_retain_string = TextInput(value=M.state['pca_fraction_variance_to_retain'], \
                                                       title="PCA fraction", \
                                                       disabled=False)
    pca_fraction_variance_to_retain_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    tsne_perplexity_string = TextInput(value=M.state['tsne_perplexity'], \
                                       title="perplexity", \
                                       disabled=False)
    tsne_perplexity_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    tsne_exaggeration_string = TextInput(value=M.state['tsne_exaggeration'], \
                                        title="exaggeration", \
                                        disabled=False)
    tsne_exaggeration_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    umap_neighbors_string = TextInput(value=M.state['umap_neighbors'], \
                                      title="neighbors", \
                                      disabled=False)
    umap_neighbors_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    umap_distance_string = TextInput(value=M.state['umap_distance'], \
                                     title="distance", \
                                     disabled=False)
    umap_distance_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    precision_recall_ratios_string = TextInput(value=M.state['precision_recall_ratios'], \
                                               title="P/Rs", \
                                               disabled=False)
    precision_recall_ratios_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))
    
    context_ms_string = TextInput(value=M.state['context_ms'], \
                                  title="context (msec)", \
                                  disabled=False)
    context_ms_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    shiftby_ms_string = TextInput(value=M.state['shiftby_ms'], \
                                  title="shift by (msec)", \
                                  disabled=False)
    shiftby_ms_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    representation = Select(title="representation", height=50, \
                            value=M.state['representation'], \
                            options=["waveform", "spectrogram", "mel-cepstrum"])
    representation.on_change('value', lambda a,o,n: C.generic_parameters_callback(''))

    window_ms_string = TextInput(value=M.state['window_ms'], \
                                 title="window (msec)", \
                                 disabled=False)
    window_ms_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    stride_ms_string = TextInput(value=M.state['stride_ms'], \
                                 title="stride (msec)", \
                                 disabled=False)
    stride_ms_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    mel_dct_string = TextInput(value=M.state['mel&dct'], \
                               title="Mel & DCT", \
                               disabled=False)
    mel_dct_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    optimizer = Select(title="optimizer", height=50, \
                       value=M.state['optimizer'], \
                       options=[("sgd","SGD"), ("adam","Adam"), ("adagrad","AdaGrad"), \
                                ("rmsprop","RMSProp")])
    optimizer.on_change('value', lambda a,o,n: C.generic_parameters_callback(''))

    learning_rate_string = TextInput(value=M.state['learning_rate'], \
                                     title="learning rate", \
                                     disabled=False)
    learning_rate_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    model_parameters = OrderedDict()
    for parameter in M.model_parameters:
      if parameter[2]=='':
        thisparameter = TextInput(value=M.state[parameter[0]], \
                                  title=parameter[1], \
                                  disabled=False, width=94)
        thisparameter.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))
      else:
        thisparameter = Select(value=M.state[parameter[0]], \
                               title=parameter[1], \
                               options=parameter[2], \
                               height=50, width=94)
      model_parameters[parameter[0]] = thisparameter

    configuration_contents = TextAreaInput(rows=49-3*np.ceil(len(model_parameters)/6).astype(np.int),
                                           max_length=50000, \
                                           disabled=True, css_classes=['fixedwidth'])
    if M.configuration_file:
        with open(M.configuration_file, 'r') as fid:
            configuration_contents.value = fid.read()


    cluster_algorithm = Select(title="cluster", height=50, \
                               value=M.state['cluster_algorithm'], \
                               options=["PCA 2D", "PCA 3D", \
                                        "tSNE 2D", "tSNE 3D", \
                                        "UMAP 2D", "UMAP 3D"])
    cluster_algorithm.on_change('value', lambda a,o,n: C.generic_parameters_callback(''))

    cluster_these_layers = MultiSelect(title='layers', \
                                       value=M.state['cluster_these_layers'], \
                                       options=[],
                                       height=108)
    cluster_these_layers.on_change('value', lambda a,o,n: C.generic_parameters_callback(''))
    cluster_these_layers_update()

    replicates_string = TextInput(value=M.state['replicates'], \
                                  title="replicates", \
                                  disabled=False)
    replicates_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    batch_seed_string = TextInput(value=M.state['batch_seed'], \
                                  title="batch seed", \
                                  disabled=False)
    batch_seed_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    weights_seed_string = TextInput(value=M.state['weights_seed'], \
                                    title="weights seed", \
                                    disabled=False)
    weights_seed_string.on_change('value', lambda a,o,n: C.generic_parameters_callback(n))

    file_dialog_string = TextInput(disabled=False)
    file_dialog_string.on_change("value", C.file_dialog_path_callback)
    file_dialog_string.value = M.state['file_dialog_string']
     
    with open(os.path.join(os.path.dirname(os.path.realpath(__file__)),'..','..','README.md'), 'r', encoding='utf-8') as fid:
        contents = fid.read()
    html = markdown.markdown(contents, extensions=['tables','toc'])
    readme_contents = Div(text=html, style={'overflow':'scroll','width':'600px','height':'1397px'})

    wordcounts = Div(text="")
    wordcounts_update()

    wizard_buttons = set([
        labelsounds,
        makepredictions,
        fixfalsepositives,
        fixfalsenegatives,
        generalize,
        tunehyperparameters,
        findnovellabels,
        examineerrors,
        testdensely])

    action_buttons = set([
        detect,
        train,
        leaveoneout,
        leaveallout,
        xvalidate,
        mistakes,
        activations,
        cluster,
        visualize,
        accuracy,
        freeze,
        classify,
        ethogram,
        misses,
        compare,
        congruence])

    parameter_buttons = set([
        logs,
        model,
        wavtfcsvfiles,
        groundtruth,
        validationfiles,
        testfiles,
        wantedwords,
        labeltypes,
        prevalences])

    parameter_textinputs = set([
        logs_folder,
        model_file,
        wavtfcsvfiles_string,
        groundtruth_folder,
        validationfiles_string,
        testfiles_string,
        wantedwords_string,
        labeltypes_string,
        prevalences_string,

        time_sigma_string,
        time_smooth_ms_string,
        frequency_n_ms_string,
        frequency_nw_string,
        frequency_p_string,
        frequency_smooth_ms_string,
        nsteps_string,
        restore_from_string,
        save_and_validate_period_string,
        validate_percentage_string,
        mini_batch_string,
        kfold_string,
        activations_equalize_ratio_string,
        activations_max_samples_string,
        pca_fraction_variance_to_retain_string,
        tsne_perplexity_string,
        tsne_exaggeration_string,
        umap_neighbors_string,
        umap_distance_string,
        cluster_algorithm,
        cluster_these_layers,
        precision_recall_ratios_string,
        replicates_string,
        batch_seed_string,
        weights_seed_string,

        context_ms_string,
        shiftby_ms_string,
        representation,
        window_ms_string,
        stride_ms_string,
        mel_dct_string,
        optimizer,
        learning_rate_string] +

        list(model_parameters.values()))

    wizard2actions = {
            labelsounds: [detect,train,activations,cluster,visualize],
            makepredictions: [train, accuracy, freeze, classify, ethogram],
            fixfalsepositives: [activations, cluster, visualize],
            fixfalsenegatives: [detect, misses, activations, cluster, visualize],
            generalize: [leaveoneout, leaveallout, accuracy],
            tunehyperparameters: [xvalidate, accuracy, compare],
            findnovellabels: [detect, train, activations, cluster, visualize],
            examineerrors: [detect, mistakes, activations, cluster, visualize],
            testdensely: [detect, activations, cluster, visualize, classify, ethogram, congruence],
            None: action_buttons }

    action2parameterbuttons = {
            detect: [wavtfcsvfiles],
            train: [logs, groundtruth, wantedwords, testfiles, labeltypes],
            leaveoneout: [logs, groundtruth, validationfiles, testfiles, wantedwords, labeltypes],
            leaveallout: [logs, groundtruth, validationfiles, testfiles, wantedwords, labeltypes],
            xvalidate: [logs, groundtruth, testfiles, wantedwords, labeltypes],
            mistakes: [groundtruth],
            activations: [logs, model, groundtruth, wantedwords, labeltypes],
            cluster: [groundtruth],
            visualize: [groundtruth],
            accuracy: [logs],
            freeze: [logs, model],
            classify: [logs, model, wavtfcsvfiles, wantedwords, prevalences],
            ethogram: [model, wavtfcsvfiles],
            misses: [wavtfcsvfiles],
            compare: [logs],
            congruence: [groundtruth, validationfiles, testfiles],
            None: parameter_buttons }

    action2parametertextinputs = {
            detect: [wavtfcsvfiles_string, time_sigma_string, time_smooth_ms_string, frequency_n_ms_string, frequency_nw_string, frequency_p_string, frequency_smooth_ms_string],
            train: [context_ms_string, shiftby_ms_string, representation, window_ms_string, stride_ms_string, mel_dct_string, optimizer, learning_rate_string, replicates_string, batch_seed_string, weights_seed_string, logs_folder, groundtruth_folder, testfiles_string, wantedwords_string, labeltypes_string, nsteps_string, restore_from_string, save_and_validate_period_string, validate_percentage_string, mini_batch_string] + list(model_parameters.values()),
            leaveoneout: [context_ms_string, shiftby_ms_string, representation, window_ms_string, stride_ms_string, mel_dct_string, optimizer, learning_rate_string, batch_seed_string, weights_seed_string, logs_folder, groundtruth_folder, validationfiles_string, testfiles_string, wantedwords_string, labeltypes_string, nsteps_string, restore_from_string, save_and_validate_period_string, mini_batch_string] + list(model_parameters.values()),
            leaveallout: [context_ms_string, shiftby_ms_string, representation, window_ms_string, stride_ms_string, mel_dct_string, optimizer, learning_rate_string, batch_seed_string, weights_seed_string, logs_folder, groundtruth_folder, validationfiles_string, testfiles_string, wantedwords_string, labeltypes_string, nsteps_string, restore_from_string, save_and_validate_period_string, mini_batch_string] + list(model_parameters.values()),
            xvalidate: [context_ms_string, shiftby_ms_string, representation, window_ms_string, stride_ms_string, mel_dct_string, optimizer, learning_rate_string, batch_seed_string, weights_seed_string, logs_folder, groundtruth_folder, testfiles_string, wantedwords_string, labeltypes_string, nsteps_string, restore_from_string, save_and_validate_period_string, mini_batch_string, kfold_string] + list(model_parameters.values()),
            mistakes: [groundtruth_folder],
            activations: [context_ms_string, shiftby_ms_string, representation, window_ms_string, stride_ms_string, mel_dct_string, logs_folder, model_file, groundtruth_folder, wantedwords_string, labeltypes_string, activations_equalize_ratio_string, activations_max_samples_string, mini_batch_string] + list(model_parameters.values()),
            cluster: [groundtruth_folder, cluster_algorithm, cluster_these_layers, pca_fraction_variance_to_retain_string, tsne_perplexity_string, tsne_exaggeration_string, umap_neighbors_string, umap_distance_string],
            visualize: [groundtruth_folder],
            accuracy: [logs_folder, precision_recall_ratios_string],
            freeze: [context_ms_string, representation, window_ms_string, stride_ms_string, mel_dct_string, logs_folder, model_file] + list(model_parameters.values()),
            classify: [context_ms_string, shiftby_ms_string, representation, stride_ms_string, logs_folder, model_file, wavtfcsvfiles_string, wantedwords_string, prevalences_string] + list(model_parameters.values()),
            ethogram: [model_file, wavtfcsvfiles_string],
            misses: [wavtfcsvfiles_string],
            compare: [logs_folder],
            congruence: [groundtruth_folder, validationfiles_string, testfiles_string],
            None: parameter_textinputs }