import numpy as np
import matplotlib.pyplot as plt

import pickle
import os

#### Style metadata ####

colors = {
    'ts': 'tab:orange',
    'ucb': 'tab:green',
    'g1': 'tab:red',
    'gpt35': 'tab:purple',
    'gpt4': 'tab:blue',
    'llama13b': 'tab:pink'
}
names = {
    'ts': 'TS',
    'ucb': 'UCB',
    'g1': 'Greedy',
    'gpt35': 'GPT-3.5',
    'gpt4': 'GPT-4',
    'llama13b': 'Llama-2-13b'
}

options ={
    'opt_act_width': 0.15,
    'opt_act_offset': 0.15,
    }

def get_name(alg,meta_data):
    if alg in ['ts','ucb','g1']:
        return names[alg]
    else:
        name = names[alg]+'-'+meta_data['suggestive'][0].upper()+meta_data['summarized'][0].upper()+meta_data['cot'][0].upper()
        if meta_data['dist'] == 'dist':
            name += 'D'
        else:
            name += str(meta_data['temp'])
        return name

#### File system metadata ####
f_sub_names = {
    'ts': 'ts',
    'ucb': 'ucb',
    'g1': 'g1',
    'gpt35': 'gpt35',
    'gpt4': 'gpt4',
    'llama13b': 'llama13b'
}

base_dir = "../results/"
baseline_pref = 'baselines'


#### Ingesting subroutine ####
def get_baseline_results(meta_data):
    all_results = {}
    for alg in meta_data['algs']:
        all_results[alg] = {}

    for delta in meta_data['deltas']:
        for alg in meta_data['algs']:
            all_results[alg][f"{delta:0.1f}"] = []
            if alg == 'gpt4' or alg == 'gpt35' or alg == 'llama13b':
                continue
                dir_name = base_dir+f"{meta_data['llm_pref']}_{meta_data['suggestive']}_{meta_data['summarized']}_{meta_data['cot']}_K={meta_data['K']}_T={meta_data['T']}_delta={delta:0.1f}/"
                fpref = dir_name + f"{f_sub_names[alg]}_t={meta_data['temp']}_"
            else:
                dir_name = base_dir+f"{baseline_pref}_K={meta_data['K']}_T={meta_data['T']}_delta={delta:0.1f}_n0={meta_data['n0']}_eta={meta_data['eta']}/"
                fpref = dir_name + f"{f_sub_names[alg]}_"
            for rep in range(meta_data['reps']):
                fname = fpref+f"{rep}.pkl"
                if os.path.isfile(fname):
                    data = pickle.load(open(fname,'rb'))
                    all_results[alg][f"{delta:0.1f}"].append(data)
    return all_results

