import numpy as np
import matplotlib.pyplot as plt
import postprocess.util as util
import matplotlib.pylab as pylab
import statsmodels.api as sm
import postprocess.stats as stats

params = {'legend.fontsize': 'large',
         'axes.labelsize': 'large',
         'axes.titlesize':'large',
         'xtick.labelsize':'large',
         'ytick.labelsize':'large'}
pylab.rcParams.update(params)

def get_legend_list(algs, prompt_type):
    legend_list = []
    if prompt_type == 'adverts':
        extra_char = 'A'
    if prompt_type == 'buttons':
        extra_char = 'B'

    for alg in algs:
        if 'UCB' in util.get_name(alg):
            new_name = 'UCB'
        elif 'Greedy' in util.get_name(alg):
            new_name = 'Greedy'
        elif 'GPT-4' in util.get_name(alg):
            # new_name = util.get_name(alg).replace('GPT-4-', '')
            # new_name = extra_char + new_name
            new_name = 'GPT4'
        elif 'GPT-3.5' in util.get_name(alg):
            # new_name = util.get_name(alg).replace('GPT-3.5-', '')
            # new_name = extra_char + new_name
            new_name = 'GPT3.5'
        elif 'Llama-2-13b' in util.get_name(alg):
            # new_name = util.get_name(alg).replace('Llama-2-13b-', '')
            # new_name = extra_char + new_name
            new_name = 'Llama2'
        else:
            new_name = util.get_name(alg)

        legend_list.append(new_name)
    return legend_list

font_size = 12

def plot_longitudinal(dl, algs=None, ax=None, pseudo=False, median=False, bands=True):
    if algs is None:
        algs = dl.alg_names
    if ax is None:
        (f,ax) = plt.subplots(nrows=1, ncols=1,figsize=(6,4))

    T = dl.T
    delta = float(dl.delta)
    for alg in algs:
        if median:
            tmp = dl.get_median_quantile(alg,0.05,pseudo=pseudo)
        else:
            tmp = dl.get_mean_std(alg,pseudo=pseudo)
        if tmp is None:
            algs.remove(alg)
            continue
        (main,low,high) = tmp
        xline = [i+1 for i in range(T)]
        plt.plot(xline, main/xline, color=util.get_color(alg), linewidth=2, alpha=1.0)
        if bands:
            plt.fill_between(xline, low/xline, high/xline, color=util.get_color(alg), alpha=0.3)
    plt.title(f"{dl.llm_pref}, {dl.K}-arms, Delta={dl.delta}")
    x = [j+1 for j in range(T)]
    plt.plot(x, [(0.5+float(delta)/2) for j in range(T)], color='black')

    # quick and dirty relabeling of legend
    legend_list = get_legend_list(algs, dl.llm_pref)

    # plt.legend([util.get_name(alg) for alg in algs] + ['OPT'],loc='upper left', ncol=2, bbox_to_anchor=(0,0.95))
    plt.legend(legend_list + ['OPT'],loc='upper left', ncol=2, bbox_to_anchor=(0,0.95))
    plt.xlabel('Time step (t)')
    plt.ylim([(0.5-delta/2)-0.001,(0.5+delta/2)+0.005])
    plt.yticks(np.arange(0.5-delta/2, 0.5+delta/2+0.005, 0.05))
    plt.xlim([0,T])
    plt.ylabel('Time-averaged reward')
    return (ax)
    
def plot_suffix_failure(dl, algs=None, ax=None):
    if algs is None:
        algs = dl.alg_names
    if ax is None:
        (f,ax) = plt.subplots(nrows=1, ncols=1,figsize=(6,4))
    
    for alg in algs:
        arr = dl.get_suffix_counts(alg)
        if arr is None:
            algs.remove(alg)
            continue
        plt.plot(np.mean(arr,axis=0),color=util.get_color(alg), linewidth=2, alpha=1.0)
    plt.title(f"{dl.llm_pref}, {dl.K}-arms, Delta={dl.delta}, T={dl.T}")

    # quick and dirty relabeling of legend
    legend_list = get_legend_list(algs, dl.llm_pref)
    # plt.legend([util.get_name(alg) for alg in algs], ncol=2)
    plt.legend(legend_list, ncol=2)

    plt.xlabel('Time step (t)')
    plt.xlim([0, int(0.8*dl.T)])
    plt.yticks(np.arange(0, 0.81, 0.2))
    # plt.ylabel('Fraction of replicates that fail on [t,T]')
    plt.ylabel('Suffix Failure Frequency @ t')
    return ax

options = {
    'opt_act_width': 0.15,
    'opt_act_offset': 0.15,
}

