import numpy as np
import dataloader, util
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
import stats

sns.set_style('whitegrid')

stat_meta = {
    'mr_time_steps': [100],
    'fail_thresholds': [5, 10],
    'sf_time_steps': [50, 80]
    }

def get_table_data(dl,algs=None):
    if algs is None:
        algs = dl.alg_names

    output = []
    row_names = []
    col_names = []
    for x in stat_meta['mr_time_steps']:
        col_names.extend([f"MedR@{x}", f"MuR@{x}", f"MedP@{x}", f"MuP@{x}"])
    col_names.extend([f"OC@{x}" for x in stat_meta['fail_thresholds']])
    col_names.extend([f"SF@{x}" for x in stat_meta['sf_time_steps']])
    col_names.extend(["MinFrac", "GrdyFrac"])
    col_names.extend([f"Revs", f"Fails", f"Runs"])

    for alg in algs:
        if len(dl.all_results[alg]) == 0:
            continue
        tmp = []

        ## Get longitudinal information
        tmp_mat = []
        low = 0.5-float(dl.delta)/2; high = 0.5+float(dl.delta)/2
        for fn in [lambda: dl.get_median_quantile(alg,0.1),
                   lambda: dl.get_mean_std(alg),
                   lambda: dl.get_median_quantile(alg,0.1,pseudo=True),
                   lambda: dl.get_mean_std(alg,pseudo=True)]:
            rewards = fn()
            if rewards is None:
                tmp_mat.append([None for x in stat_meta['mr_time_steps']])
            else:
                rewards = rewards[0]
                tmp_mat.append([(rewards[x-1]/x - low)/(high-low) for x in stat_meta['mr_time_steps']])
        tmp_mat = np.matrix(tmp_mat)
        tmp.extend(list(np.array(np.reshape(tmp_mat,-1,order='F'))[0]))

        ## Get OptCount
        mat = dl.get_opt_action_matrix(alg)
        if mat is None:
            tmp.extend([None for x in stat_meta['fail_thresholds']])
        else:
            opt_counts = np.sum(mat, axis=1)
            tmp.extend([np.where(opt_counts < x)[0].shape[0]/len(opt_counts) for x in stat_meta['fail_thresholds']])


        ## Get suffix failures
        mat = dl.get_suffix_counts(alg)
        if mat is None:
            tmp.extend([None for x in stat_meta['sf_time_steps']])
        else:
            arr = np.mean(mat,axis=0)
            tmp.extend([arr[x-1] for x in stat_meta['sf_time_steps']])
            
        ## Get MinFrac and greedy frac
        min_frac = stats.get_min_frac(dl,alg)
        ## rescale to be in [0,1]
        tmp.append(dl.K*min_frac)
        tmp.append(stats.get_greedy_frac(dl,alg))


        ## Get revs and failures
        if alg not in dl.failures.keys():
            tmp.extend([0,0])
        else:
            revs = dl.failures[alg]['reverted']
            failures = dl.failures[alg]['failed']
            if len(revs) == 0:
                tmp.extend([0, 0]) ## This row will get thrown out anyway
            else:
                tmp.extend([np.mean(revs)/dl.T , np.sum(failures)])

        ## Get runs
        runs = len(dl.all_results[alg])
        if runs == 0:
            continue
        tmp.append(runs)
        output.append(tmp)
        row_names.append(util.get_name(alg))

    return (output, row_names, col_names)


def render_table(dl,algs=None,save=False):
    (output, row_names, col_names) = get_table_data(dl,algs=algs)
    
    ## The numbers to visualize
    df1 = pd.DataFrame(np.matrix(output,dtype=float), index = row_names, columns = col_names)

    ## We will flip some columns so the colormap corresponds to yellow=good, blue=bad
    df2 = pd.DataFrame(np.matrix(output,dtype=float),index=row_names, columns=col_names)
    for col in col_names:
        if col[0] == 'M' and col != 'MinFrac':
            continue
        df2.loc[:, df2.columns==col] = 1 - df2.loc[:, df2.columns==col]

    (fig, axes) = plt.subplots(nrows=1, ncols=2, figsize=(len(col_names),0.5*len(row_names)), gridspec_kw={'width_ratios': [2, 1]})
    plt.sca(axes[0])

    to_drop = ['Revs', 'Fails', 'Runs']
    sns.heatmap(df2.drop(to_drop, axis=1), annot=df1.drop(to_drop, axis=1), cmap="viridis", vmin=0, vmax=1, linewidths=.5, cbar=False, square=False, fmt='0.2f', annot_kws={"size": 14})
    ax = plt.gca()
    ax.xaxis.tick_top()
    ax.tick_params(length=0)
    plt.sca(axes[1])
    sns.heatmap(df2[to_drop], annot=df1[to_drop], cmap="viridis", vmin=0, vmax=1, linewidths=.5, cbar=False, fmt='0.2f', annot_kws={"size": 14})
    ax = plt.gca()
    ax.xaxis.tick_top()
    ax.tick_params(length=0)
    plt.yticks([])
    plt.tick_params(axis='both', which='major', labelsize=14)

    # plot_name = f"{dl.llm_pref} Template, T={dl.T}, K={dl.K}, Delta={dl.delta}"
    # fig.suptitle(plot_name,fontsize=16)
    # fig.tight_layout()
    plt.subplots_adjust(top=0.94)

    if save:
        fig_name = f"table_{dl.llm_pref}_{dl.delta}"
        for llm in dataloader.llms:
            if llm in dl.algs:
                fig_name += f"_{llm}"
        plt.savefig(f"../figs/{fig_name}.pdf", dpi=100, format="pdf")