def load_llm_results(all_results,meta_data,debug=True):
    ## Clear previous results
    all_results['gpt3'] = {}
    all_results['gpt4'] = {}
    all_results['llama13b'] = {}
    failures = {'gpt35': {}, 'gpt4': {}, 'llama13b': {}}
    for delta in meta_data['deltas']:
        for alg in meta_data['algs']:
            if alg == 'gpt4' or alg == 'gpt35' or alg == 'llama13b':
                all_results[alg][f"{delta:0.1f}"] = []
                failures[alg][f"{delta:0.1f}"] = {'reverted': [], 'failed': []}
                dir_name = base_dir+f"{meta_data['llm_pref']}_{meta_data['suggestive']}_{meta_data['summarized']}_{meta_data['dist']}_{meta_data['cot']}_K={meta_data['K']}_T={meta_data['T']}_delta={delta:0.1f}/"
                # dir_name = base_dir+f"kbuttons_K={meta_data['K']}_T={meta_data['T']}_delta={delta:0.1f}/"
                fpref = dir_name + f"{f_sub_names[alg]}_t={meta_data['temp']}_"
                # fpref = dir_name + f"{f_sub_names[alg]}_"
            else:
                continue
            for rep in range(meta_data['reps']):
                fname = fpref+f"{rep}.pkl"
                if os.path.isfile(fname):
                    data = pickle.load(open(fname,'rb'))
                    if len(data) == 3:
                        all_results[alg][f"{delta:0.1f}"].append(data[0])
                        failures[alg][f"{delta:0.1f}"]['reverted'].append(data[1])
                        failures[alg][f"{delta:0.1f}"]['failed'].append(data[2])
                    else:
                        all_results[alg][f"{delta:0.1f}"].append(data)               
    for delta in meta_data['deltas']:
        if 'gpt4' in meta_data['algs']:
            num_runs = len(all_results['gpt4'][f"{delta:0.1f}"])
            num_reverted = len([x for x in failures['gpt4'][f"{delta:0.1f}"]['reverted'] if x is True])
            num_failed = len([x for x in failures['gpt4'][f"{delta:0.1f}"]['failed'] if x is True])
            if debug:
                print(f"Delta {delta:0.1f}, gpt-4 runs. Attempted: {num_runs}, Reverted: {num_reverted}, Failed: {num_failed}", flush=True)
        if 'gpt35' in meta_data['algs']:
            num_runs = len(all_results['gpt35'][f"{delta:0.1f}"])
            num_reverted = len([x for x in failures['gpt35'][f"{delta:0.1f}"]['reverted'] if x is True])
            num_failed = len([x for x in failures['gpt35'][f"{delta:0.1f}"]['failed'] if x is True])
            if debug:
                print(f"Delta {delta:0.1f}, gpt-3.5 runs. Attempted: {num_runs}, Reverted: {num_reverted}, Failed: {num_failed}", flush=True)
        if 'llama13b' in meta_data['algs']:
            num_runs = len(all_results['llama13b'][f"{delta:0.1f}"])
            num_reverted = len([x for x in failures['llama13b'][f"{delta:0.1f}"]['reverted'] if x is True])
            num_failed = len([x for x in failures['llama13b'][f"{delta:0.1f}"]['failed'] if x is True])
            if debug:
                print(f"Delta {delta:0.1f}, llama13b runs. Attempted: {num_runs}, Reverted: {num_reverted}, Failed: {num_failed}", flush=True)
    return (failures)

def get_results(meta_data):
    all_results = {}
    for alg in meta_data['algs']:
        all_results[alg] = {}

    for delta in meta_data['deltas']:
        for alg in meta_data['algs']:
            all_results[alg][f"{delta:0.1f}"] = []
            if alg == 'gpt4' or alg == 'gpt35' or alg == 'llama13b':
                dir_name = base_dir+f"{meta_data['llm_pref']}_{meta_data['suggestive']}_{meta_data['summarized']}_{meta_data['cot']}_K={meta_data['K']}_T={meta_data['T']}_delta={delta:0.1f}/"
                fpref = dir_name + f"{f_sub_names[alg]}_t={meta_data['temp']}_"
            else:
                dir_name = base_dir+f"{baseline_pref}_K={meta_data['K']}_T={meta_data['T']}_delta={delta:0.1f}/"
                fpref = dir_name + f"{f_sub_names[alg]}_"
            for rep in range(meta_data['reps']):
                fname = fpref+f"{rep}.pkl"
                if os.path.isfile(fname):
                    data = pickle.load(open(fname,'rb'))
                    all_results[alg][f"{delta:0.1f}"].append(data)
    for delta in meta_data['deltas']:
        if 'gpt4' in meta_data['algs']:
            num_runs = len(all_results['gpt4'][f"{delta:0.1f}"])
            print(f"Delta {delta:0.1f}: Number of gpt-4 runs: {num_runs}", flush=True)
        if 'gpt35' in meta_data['algs']:
            num_runs = len(all_results['gpt35'][f"{delta:0.1f}"])
            print(f"Delta {delta:0.1f}: Number of gpt-3.5 runs: {num_runs}", flush=True)
        if 'llama13b' in meta_data['algs']:
            num_runs = len(all_results['llama13b'][f"{delta:0.1f}"])
            print(f"Delta {delta:0.1f}: Number of llama13b runs: {num_runs}", flush=True)
    return all_results


