import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import postprocess.stats, postprocess.util

def get_table_data(dl,algs=None):
    if algs is None:
        algs = dl.alg_names

    output = []
    row_names = []
    col_names = []
    col_names.extend(['MedianReward', 'SuffFailFreq(T/2)', 'MinFrac', 'GreedyFrac', 'Runs'])

    for alg in algs:
        if len(dl.all_results[alg]) == 0:
            continue
        tmp = []

        ## Get Median Reward
        low = 0.5-float(dl.delta)/2; high = 0.5+float(dl.delta)/2
        rewards = dl.get_median_quantile(alg,0.5)
        tmp.extend([(rewards[0][dl.T-1]/dl.T - low)/(high-low)])


        ## Get Suffix Failure
        mat = dl.get_suffix_counts(alg)
        if mat is None:
            tmp.extend([None])
        else:
            arr = np.mean(mat,axis=0)
            tmp.extend([arr[int(dl.T/2)-1]])

        ## Get Min Frac
        min_frac = postprocess.stats.get_min_frac(dl,alg)
        tmp.append(dl.K*min_frac)

        ## Get Greedy Frac
        tmp.append(postprocess.stats.get_greedy_frac(dl,alg))

        # ## Get failures
        # if alg not in dl.failures.keys():
        #     tmp.extend([0])
        # else:
        #     revs = dl.failures[alg]['reverted']
        #     failures = dl.failures[alg]['failed']
        #     if len(revs) == 0:
        #         tmp.extend([0]) ## This row will get thrown out anyway
        #     else:
        #         tmp.extend([np.sum(failures)])

        ## Get runs
        runs = len(dl.all_results[alg])
        if runs == 0:
            continue
        tmp.append(runs)
        output.append(tmp)
        row_names.append(postprocess.util.get_name(alg))

    return (output, row_names, col_names)

def render_table(dl,alg,width):
    (output, row_names, col_names) = get_table_data(dl)

    # quick and dirty name change
    for idx, col_name in enumerate(col_names):
        if col_name == 'MinFrac':
            col_names[idx] = 'K*MinFrac'
        if col_name == 'Runs':
            col_names[idx] = 'Replicates'
    
    inds = [0,1,2]
    for i in range(len(row_names)):
        if alg in row_names[i]:
            inds.append(i)
    inds = np.array(inds)

    # quick and dirty change to row_names:
    for idx, row_name in enumerate(row_names):

        if row_names[idx] == 'UCB-1.0':
            row_names[idx] = 'UCB'
        if row_names[idx] == 'Greedy-1':
            row_names[idx] = 'Greedy'
        
        if alg in row_names[idx]:
            # remove alg from row_names[i]
            row_names[idx] = row_name.replace(alg + '-', '')

            # add extra letter for buttons/adverts
            if dl.llm_pref == 'buttons':
                row_names[idx] = 'B' + row_names[idx]
            if dl.llm_pref == 'adverts':
                row_names[idx] = 'A' + row_names[idx]
            if 'tC' in row_names[idx]:
                row_names[idx] = row_names[idx].replace('tC','$\widetilde{C}$')
    ## The numbers to visualize
    df1 = pd.DataFrame(np.matrix(output,dtype=float)[inds,:], index=[row_names[i] for i in inds], columns=col_names)
    ## We will flip some columns so the colormap corresponds to yellow=good, blue=bad                                                                               
    df2 = pd.DataFrame(np.matrix(output,dtype=float)[inds,:],index=[row_names[i] for i in inds], columns=col_names)

    for col in col_names:
        if col[0] == 'M' and col != 'MinFrac':
            continue
        df2.loc[:, df2.columns==col] = 1 - df2.loc[:, df2.columns==col]
    # df2.loc[:, 'Fails'] = 0
    # df2.loc[:, 'Runs'] = 0
    df2.loc[:, 'Replicates'] = 0

    # to_drop = ['Fails', 'Runs']
    # to_drop = ['Fails', 'Replicates']
    to_drop = ['Replicates']
    dropped1 = df1.drop(to_drop,axis=1).T
    dropped2 = df2.drop(to_drop,axis=1).T
    # dropped1 = df1.drop(to_drop,axis=1)
    # dropped2 = df2.drop(to_drop,axis=1)
    (fig,axes) = plt.subplots(nrows=2,ncols=1, figsize=(width,4), gridspec_kw={'height_ratios': [4,1]})
    # (fig,axes) = plt.subplots(nrows=1,ncols=2, figsize=(width,width), gridspec_kw={'width_ratios': [2,1]})

    plt.sca(axes[0])
    sns.heatmap(dropped2, annot=dropped1, cmap='viridis',vmin=0,vmax=1, linewidth=0.5,cbar=False, square=False, fmt='0.2f', annot_kws={"size":14, 'fontweight': 'bold'})
    ax = plt.gca()
    ax.xaxis.tick_top()
    ax.tick_params(length=0)
    plt.sca(axes[1])

    ax.set_yticklabels(ax.get_yticklabels(), rotation=0)

    dropped1 = (df1[to_drop].T).astype(int)
    dropped2 = df2[to_drop].T
    # dropped1 = (df1[to_drop]).astype(int)
    # dropped2 = df2[to_drop]
    sns.heatmap(dropped2, annot=dropped1, cmap='Greys',vmin=0,vmax=1, linewidth=0.5,linecolor='black',cbar=False, square=False, fmt='d', annot_kws={"size":14, 'fontweight': 'bold'},xticklabels=False)
    ax = plt.gca()
    ax.xaxis.tick_top()
    ax.tick_params(length=0)

    ax.set_yticklabels(ax.get_yticklabels(), rotation=0)
    plt.subplots_adjust(hspace=0.01)
    # plt.subplots_adjust(wspace=0.0)
    
    for _, spine in ax.spines.items():
        spine.set_visible(True)
        spine.set_color('black')
        spine.set_linewidth(0.5)


