python/oskopek/mvae/mt/visualization/generate_plots.py

generate_plots.py
# Copyright 2019 Ondrej Skopek.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import argparse
import glob
import itertools
import math
import os
from collections import defaultdict
from datetime import datetime
from typing import Dict, List, Optional, Tuple
import warnings

import bokeh.io
import bokeh.models
import bokeh.palettes
import bokeh.plotting
import bokeh.resources
import numpy as np
import pandas as pd

from . import read_log, utils

BASE_HEIGHT = 800
HEIGHT = 1500
WIDTH = 2 * HEIGHT

stat_title = {
    "ll": "Log Likelihood",
    "elbo": "Evidence Lower BOund",
    "kl": "Kullback-Leibler Divergence",
    "bce": "Binary Cross-Entropy",
    "mi": "Mutual Information",
    "cov_norm": "Frobenius norm of the cross-covariance matrix"
}

axis_title = {
    "ll": "ln(likelihood)",
    "elbo": "ELBO",
    "kl": "KL",
    "bce": "BCE",
    "mi": "MI",
    "cov_norm": "Cross-Covariance norm"
}


def set_font_size(p: bokeh.plotting.figure, font_size: int = 40, small_font_size: int = 35) -> None:
    font_size_s = f"{font_size}pt"
    small_font_size_s = f"{small_font_size}pt"
    if p.title is not None:
        p.title.text_font_size = font_size_s
        p.title.text_color = "black"

    p.xaxis.axis_line_width = 5
    p.xaxis.major_tick_line_width = 5
    p.xaxis.minor_tick_line_width = 5
    p.yaxis.axis_line_width = 5
    p.yaxis.major_tick_line_width = 5
    p.yaxis.minor_tick_line_width = 5

    p.xaxis.axis_label_text_font_size = font_size_s
    p.xaxis.major_label_text_color = "black"
    p.xaxis.axis_label_text_color = "black"
    p.xaxis.major_label_text_font_size = small_font_size_s
    p.yaxis.axis_label_text_font_size = font_size_s
    p.yaxis.major_label_text_font_size = small_font_size_s
    p.yaxis.major_label_text_color = "black"
    p.yaxis.axis_label_text_color = "black"
    if p.legend:
        p.legend.label_text_font_size = small_font_size_s
        p.legend.label_text_color = "black"
        p.legend.glyph_height = font_size
        p.legend.glyph_width = font_size


def export_plots(p: bokeh.plotting.figure,
                 filename: str,
                 title: str,
                 width: int = WIDTH,
                 height: int = HEIGHT,
                 box: bool = False,
                 show_title: bool = False,
                 y_range_start: Optional[float] = None,
                 y_range_end: Optional[float] = None) -> None:
    # HTML
    if not show_title:
        p.title = None
    bokeh.plotting.save(p, title=title, filename=filename + ".html", resources=bokeh.resources.CDN)

    # PNG
    if y_range_start:
        p.y_range.start = y_range_start
    if y_range_end:
        p.y_range.end = y_range_end

    set_font_size(p)
    p.sizing_mode = "fixed"
    p.width = width
    if box:
        p.height = width
    else:
        p.height = height
    p.toolbar_location = None
    bokeh.io.export_png(p, filename=filename + ".png", height=HEIGHT, width=WIDTH)

    # SVG:
    # p.output_backend = "svg"
    # bokeh.io.export_svgs(p, filename=filename + ".svg")
    #
    # os.system(f"inkscape --without-gui --export-pdf={filename}.pdf {filename}.svg")