def plot_and_save_all(data,meta_data):
    llm_pref = meta_data['llm_pref']
    T = meta_data['T']
    K = meta_data['K']
    
    pref = f"../figs/{llm_pref}_K={K}_T={T}_"

    f = plot_mean_std_pseudoregret(data,meta_data)
    plt.savefig(pref+"mean_std.pdf", format="pdf", dpi=100)
    # plt.close(f)

    f = plot_opt_freqs(data,meta_data)
    plt.savefig(pref+"opt_freqs.pdf", format="pdf", dpi=100)
    # plt.close(f)

    f = plot_suffix_failure(data,meta_data)
    plt.savefig(pref+"suffix_failure.pdf", format="pdf", dpi=100)
    # plt.close(f)

    f = plot_min_action_hist(data,meta_data)
    plt.savefig(pref+"min_action_hist.pdf", format="pdf", dpi=100)
    # plt.close(f)

    f = plot_opt_action_hist(data,meta_data)
    plt.savefig(pref+"opt_action_hist.pdf", format="pdf", dpi=100)
    # plt.close(f)

def plot_triple(data,meta_data):
    (f,axs) = plt.subplots(nrows=1, ncols=3, figsize=(6*3,4))
    if type(axs) is not list and type(axs) is not np.ndarray:
        axs = [axs]
    meta_data['axes'] = [axs[0]]
    plot_opt_action_hist(data,meta_data)
    meta_data['axes'] = [axs[1]]
    plot_suffix_failure(data,meta_data)
    meta_data['axes'] = [axs[2]]
    plot_mean_std(data, meta_data)
    f.suptitle(f"{meta_data['llm_pref'][1:]} prompt, {meta_data['K']}-arms, T={meta_data['T']}",fontsize=18)
    # axs[0].set_title('Cumulative average reward')
    # axs[1].set_title('Plays of optimal arm')
    # axs[2].set_title('Persistent failure probabilities')
    axs[0].set_title('')
    axs[1].set_title('')
    axs[2].set_title('')
    f.subplots_adjust(hspace=0.4, top=0.9)
    return (f)

def get_reward_matrix(data):
    rewards = []
    for item in data:
        rewards.append([x[1] for x in item])
    mx = np.max([len(x) for x in rewards])
    new_rewards = [x for x in rewards if len(x) == mx]
    return np.vstack(new_rewards)

def get_opt_action_matrix(data):
    T = np.max([len(item) for item in data])
    N = len(data)
    M = np.zeros((len(data), T))
    for i in range(len(data)):
        item = data[i]
        for t in range(len(item)):
            M[i,t] = (item[t][0] == 0)
    return(M)

def _plot_mean_std_sub(data,color,stdev=True):
    rewards = get_reward_matrix(data)
    cum_rewards = np.cumsum(rewards,axis=1)
    means = np.mean(cum_rewards, axis=0)
    stds = np.std(cum_rewards, axis=0)
    T = rewards.shape[1]
    N = rewards.shape[0]
    xline = [i+1 for i in range(T)]
    plt.plot(xline,means/xline,color=color,linewidth=2,alpha=1.0)
    if stdev:
        plt.fill_between(xline, (means - 2*stds/np.sqrt(N))/xline, (means + 2*stds/np.sqrt(N))/xline, color=color, alpha=0.3)

def _plot_mean_std_pseudoregret_sub(data,color,delta,stdev=True):
    T = len(data[0])
    N = len(data)
    M = get_opt_action_matrix(data)
    M = M*delta + (0.5 - delta/2)
    cum_rewards = np.cumsum(M,axis=1)
    means = np.mean(cum_rewards, axis=0)
    stds = np.std(cum_rewards, axis=0)
    xline = [i+1 for i in range(T)]
    plt.plot(xline,means/xline,color=color,linewidth=2,alpha=1.0)
    if stdev:
        plt.fill_between(xline, (means - 2*stds/np.sqrt(N))/xline, (means + 2*stds/np.sqrt(N))/xline, color=color, alpha=0.3)