def render_table_vert(dl,alg,height,save=True):
    (output, row_names, col_names) = get_table_data(dl)
    
    # quick and dirty name change
    for idx, col_name in enumerate(col_names):
        if col_name == 'MinFrac':
            col_names[idx] = 'K*MinFrac'
        if col_name == 'Runs':
            col_names[idx] = 'Replicates'
        # if col_name == 'MedianReward':
        #     col_names[idx] = 'MedRwd'
        # if col_name == 'SuffFailFreq(T/2)':
        #     col_names[idx] = 'SufFail'
        # if col_name == 'GreedyFrac':
        #     col_names[idx] = 'GrFr'
    
    inds = [0,1,2]
    for i in range(len(row_names)):
        if alg in row_names[i]:
            inds.append(i)
    inds = np.array(inds)

    # quick and dirty change to row_names:
    for idx, row_name in enumerate(row_names):
        if row_names[idx] == 'UCB-1.0':
            row_names[idx] = 'UCB'
        if row_names[idx] == 'Greedy-1':
            row_names[idx] = 'Greedy'
            
        if alg in row_names[idx]:
            # remove alg from row_names[i]
            row_names[idx] = row_name.replace(alg + '-', '')

            # add extra letter for buttons/adverts
            if dl.llm_pref == 'buttons':
                row_names[idx] = 'B' + row_names[idx]
            if dl.llm_pref == 'adverts':
                row_names[idx] = 'A' + row_names[idx]
    
    ## The numbers to visualize
    df1 = pd.DataFrame(np.matrix(output,dtype=float)[inds,:], index=[row_names[i] for i in inds], columns=col_names)
    ## We will flip some columns so the colormap corresponds to yellow=good, blue=bad                                                                               
    df2 = pd.DataFrame(np.matrix(output,dtype=float)[inds,:],index=[row_names[i] for i in inds], columns=col_names)

    for col in col_names:
        if col[0] == 'M' and col != 'K*MinFrac':
            continue
        df2.loc[:, df2.columns==col] = 1 - df2.loc[:, df2.columns==col]

        # if col == 'MinFrac':
        # df2.loc[:, df2.columns==col] = df2.loc[:, df2.columns==col]*dl.K

    # df2.loc[:, 'Fails'] = 0
    # df2.loc[:, 'Runs'] = 0
    df2.loc[:, 'Replicates'] = 0

    # to_drop = ['Fails', 'Runs']
    # to_drop = ['Fails', 'Replicates']
    to_drop = ['Replicates']
    # dropped1 = df1.drop(to_drop,axis=1).T
    # dropped2 = df2.drop(to_drop,axis=1).T
    dropped1 = df1.drop(to_drop,axis=1)
    dropped2 = df2.drop(to_drop,axis=1)
    # (fig,axes) = plt.subplots(nrows=2,ncols=1, figsize=(width,4), gridspec_kw={'height_ratios': [3,1]})
    (fig,axes) = plt.subplots(nrows=1,ncols=2, figsize=(18,height), gridspec_kw={'width_ratios': [4,1]})

    plt.sca(axes[0])
    sns.heatmap(dropped2, annot=dropped1, cmap='viridis',vmin=0,vmax=1, linewidth=0.5,cbar=False, square=False, fmt='0.2f', annot_kws={"size":24, 'fontweight': 'bold'})        # font size of text in box
    ax = plt.gca()
    ax.xaxis.tick_top()
    ax.tick_params(length=0,labelsize=20)
    plt.sca(axes[1])

    ax.set_yticklabels(ax.get_yticklabels(), rotation=0)

    #dropped1 = (df1[to_drop].T).astype(int)
    #dropped2 = df2[to_drop].T
    dropped1 = (df1[to_drop]).astype(int)
    dropped2 = df2[to_drop]
    sns.heatmap(dropped2, annot=dropped1, cmap='Greys',vmin=0,vmax=1, linewidth=0.5,linecolor='black',cbar=False, square=False, fmt='d', annot_kws={"size":24, 'fontweight': 'bold'},yticklabels=False)
    ax = plt.gca()
    ax.xaxis.tick_top()
    ax.tick_params(length=0,labelsize=20)

    ax.set_yticklabels(ax.get_yticklabels(), rotation=0)
    # plt.subplots_adjust(hspace=0.01)
    plt.subplots_adjust(wspace=0.0)

    # sns.set(font_scale=2.2)     # change size of axis labels
    
    for _, spine in ax.spines.items():
        spine.set_visible(True)
        spine.set_color('black')
        spine.set_linewidth(0.5)