def box_whiskers_plot(df: pd.DataFrame, out_folder: str, statistic: str = "ll", subtitle: str = "") -> None:
    title = f"{stat_title[statistic]} ({subtitle})"

    # find the quartiles and IQR for each category
    groups = df.groupby('model')
    q1 = groups.quantile(q=0.25)
    q2 = groups.quantile(q=0.5)
    q3 = groups.quantile(q=0.75)
    iqr = q3 - q1
    upper = q3 + 1.5 * iqr
    lower = q1 - 1.5 * iqr

    cats = sorted(df['model'].unique())
    p = bokeh.plotting.figure(title=title,
                              x_range=cats,
                              x_axis_label="Model",
                              tools="",
                              toolbar_location=None,
                              y_axis_label=axis_title[statistic],
                              plot_height=BASE_HEIGHT,
                              sizing_mode="stretch_width")

    p.xaxis.major_label_orientation = math.pi / 4  # math.pi / 6
    p.xaxis.major_label_standoff = 20
    p.yaxis.major_label_standoff = 15

    # if no outliers, shrink lengths of stems to be no longer than the minimums or maximums
    qmin = groups.quantile(q=0.00)
    qmax = groups.quantile(q=1.00)
    upper[statistic] = [min([x, y]) for (x, y) in zip(list(qmax.loc[:, statistic]), upper[statistic])]
    lower[statistic] = [max([x, y]) for (x, y) in zip(list(qmin.loc[:, statistic]), lower[statistic])]

    lw = 2

    # stems
    p.segment(cats, upper[statistic], cats, q3[statistic], line_color="black", line_width=lw)
    p.segment(cats, lower[statistic], cats, q1[statistic], line_color="black", line_width=lw)

    # boxes
    p.vbar(cats, 0.7, q2[statistic], q3[statistic], line_width=lw, fill_color="#E08E79", line_color="black")
    p.vbar(cats, 0.7, q1[statistic], q2[statistic], line_width=lw, fill_color="#3B8686", line_color="black")

    # whiskers (almost-0 height rects simpler than segments)
    p.rect(cats, lower[statistic], 0.2, 0.001, line_color="black", line_width=lw, fill_color="black")
    p.rect(cats, upper[statistic], 0.2, 0.001, line_color="black", line_width=lw, fill_color="black")

    # outliers
    def outliers(group: pd.DataFrame) -> pd.Series:
        cat = group.name
        return group[(group[statistic] > upper.loc[cat][statistic]) |
                     (group[statistic]  <  lower.loc[cat][statistic])][statistic]

    out = groups.apply(outliers).dropna()
    if not out.empty:
        outx = []
        outy = []
        for keys in out.index:
            outx.append(keys[0])
            outy.append(out[keys[0]][keys[1]])
        p.circle(outx, outy, size=6, color="#F38630", fill_alpha=0.6)

    p.xgrid.grid_line_color = None
    p.ygrid.grid_line_color = None

    export_plots(p, filename=os.path.join(out_folder, f"model_boxplot_{statistic}"), title=title, box=False)


def line_plot(run_table: pd.DataFrame,
              std_table: pd.DataFrame,
              out_folder: str,
              statistic: str = "ll",
              color_column: str = "run",
              y_range_start: Optional[float] = None,
              y_range_end: Optional[float] = None,
              subtitle: str = "") -> None:
    title = f"{stat_title[statistic]} ({subtitle})"

    # create a new plot
    tooltips = [('x', f'@x'), ("y", "@y"), ("label", "@label")]
    if std_table is not None:
        tooltips.append(("Stddev", f"@{statistic}_std"))
    p = bokeh.plotting.figure(tools="pan,crosshair,reset,save,wheel_zoom",
                              title=title,
                              x_axis_label="Epoch",
                              y_axis_label=axis_title[statistic],
                              tooltips=tooltips,
                              toolbar_location="above",
                              plot_height=BASE_HEIGHT,
                              sizing_mode="stretch_width")

    colors = bokeh.palettes.Category20[20]
    items = []
    for color, line in zip(colors, sorted(run_table[color_column].unique())):
        df = pd.DataFrame(run_table[run_table[color_column] == line])
        if std_table is not None:
            std = std_table[std_table[color_column] == line]
            std_key = f"{statistic}_std"
            df[std_key] = std[statistic]
            df["lower"] = df[statistic] - df[std_key]
            df["upper"] = df[statistic] + df[std_key]

        source = bokeh.models.ColumnDataSource(df)
        plotted_line = p.line(x="epoch", y=statistic, source=source, color=color, line_width=2)
        # plotted_points = p.scatter(x="epoch", y=statistic, source=source, color=color)
        renderers = [plotted_line]  # plotted_points
        if std_table is not None:
            band = bokeh.models.Band(base="epoch",
                                     lower="lower",
                                     upper="upper",
                                     source=source,
                                     level='underlay',
                                     fill_alpha=0.1,
                                     fill_color=color)
            p.add_layout(band)
            callback = bokeh.models.CustomJS(args=dict(band=band),
                                             code="""
            if (band.visible == false)
                band.visible = true;
            else
                band.visible = false; """)
            plotted_line.js_on_change('visible', callback)

        items.append(bokeh.models.LegendItem(label=line, renderers=renderers))

    legend = bokeh.models.Legend(items=items, location="center", orientation="vertical", click_policy="hide")
    p.add_layout(legend, "right")
    p.x_range.start = 0
    p.x_range.end = run_table["epoch"].max()

    # show the results
    export_plots(p,
                 filename=os.path.join(out_folder, f"{color_column}_lineplot_{statistic}"),
                 title=title,
                 y_range_start=y_range_start,
                 y_range_end=y_range_end)