def _plot_median_quantile_pseudoregret_sub(data,color,delta,q=0.05):
    T = len(data[0])
    N = len(data)
    M = get_opt_action_matrix(data)
    M = M*delta + (0.5 - delta/2)
    cum_rewards = np.cumsum(M,axis=1)
    meds = np.median(cum_rewards, axis=0)
    qlow = np.quantile(cum_rewards, q, axis=0)
    qhigh = np.quantile(cum_rewards, 1-q, axis=0)
    xline = [i+1 for i in range(T)]
    plt.plot(xline,meds/xline,color=color,linewidth=2,alpha=1.0)
    plt.fill_between(xline, qlow/xline, qhigh/xline, color=color, alpha=0.3)

def _plot_lower_quantile_pseudoregret_sub(data,color,delta):
    T = len(data[0])
    N = len(data)
    M = get_opt_action_matrix(data)
    M = M*delta + (0.5 - delta/2)
    cum_rewards = np.cumsum(M,axis=1)
    quantiles = [0.5, 0.3, 0.1]
    xline = [i+1 for i in range(T)]
    for i in range(len(quantiles)):
        quant = np.quantile(cum_rewards, quantiles[i], axis=0)
        if i == 0:
            plt.plot(xline,quant/xline,color=color,linewidth=2,alpha=(1.0-i/len(quantiles)))
        else:
            plt.plot(xline,quant/xline,color=color,linewidth=2,alpha=(1.0-i/len(quantiles)), label="_nolegend_")

def _plot_trajectories_sub(data,color):
    rewards = get_reward_matrix(data)
    cum_rewards = np.cumsum(rewards,axis=1)
    T = cum_rewards.shape[1]
    xline = [i+1 for i in range(T)]
    for i in range(cum_rewards.shape[0]):
        plt.plot(xline, cum_rewards[i,:]/xline, color=color, alpha=0.2, label="_nolegend_")

def plot_mean_std(data, meta_data):
    algs = meta_data['algs']
    deltas = meta_data['deltas']
    T = meta_data['T']
    K = meta_data['K']
    if 'axes' in meta_data.keys():
        axs = meta_data['axes']
        f = None
    else:
        (f,axs) = plt.subplots(nrows=1, ncols=len(deltas), figsize=(6*len(deltas),4))
        if type(axs) is not list and type(axs) is not np.ndarray:
            axs = [axs]

    for i in range(len(deltas)):
        plt.sca(axs[i])
        delta = deltas[i]
        for alg in algs:
            subdata = data[alg][f"{delta:0.1f}"]
            _plot_mean_std_sub(subdata,colors[alg],stdev=True)
        plt.title(f"{meta_data['llm_pref']}, {K}-arms, Delta={delta:0.1f}")
        x = [j+1 for j in range(T)]
        plt.plot(x, [(0.5+delta/2) for j in range(T)], color='black')
        plt.legend([names[alg] for alg in algs] + ['OPT'],loc='upper left', ncol=2, bbox_to_anchor=(0,0.95))
        # plt.legend([names[alg] for alg in algs] + ['OPT'],loc='best')
        plt.xlabel('Time step (t)')
        plt.ylim([(0.5-delta/2)-0.001,(0.5+delta/2)+0.005])
        plt.yticks(np.arange(0.5-delta/2, 0.5+delta/2+0.005, 0.05))
        plt.xlim([0,T])
    plt.sca(axs[0])
    plt.ylabel('Time-averaged reward')
    return(f)

def plot_mean_std_pseudoregret(data,meta_data):
    algs = meta_data['algs']
    deltas = meta_data['deltas']
    T = meta_data['T']
    K = meta_data['K']

    (f,axs) = plt.subplots(nrows=1, ncols=len(deltas), figsize=(6*len(deltas),4))
    if type(axs) is not list and type(axs) is not np.ndarray:
        axs = [axs]

    for i in range(len(deltas)):
        plt.sca(axs[i])
        delta = deltas[i]
        for alg in algs:
            subdata = data[alg][f"{delta:0.1f}"]
            _plot_mean_std_pseudoregret_sub(subdata,colors[alg],delta,stdev=True)
        plt.title(f"{meta_data['llm_pref']}, {K}-arms, Delta={delta:0.1f}")
        x = [j+1 for j in range(T)]
        plt.plot(x, [(0.5+delta/2) for j in range(T)], color='black')
        plt.legend([names[alg] for alg in algs] + ['OPT'])
        plt.xlabel('Time step')
        plt.ylim([(0.5-delta/2)-0.05,(0.5+delta/2)+0.05])
    plt.sca(axs[0])
    plt.ylabel('Cumulative average pseudo-reward')
    return(f)

