import os
import argparse
import numpy as np
import matplotlib.pyplot as plt
from collections import OrderedDict, defaultdict
from scipy.signal import medfilt

import seaborn as sns

sns.set_style('ticks')
cmap = sns.color_palette()


def tsplot(data, label=None, **kw):
    x = np.arange(data.shape[1])
    est = np.mean(data, axis=0)
    sd = np.std(data, axis=0)
    cis = (est - sd, est + sd)
    plt.fill_between(x, cis[0], cis[1], alpha=0.2, **kw)
    plt.plot(x, est, label=label, **kw)
    plt.margins(x=0)


def prep_data(data):
    x = np.arange(len(data))[np.isfinite(data)]
    data = np.array(data)[np.isfinite(data)]
    return x, data


if __name__ == '__main__':
    """
    Plot model evolution
    """

    # add argument parser
    parser = argparse.ArgumentParser(description='Show evaluation plot.')
    parser.add_argument('results', metavar='N', type=str, nargs='+', help='result.npy files.')
    parser.add_argument('--max_epoch', type=int, default=None, help='last epoch to plot.')
    parser.add_argument('--ymin', help='minimum y value.', type=float, default=None)
    parser.add_argument('--ymax', help='maximum y value.', type=float, default=None)
    parser.add_argument('--watch', help='refresh plot.', type=int, default=None)
    parser.add_argument('--key', help='key for evaluation.', type=str, default="loss")
    parser.add_argument('--max', help='used for highlighting the best value.', action='store_true')
    parser.add_argument('--folds_avg', help='plot average of different folds.', action='store_true')
    parser.add_argument('--list_keys', help='list available keys.', action='store_true')
    parser.add_argument('--logarithmic', help='logarithmic scaled y-axis.', action='store_true')
    parser.add_argument('--smooth', help='apply outlier smoothing in result plots.', action='store_true')
    parser.add_argument('--strip_legend', help='strip legend log files strings.', action='store_true')
    args = parser.parse_args()

    # define how to find best validation epoch
    best_fun = np.argmax if args.max else np.argmin
    va = "bottom" if args.max else "top"

    while True:

        # load results
        all_results = OrderedDict()
        fold_results = OrderedDict()
        for result in np.sort(args.results):
            dir_name = result.split(os.sep)[-2]
            exp_name = result.split(os.sep)[-1].split('.npy')[0]
            exp_name = '_'.join([dir_name, exp_name])

            exp_res = np.load(result, allow_pickle=True).item(0)
            all_results[exp_name] = exp_res

            # list available keys
            if args.list_keys:
                for key in exp_res.keys():
                    print(key)
                exit(0)

            # collect results for fold averaging
            if dir_name not in fold_results:
                fold_results[dir_name] = defaultdict(list)

            for key in exp_res.keys():
                fold_results[dir_name][key].append(exp_res[key])

        # collect fold results
        if args.folds_avg:
            for model in fold_results.keys():
                for key in fold_results[model].keys():
                    min_samples = np.min([len(r) for r in fold_results[model][key]])
                    fold_results[model][key] = np.asarray([r[0:min_samples] for r in fold_results[model][key]])

            all_results = fold_results

        # strip long legend strings
        if args.strip_legend:
            log_strings = [k for k in all_results.keys()]
            commonprefix = os.path.commonprefix(log_strings)
            for key, value in all_results.copy().items():
                new_key = key.replace(commonprefix, "")
                all_results[new_key] = all_results[key]
                all_results.pop(key)

        # present results
        fig = plt.figure("Model Evolution")
        plt.clf()
        ax = plt.subplot(111)
        plt.subplots_adjust(bottom=0.15, left=0.15, right=0.9, top=0.95)

        for i, (exp_name, exp_res) in enumerate(all_results.items()):

            key_tr = 'tr_' + args.key
            key_va = 'va_' + args.key
            label = args.key
            suf_tr, suf_va = "_tr", "_va"

            if key_va not in exp_res.keys() and args.key in exp_res.keys():
                key_va = args.key
                suf_tr, suf_va = "", ""

            # special case: running train losses
            un_running_keys = [k.replace("_running", "") for k in exp_res.keys() if "running" in k]
            if key_tr not in exp_res.keys() and key_tr in un_running_keys:
                key_tr = key_tr + "_running"
                label = "%s (running)" % label

            if args.folds_avg:
                if key_tr in exp_res:
                    tsplot(exp_res[key_tr], label=exp_name + suf_tr, color=cmap[i % len(cmap)], linewidth=2)
                if key_va in exp_res:
                    tsplot(exp_res[key_va], label=exp_name + suf_va, color=cmap[i % len(cmap)], linewidth=1)
            else:
                if key_tr in exp_res:
                    tr_x, tr_y = prep_data(exp_res[key_tr])
                    if args.smooth:
                        exp_res[key_tr] = medfilt(tr_y, kernel_size=5)
                        tr_y = exp_res[key_tr]
                    plt.plot(tr_x, tr_y, '-', color=cmap[i % len(cmap)], linewidth=3, alpha=0.6,
                             label=exp_name + suf_tr)
                if key_va in exp_res:
                    va_x, va_y = prep_data(exp_res[key_va])
                    if args.smooth:
                        exp_res[key_va] = medfilt(va_y, kernel_size=5)
                        va_y = exp_res[key_va]
                    plt.plot(va_x, va_y, '-', color=cmap[i % len(cmap)], linewidth=2, label=exp_name + suf_va)

            # highlight best epoch
            if key_va in exp_res:
                if not args.folds_avg:
                    best_value_idx = best_fun(exp_res[key_va])
                    best_value = exp_res[key_va][best_value_idx]
                else:
                    mean_vals = exp_res[key_va].mean(0)
                    best_value_idx = best_fun(mean_vals)
                    best_value = mean_vals[best_value_idx]

                plt.plot([0, best_value_idx], [best_value] * 2, '--', color=cmap[i % len(cmap)], alpha=0.5)
                plt.text(best_value_idx, best_value, ('%.5f' % best_value), va=va, ha='right',
                         color=cmap[i % len(cmap)])
                plt.plot(best_value_idx, best_value, 'o', color=cmap[i % len(cmap)])

        plt.ylabel(label.upper(), fontsize=16)
        plt.legend(loc="best", fontsize=10).set_draggable(True)

        if args.ymin is not None and args.ymax is not None:
            plt.ylim([args.ymin, args.ymax])

        if args.max_epoch is not None:
            plt.xlim([0, args.max_epoch])

        plt.xlabel("Epoch", fontsize=16)
        plt.grid(True)

        ax.tick_params(axis='x', labelsize=16)
        ax.tick_params(axis='y', labelsize=16)

        if args.logarithmic:
            ax.set_yscale("log", nonposy='clip')

        plt.draw()

        if args.watch:
            plt.pause(args.watch)
        else:
            plt.show(block=True)
            break