def curvature_line_plot(run_table: pd.DataFrame,
                        std_table: pd.DataFrame,
                        out_folder: str,
                        color_column: str = "run",
                        subtitle: str = "") -> None:

    def filter_nan_cols(df: pd.DataFrame, ccval: str) -> pd.DataFrame:
        df = df[df[color_column] == ccval]
        df = df.dropna(axis=1, how="all")
        df = df[sorted([col for col in df if col.endswith("/curvature") or col.lower() == "epoch"])]
        return df

    tooltips = [('Run-Component', f'@label'), ("Epoch", "@epoch"), ("Curvature", f"@curvature")]
    if std_table is not None:
        tooltips.append(("StdDev", f"@curvature_std"))
    title = f"Learned curvature of components ({subtitle})"
    p = bokeh.plotting.figure(
        tools="pan,crosshair,reset,save,wheel_zoom",
        #  title=title,
        x_axis_label="Epoch",
        y_axis_label="Curvature",
        tooltips=tooltips,
        toolbar_location="above",
        plot_height=BASE_HEIGHT,
        sizing_mode="stretch_width")

    items = []
    colors = bokeh.palettes.Category20[20]
    color_iter = itertools.cycle(colors)

    for ccval in run_table[color_column].unique():
        if "fixed" in ccval:
            continue
        df = filter_nan_cols(run_table, ccval)
        if std_table is not None:
            std = filter_nan_cols(std_table, ccval)

        for curvature_col in [col for col in df if col.endswith("/curvature")]:
            label = f"{ccval}-{curvature_col[:curvature_col.find('/')]}"
            curvature_df = df.rename(index=str, columns={curvature_col: "curvature"})
            curvature_df['label'] = label
            if std_table is not None and curvature_col in std:
                std_key = f"curvature_std"
                curvature_std = std.rename(index=str, columns={curvature_col: std_key})
                curvature_df[std_key] = curvature_std[std_key]
                curvature_df["lower"] = curvature_df["curvature"] - curvature_df[std_key]
                curvature_df["upper"] = curvature_df["curvature"] + curvature_df[std_key]

            source = bokeh.models.ColumnDataSource(curvature_df)
            color = next(color_iter)
            plotted_line = p.line(x="epoch", y="curvature", source=source, color=color)
            plotted_points = p.scatter(x="epoch", y="curvature", source=source, color=color)
            renderers = [plotted_line, plotted_points]
            if std_table is not None and curvature_col in std:
                band = bokeh.models.Band(base="epoch",
                                         lower="lower",
                                         upper="upper",
                                         source=source,
                                         level='underlay',
                                         fill_alpha=0.1,
                                         fill_color=color)
                p.add_layout(band)
                callback = bokeh.models.CustomJS(args=dict(band=band),
                                                 code="""
                if (band.visible == false)
                    band.visible = true;
                else
                    band.visible = false; """)
                plotted_line.js_on_change('visible', callback)

            items.append(bokeh.models.LegendItem(label=label, renderers=renderers))

    legend = bokeh.models.Legend(items=items, location="center", orientation="vertical", click_policy="hide")
    p.add_layout(legend, "right")
    p.x_range.start = 0
    p.x_range.end = run_table["epoch"].max()

    export_plots(p, filename=os.path.join(out_folder, f"{color_column}_lineplot_curvature"), title=title)