def plot_median_quantile_pseudoregret(data,meta_data,q=0.05):
    algs = meta_data['algs']
    deltas = meta_data['deltas']
    T = meta_data['T']
    K = meta_data['K']

    (f,axs) = plt.subplots(nrows=1, ncols=len(deltas), figsize=(6*len(deltas),4))
    if type(axs) is not list and type(axs) is not np.ndarray:
        axs = [axs]

    for i in range(len(deltas)):
        plt.sca(axs[i])
        delta = deltas[i]
        for alg in algs:
            subdata = data[alg][f"{delta:0.1f}"]
            _plot_median_quantile_pseudoregret_sub(subdata,colors[alg],delta,q=q)
        plt.title(f"{meta_data['llm_pref']}, {K}-arms, Delta={delta:0.1f}")
        x = [j+1 for j in range(T)]
        plt.plot(x, [(0.5+delta/2) for j in range(T)], color='black')
        plt.legend([names[alg] for alg in algs] + ['OPT'])
        plt.xlabel('Time step')
        plt.ylim([(0.5-delta/2)-0.05,(0.5+delta/2)+0.05])
    plt.sca(axs[0])
    plt.ylabel('Cumulative average reward')
    return(f)
    
def plot_lower_quantile_pseudoregret(data,meta_data):
    algs = meta_data['algs']
    deltas = meta_data['deltas']
    T = meta_data['T']
    K = meta_data['K']

    (f,axs) = plt.subplots(nrows=1, ncols=len(deltas), figsize=(6*len(deltas),4))
    if type(axs) is not list and type(axs) is not np.ndarray:
        axs = [axs]

    for i in range(len(deltas)):
        plt.sca(axs[i])
        delta = deltas[i]
        for alg in algs:
            subdata = data[alg][f"{delta:0.1f}"]
            _plot_lower_quantile_pseudoregret_sub(subdata,colors[alg],delta)
        plt.title(f"{meta_data['llm_pref']}, {K}-arms, Delta={delta:0.1f}")
        x = [j+1 for j in range(T)]
        plt.plot(x, [(0.5+delta/2) for j in range(T)], color='black')
        plt.legend([names[alg] for alg in algs] + ['OPT'])
        plt.xlabel('Time step')
        plt.ylim([(0.5-delta/2)-0.05,(0.5+delta/2)+0.05])
    plt.sca(axs[0])
    plt.ylabel('Cumulative average reward')
    return(f)


def plot_mean_and_trajectories(data, main_alg, meta_data):
    algs = meta_data['algs']
    deltas = meta_data['deltas']
    T = meta_data['T']
    K = meta_data['K']

    (f,axs) = plt.subplots(nrows=1, ncols=len(deltas), figsize=(6*len(deltas),4))
    if type(axs) is not list and type(axs) is not np.ndarray:
        axs = [axs]

    for i in range(len(deltas)):
        plt.sca(axs[i])
        delta = deltas[i]
        _plot_trajectories_sub(data[main_alg][f"{delta:0.1f}"],colors[main_alg])
        for alg in algs:
            subdata = data[alg][f"{delta:0.1f}"]
            _plot_mean_std_sub(subdata,colors[alg],stdev=False)
        plt.title(f"{meta_data['llm_pref']}, {K}-arms, Delta={delta:0.1f}")
        x = [j+1 for j in range(T)]
        plt.plot(x, [(0.5+delta/2) for j in range(T)], color='black')
        plt.legend([names[alg] for alg in algs] + ['OPT'])
        plt.xlabel('Time step')
        ## plt.ylim([(0.5-delta/2)-0.05,(0.5+delta/2)+0.05])
        plt.ylim([0,1])
    plt.sca(axs[0])
    plt.ylabel('Cumulative average reward')
    