def plot_opt_action_hist(dl,algs=None,ax=None,bin_width=10,start=0):
    if algs is None:
        algs = dl.alg_names
    if ax is None:
        (f,ax) = plt.subplots(nrows=1, ncols=1,figsize=(6,4))

    offset = 0 
    width = options['opt_act_width']
    num_bins = int((dl.T-start)/bin_width)
    for alg in algs:
        mat = dl.get_opt_action_matrix(alg)
        if mat is None:
            algs.remove(alg)
            continue
        mat = mat[:,start:]
        opt_counts = np.sum(mat, axis=1)
        bins = [0 for i in range(num_bins)]
        for i in range(num_bins):
            for j in range(bin_width):
                bins[i] += np.where(opt_counts == bin_width*i+j)[0].shape[0]
        bins[-1] += np.where(opt_counts == (dl.T-start))[0].shape[0]
        plt.bar(np.array(range(num_bins))+offset, np.array(bins)/len(opt_counts), width=width,align='edge', color=util.get_color(alg))
        offset += options['opt_act_offset']

    plt.title(f"{dl.llm_pref}, {dl.K}-arms, Delta = {dl.delta}, T={dl.T}")

    legend_list = get_legend_list(algs, dl.llm_pref)
    # plt.legend([util.get_name(alg) for alg in algs],ncol=2)
    plt.legend(legend_list,ncol=2)

    plt.xlabel(f'Plays of the best arm in rounds [{start},{dl.T}]')
    plt.xticks(np.array(range(num_bins)),labels=bin_width*np.array(range(num_bins)))
    plt.vlines(range(num_bins+1), 0, 1, linestyle='--', color='black', linewidth=1, alpha=0.5, label="_nolegend_")
    plt.ylim([0,1])
    plt.xlim([0,num_bins])
    plt.ylabel('Fraction of replicates')

def plot_min_action_hist(dl, algs=None, ax=None, bin_width=10):
    if algs is None:
        algs = dl.alg_names
    if ax is None:
        (f,ax) = plt.subplots(nrows=1, ncols=1,figsize=(6,4))

    offset = 0 
    width = options['opt_act_width']
    num_bins = int(dl.T/bin_width)
    for alg in algs:
        freqs = dl.get_action_freqs(alg)
        if freqs is None:
            algs.remove(alg)
            continue
        min_counts = np.min(freqs, axis=1)
        bins = [0 for i in range(num_bins)]
        for i in range(num_bins):
            for j in range(bin_width):
                bins[i] += np.where(min_counts == bin_width*i+j)[0].shape[0]
        bins[-1] += np.where(min_counts == dl.T)[0].shape[0]
        plt.bar(np.array(range(num_bins))+offset, np.array(bins)/len(min_counts), width=width,align='edge', color=util.get_color(alg))
        offset += options['opt_act_offset']

    plt.title(f"{dl.llm_pref}, {dl.K}-arms, Delta = {dl.delta}, T={dl.T}")

    legend_list = get_legend_list(algs, dl.llm_pref)
    # plt.legend([util.get_name(alg) for alg in algs],ncol=2)
    plt.legend(legend_list,ncol=2)

    plt.xlabel('Plays of the least-played arm')
    plt.xticks(np.array(range(num_bins)),labels=bin_width*np.array(range(num_bins)))
    plt.vlines(range(num_bins), 0, 1, linestyle='--', color='black', linewidth=1, alpha=0.5, label="_nolegend_")
    plt.ylim([0,1])
    plt.ylabel('Fraction of replicates')
    
def plot_min_action_longitudinal(dl,algs=None,ax=None,median=False):
    if algs is None:
        algs = dl.alg_names
    if ax is None:
        (f,ax) = plt.subplots(nrows=1, ncols=1,figsize=(6,3.5))
    
    for alg in algs:
        tens = dl.get_action_tensor(alg)

        tmp = np.cumsum(tens,axis=1)
        tmp = np.min(tmp,axis=2)

        if median:
            mins = np.median(tmp, axis=0)
        else:
            mins = np.mean(tmp, axis=0)

        xline = np.array([i+1 for i in range(dl.T)])
        plt.plot(xline, dl.K*mins/xline, color=util.get_color(alg), linewidth=2, alpha=1.0)

    plt.title(f"{dl.llm_pref}, {dl.K}-arms, Delta={dl.delta}")

    legend_list = get_legend_list(algs, dl.llm_pref)
    # plt.legend([util.get_name(alg) for alg in algs],loc='lower left', ncol=2)
    plt.legend(legend_list,loc='upper right', ncol=2)

    plt.xlabel('Time step (t)')
    plt.ylim([0, 1])
    plt.xlim([0,dl.T])
    plt.ylabel('K * MinFrac(t)')
    return (ax)

def plot_cdf(dl,fn=stats.ave_reward,algs=None,ax=None):
    if algs is None:
        algs = dl.alg_names
    if ax is None:
        (f,ax) = plt.subplots(nrows=1, ncols=1,figsize=(6,3.5))

    for alg in algs:
        data = [fn(item) for item in dl.all_results[alg]]
        ecdf = sm.distributions.ECDF(data)
        x = np.linspace(0,1,num=1000)
        y = ecdf(x)
        plt.step(x,y,color=util.get_color(alg),linewidth=2,alpha=1.0)

    legend_list = get_legend_list(algs, dl.llm_pref)
    # plt.legend([util.get_name(alg) for alg in algs],loc='lower left', ncol=2)
    plt.legend(legend_list,loc='upper left', ncol=2)

    plt.ylim([0, 1])
    plt.ylabel('Fraction of replicates')
    return (ax)