def models_latex_table(mean: pd.DataFrame, std: pd.DataFrame, show_curvature: bool = False) -> str:

    def _last(df: pd.DataFrame) -> pd.DataFrame:
        return df.groupby("model").last().reset_index()

    model_mean = _last(mean)
    model_std = _last(std)
    rows = [[f"{x} ${utils.texify_components(x)}$" for x in model_mean["model"]]]
    for key in ["ll"]:  # , "elbo", "bce", "kl"]:  # "mi", "cov_norm"]:
        if key not in model_mean.keys():
            warnings.warn(f"Key {key} not in data frame, skipping.")
            continue
        rows.append([])
        for m, s in zip(model_mean[key], model_std[key]):
            rows[-1].append(f"${m:0.2f}$" + "{\\scriptsize" + f"$\\pm {s:0.2f}$" + "}")

    def filter_nan_cols(df: pd.DataFrame, model: str) -> Tuple[np.ndarray, ...]:
        df = df[df["model"] == model]
        df = df.dropna(axis=1)
        npa = df[sorted([col for col in df if col.endswith("/curvature")])].to_numpy()
        return tuple(*npa)

    def print_tuple(t: Tuple[np.ndarray, ...], fmt: str = "0.3f") -> str:
        s = ", ".join((("{:" + fmt + "}").format(x) for x in t))
        return f"({s})"

    if show_curvature:
        rows.append([])
        for model in model_mean["model"]:
            model_mean_ = filter_nan_cols(model_mean, model)
            model_std_ = filter_nan_cols(model_std, model)

            if len(model_mean_) > 0:
                rows[-1].append(f"${print_tuple(model_mean_)} \\pm {print_tuple(model_std_)}$")

        curvature = True
        if not rows[-1]:
            curvature = False
            del rows[-1]

    rows_transposed = sorted(list(zip(*rows)), key=lambda x: x[0])
    rows_str = "\n".join(" & ".join(str(element) for element in row) + "\\\\" for row in rows_transposed)

    if show_curvature and curvature:
        return """
      \\begin{tabular}{l|rrrrrl}
        \\toprule
        \\textbf{Model} & LL & ELBO & BCE & KL & Curvature\\\\
        \\midrule
        """ + rows_str + """
        \\bottomrule
      \\end{tabular}
        """
    else:
        return """
              \\begin{tabular}{l|rrrrr}
                \\toprule
                \\textbf{Model} & LL & ELBO & BCE & KL \\\\
                \\midrule
                """ + rows_str + """
                \\bottomrule
              \\end{tabular}
                """


def merge_runs(runs: Dict[str, List[Tuple[datetime, pd.DataFrame]]]) -> pd.DataFrame:
    total_dfs = []
    for model, model_runs in runs.items():
        model_dfs = []
        for i, (time, df) in enumerate(sorted(model_runs, key=lambda x: x[0])):
            df["time"] = [time] * len(df)
            df["run"] = [f"{model} (run {i})"] * len(df)
            model_dfs.append(df)
        model_df = pd.concat(model_dfs, sort=True)
        model_df["model"] = [model] * len(model_df)
        total_dfs.append(model_df)
    return pd.concat(total_dfs, sort=False)


def mean_models(run_table: pd.DataFrame, by: List[str] = ["epoch", "model"]) -> pd.DataFrame:
    return run_table.groupby(by=by).mean().reset_index()


def std_models(run_table: pd.DataFrame, by: List[str] = ["epoch", "model"]) -> pd.DataFrame:
    return run_table.groupby(by=by).std().reset_index()


def read_runs(pattern: str) -> Dict[str, List[Tuple[datetime, pd.DataFrame, pd.DataFrame]]]:
    """
    :param pattern: the glob of logs to read.
    :return: A dictionary of modelname: [(time, eval_df, train_df)]
    """
    runs: Dict[str, List[Tuple[datetime, pd.DataFrame, pd.DataFrame]]] = defaultdict(list)
    for filename in glob.iglob(pattern, recursive=True):
        try:
            model_name, time, df, df_train = read_log.log_to_pd(filename)
        except NameError:  # For NaNs
            print(f"Couldn't read file '{filename}', skipping (probably a failed run).")
            continue

        print(f"Found model: '{model_name}' at time '{time}'.")
        runs[model_name].append((time, df, df_train))
    if not runs:
        raise ValueError(f"No runs found for glob '{pattern}'.")
    return runs