def _plot_opt_freqs_sub(data,color):
    arm_freqs = []
    for item in data:
        lst = [y[0] for y in item]
        arm_freqs.append(lst)
    arm_freqs = np.vstack(arm_freqs)
    opt = np.count_nonzero(arm_freqs==0, axis=0)
    plt.plot(opt/arm_freqs.shape[0], color=color)

def plot_opt_freqs(data, meta_data):
    algs = meta_data['algs']
    deltas = meta_data['deltas']
    T = meta_data['T']
    K = meta_data['K']

    (f, axs) = plt.subplots(nrows=1,ncols=len(deltas), figsize=(6*len(deltas),4))
    if type(axs) is not list and type(axs) is not np.ndarray:
        axs = [axs]
    
    for i in range(len(deltas)):
        plt.sca(axs[i])
        delta = deltas[i]
        for alg in algs:
            subdata = data[alg][f"{delta:0.1f}"]
            _plot_opt_freqs_sub(subdata, colors[alg])
    
        plt.title(f"{meta_data['llm_pref']}, {K}-arms, Delta = {delta:0.1f}")
        plt.legend([names[alg] for alg in algs])
        plt.xlabel('Time step')
    plt.sca(axs[0])
    plt.ylabel('Fraction of replicates pulling optimal arm')

def _plot_suffix_failure_sub(data, color):
    T = len(data[0])
    M = get_opt_action_matrix(data)
    arr = np.fliplr(np.cumsum(np.fliplr(M), axis=1))
    bools = (arr==0)
    plt.plot(np.mean(bools,axis=0),color)

def plot_suffix_failure(data, meta_data):
    algs = meta_data['algs']
    deltas = meta_data['deltas']
    T = meta_data['T']
    K = meta_data['K']
    
    if 'axes' in meta_data.keys():
        axs = meta_data['axes']
        f = None
    else:
        (f, axs) = plt.subplots(nrows=1,ncols=len(deltas), figsize=(6*len(deltas),4))
        if type(axs) is not list and type(axs) is not np.ndarray:
            axs = [axs]
    
    for i in range(len(deltas)):
        plt.sca(axs[i])
        delta = deltas[i]
    
        for alg in algs:
            subdata = data[alg][f"{delta:0.1f}"]
            _plot_suffix_failure_sub(subdata, colors[alg])
        plt.title(f"{meta_data['llm_pref']}, {K}-arms, Delta = {delta:0.1f}, T={T}, eta={meta_data['eta']}, n0={meta_data['n0']}")
        plt.legend([names[alg] for alg in algs],ncol=2)
        plt.xlabel('Time step (t)')
        plt.xlim([0,int(0.8*T)])
        plt.yticks(np.arange(0, 0.81, 0.2))
    plt.sca(axs[0])
    plt.ylabel('Fraction of replicates that fail on [t,T]')

def _plot_min_action_hist_sub(data,color,meta_data,offset=0):
    K = meta_data['K']
    T = meta_data['T']
    bin_width = meta_data['bin_width']
    num_bins = int(T/bin_width)

    arm_counts = []
    for item in data:
        arr = []
        for i in range(len(item)):
            tmp = np.zeros(K)
            tmp[item[i][0]] = 1
            arr.append(tmp)
        arr = np.vstack(arr)
        counts = np.sum(arr, axis=0)
        arm_counts.append(counts)
    arm_counts = np.vstack(arm_counts)
    min_counts = np.min(arm_counts, axis=1)
    bins = [0 for i in range(num_bins)]
    for i in range(num_bins):
        for j in range(bin_width):
            bins[i] += np.where(min_counts == bin_width*i+j)[0].shape[0]
    plt.bar(np.array(range(num_bins))+offset, np.array(bins)/len(min_counts),width=options['opt_act_width'],align='edge',color=color)