def early_stopping(
        dft: pd.DataFrame,
        dfe: pd.DataFrame,
        lookahead: int = 50,
        warmup: int = 500,  # don't do ES, we already do it during training
        stat: str = "elbo",
        difference: float = 0.0) -> pd.DataFrame:
    for run in dft["run"].unique():
        if "u" in run:
            continue  # Ignore universal model runs.
        r = dft[dft["run"] == run].set_index("epoch").reset_index()

        epochs = list(r["epoch"])
        stats = list(r[stat])
        assert len(epochs) == len(stats)

        should_stop = max(zip(epochs, stats), key=lambda x: x[1])[0]
        for epoch, val in zip(epochs, stats):
            if epoch  <  warmup:
                continue
            if epoch + lookahead >= len(stats):
                break
            if val >= max(stats[epoch + 1:epoch + lookahead + 1]) + difference:
                should_stop = epoch
                break
        # remove the runs of run which are above should_stop epochs
        dfe = dfe[(dfe["epoch"]  < = should_stop) | (dfe["run"] != run)]

    return dfe


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--glob", type=str, default="lsf.*", help="Which folder to search for log files.")
    parser.add_argument("--plot", type=str, default="runs", help="Type of plots ('runs', 'models').")
    parser.add_argument("--out_dir", type=str, default="./plots", help="Output dir.")
    parser.add_argument("--exp", type=str, default=None)
    parser.add_argument("--statistics", type=str, default="ll,mi,bce")
    args = parser.parse_args()

    statistics = [stat.strip() for stat in args.statistics.split(",")]

    bounds = {
        "ll": {
            # const
            "bdp_z6_const_1000": (-55., -60.1),
            "mnist_z6_const_200": (-96.5, -101.5),
            "mnist_z6_const_300": (-95.5, -99.0),
            "mnist_z6_const_500": (-95.5, -99.0),
            "mnist_z15_const_200": (-77., -87.1),
            "mnist_z30_const_200": (-75., -86.1),
            # learn
            "bdp_z6_learn_1000": (-55., -60.1),
            "mnist_z6_learn_300": (-95.5, -99.0)
        },
        "mi": {
            # const
            "bdp_z6_const_1000": (7., 3.),
            "mnist_z6_const_200": (13., 9.),
            "mnist_z6_const_300": (13., 9.),
            "mnist_z6_const_500": (13., 9.),
            "mnist_z15_const_200": (21., 16.),
            "mnist_z30_const_200": (26., 17.),
            # learn
            "bdp_z6_learn_1000": (7., 3.),
            "mnist_z6_learn_300": (13.0, 9.)
        }
    }

    if not os.path.exists(args.out_dir):
        os.makedirs(args.out_dir)

    runs = read_runs(args.glob)
    df = merge_runs({k: [(dt, dfev) for dt, dfev, _ in runs[k]] for k in runs})
    df_train = merge_runs({k: [(dt, dftr) for dt, _, dftr in runs[k]] for k in runs})

    if args.plot == "runs":  # Single runs.
        color_column = "run"
        subtitle = "Run comparison"
        std = None
    elif args.plot == "models":  # Average runs of a model.
        subtitle = "Comparison across runs"

        df = early_stopping(df_train, df, warmup=100, lookahead=50, stat="elbo", difference=0.0)
        by_runs = df.groupby("run").last().reset_index()
        for statistic in statistics:
            box_whiskers_plot(by_runs, args.out_dir, statistic=statistic, subtitle=subtitle)

        std = std_models(by_runs, by=["model"])
        df = mean_models(by_runs, by=["model"])
        color_column = "model"

        table_str = models_latex_table(df, std, show_curvature=False)
        table_path = os.path.join(args.out_dir, "model_recon_table.tex")
        with open(table_path, "w") as f:
            print(table_str, file=f)
        print("Saved to", table_path)
    else:
        raise NotImplementedError(f"Invalid plot type {args.plot}.")

    for statistic in statistics:
        start = None
        end = None
        if statistic in bounds:
            if args.exp in bounds[statistic]:
                end, start = bounds[statistic][args.exp]
        line_plot(df,
                  std,
                  args.out_dir,
                  statistic=statistic,
                  color_column=color_column,
                  subtitle=subtitle,
                  y_range_start=start,
                  y_range_end=end)

    # Filter weird numbers
    for run in df_train["run"].unique():
        if (df_train[df_train["run"] == run].loc[:, df_train.columns.str.endswith("curvature")].abs() >
                1e1).any().any():
            print("Removing", run)
            df_train = df_train[df_train["run"] != run]

    dft_mean = mean_models(df_train)
    dft_std = std_models(df_train)
    curvature_line_plot(dft_mean, dft_std, args.out_dir, color_column=color_column, subtitle=subtitle)


if __name__ == "__main__":
    main()