def plot_min_action_hist(data, meta_data):
    algs = meta_data['algs']
    deltas = meta_data['deltas']
    T = meta_data['T']
    K = meta_data['K']
    bin_width = meta_data['bin_width']
    num_bins = int(T/bin_width)

    (f, axs) = plt.subplots(nrows=1,ncols=len(deltas), figsize=(6*len(deltas),4))
    if type(axs) is not list and type(axs) is not np.ndarray:
        axs = [axs]
    
    for i in range(len(deltas)):
        plt.sca(axs[i])
        delta = deltas[i]
        offset = 0
        for alg in algs:
            subdata = data[alg][f"{delta:0.1f}"]
            _plot_min_action_hist_sub(subdata,colors[alg],meta_data,offset=offset)
            offset += options['opt_act_offset']
        
        plt.title(f"{meta_data['llm_pref']}, {K}-arms, Delta = {delta:0.1f}, T={T}")
        plt.legend([names[alg] for alg in algs])
        plt.xlabel('smallest arm count (min_a N_T(a))')
        plt.xticks(np.array(range(num_bins)),labels=bin_width*np.array(range(num_bins)))
        plt.vlines(range(num_bins), 0, 1, linestyle='--', color='black', linewidth=1, alpha=0.5)
        plt.ylim([0,1])
    plt.sca(axs[0])
    plt.ylabel('Fraction of replicates')


def _plot_opt_action_hist_sub(data,color,meta_data,offset=0.0):
    K = meta_data['K']
    T = meta_data['T']
    bin_width = meta_data['bin_width']
    num_bins = int(T/bin_width)

    arm_counts = []
    for item in data:
        arr = []
        for i in range(len(item)):
            tmp = np.zeros(K)
            tmp[item[i][0]] = 1
            arr.append(tmp)
        arr = np.vstack(arr)
        counts = np.sum(arr, axis=0)
        arm_counts.append(counts)
    arm_counts = np.vstack(arm_counts)
    opt_counts = arm_counts[:,0]
    bins = [0 for i in range(num_bins)]
    for i in range(num_bins):
        for j in range(bin_width):
            bins[i] += np.where(opt_counts == bin_width*i+j)[0].shape[0]
    bins[-1] += np.where(opt_counts == T)[0].shape[0]
    plt.bar(np.array(range(num_bins))+offset, np.array(bins)/len(opt_counts),width=options['opt_act_width'],align='edge',color=color)
    
def plot_opt_action_hist(data, meta_data):
    algs = meta_data['algs']
    deltas = meta_data['deltas']
    T = meta_data['T']
    K = meta_data['K']
    bin_width = meta_data['bin_width']
    num_bins = int(T/bin_width)

    if 'axes' in meta_data.keys():
        axs = meta_data['axes']
        f = None
    else:
        (f, axs) = plt.subplots(nrows=1,ncols=len(deltas), figsize=(6*len(deltas),4))
        if type(axs) is not list and type(axs) is not np.ndarray:
            axs = [axs]
    
    for i in range(len(deltas)):
        plt.sca(axs[i])
        delta = deltas[i]
        offset = 0
        for alg in algs:
            subdata = data[alg][f"{delta:0.1f}"]
            _plot_opt_action_hist_sub(subdata,colors[alg],meta_data,offset=offset)
            offset += options['opt_act_offset']
        
        plt.title(f"{meta_data['llm_pref']}, {K}-arms, Delta = {delta:0.1f}, T={T}, eta={meta_data['eta']}, n0={meta_data['n0']}")
        plt.legend([get_name(alg,meta_data) for alg in algs],ncol=2)
        plt.xlabel('Plays of the best arm')
        plt.xticks(np.array(range(num_bins)),labels=bin_width*np.array(range(num_bins)))
        plt.vlines(range(num_bins), 0, 1, linestyle='--', color='black', linewidth=1, alpha=0.5, label="_nolegend_")
        plt.ylim([0,1])

    plt.sca(axs[0])
    plt.ylabel('Fraction of replicates')

def get_opt_action_matrices(data,meta_data):
    algs = meta_data['algs']
    deltas = meta_data['deltas']
    T = meta_data['T']
    K = meta_data['K']
    
    Ms = {}
    for alg in algs:
        Ms[alg] = {}
    for delta in deltas:
        for alg in algs:
            subdata = data[alg][f"{delta:0.1f}"]
            mat = np.zeros((len(subdata), T))
            for i in range(len(subdata)):
                item = subdata[i]
                for t in range(len(item)):
                    mat[i,t] = (item[t][0] == 0)
            Ms[alg][f"{delta:0.1f}"] = mat
    return Ms
