from Environment import *
from Algorithms import *
import numpy as np
import matplotlib.pyplot as plt
import pickle
from pathlib import Path
import matplotlib.gridspec as gridspec


def plot(T,repeat,d,L,N,K):

    exp_reward=dict()
    avg_regret_sum=dict()
    regret_sum_list=dict()
    std=dict()
    regret=dict()
    for algorithm in ['UCB','Elimination','Elimination2','UCB-QMB','UCB-GS','ETC-GS','TS-QMB']:
        exp_reward[algorithm]=[]    
        regret_sum_list[algorithm]=np.zeros((repeat,T),float)
        avg_regret_sum[algorithm]=np.zeros(T,float)
        std[algorithm]=np.zeros(T,float)
    algorithms=['Elimination','UCB','Elimination2','ETC-GS','UCB-QMB','UCB-GS']
    algorithms=['Elimination','Elimination2','UCB','TS-QMB','UCB-QMB']

    gs = gridspec.GridSpec(1,1) 
    fig = plt.figure(figsize=(8, 6))
    bool_initial=True
    # for n,N in enumerate([N]):
    #     for k,K in enumerate([K]):
    for algorithm in algorithms:
        name=algorithm
        avg_regret_sum[algorithm]=np.zeros(T,float)

        for i in range(repeat):
            # if name=='TS-QMB':
            #     filename_1=name+'T'+str(T)+'d'+str(d)+'N'+str(N)+'K'+str(K)+'L'+str(L)+'repeat'+str(i)+'regret2.txt'
            # else:
            filename_1=name+'T'+str(T)+'d'+str(d)+'N'+str(N)+'K'+str(K)+'L'+str(L)+'repeat'+str(i)+'regret.txt'
            pickle_file1 = open('./result/'+filename_1, "rb")
            objects = []

            while True:
                try:
                    objects.append(pickle.load(pickle_file1))
                except EOFError:
                    break
            pickle_file1.close()
            regret[name]=objects[0]
            regret_sum_list[name][i,:]=objects[0]
            avg_regret_sum[name]+=objects[0]  

        regret[name]=avg_regret_sum[name]/repeat
        std[name]=np.std(regret_sum_list[name],axis=0)


    T_p=int(T/10)
    ax = fig.add_subplot(gs[0, 0])
    ax.tick_params(labelsize=22)
    
    ax.yaxis.get_offset_text().set_fontsize(22)
    ax.xaxis.get_offset_text().set_fontsize(22)
    
    
    
    
    # plt.gcf().subplots_adjust(left=0.20)
    plt.subplots_adjust(left=0.2)
    plt.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
    color=['violet','limegreen','royalblue','lightsalmon','gold' ,'tomato']
    marker_list=['^','s','v','o','P','D']
    if bool_initial==True:
        for i, algorithm in enumerate(algorithms):
            name=algorithm
            if name=='Elimination':
                col=color[i]
                mark=marker_list[i]
                ax.plot(range(T),regret[name],color=col, marker=mark, label='B-SMB (Algorithm1)', markersize=10,markevery=T_p,zorder=6-i)
                ax.errorbar(range(T), regret[name], yerr=1.96*std[name]/np.sqrt(repeat), color=col, errorevery=T_p, capsize=6,zorder=6-i)
            elif name=='Elimination2':
                col=color[i]
                mark=marker_list[i]
                ax.plot(range(T),regret[name],color=col, marker=mark, label=r'B-SMB$^+$(Algorithm2)', markersize=10,markevery=T_p,zorder=4-i)
                ax.errorbar(range(T), regret[name], yerr=1.96*std[name]/np.sqrt(repeat), color=col, errorevery=T_p, capsize=6,zorder=4-i) 
            elif name=='UCB':
                col=color[i]
                mark=marker_list[i]
                ax.plot(range(T),regret[name],color=col, marker=mark, label=r'OFU-MNL$^+$', markersize=10,markevery=T_p,zorder=6-i)
                ax.errorbar(range(T), regret[name], yerr=1.96*std[name]/np.sqrt(repeat), color=col, errorevery=T_p, capsize=6,zorder=6-i)     
            else:
                col=color[i]
                mark=marker_list[i]
                ax.plot(range(T),regret[name],color=col, marker=mark, label=name, markersize=10,markevery=T_p,zorder=6-i)
                ax.errorbar(range(T), regret[name], yerr=1.96*std[name]/np.sqrt(repeat), color=col, errorevery=T_p, capsize=6,zorder=6-i)
    plt.rc('legend',fontsize=18)

    plt.title('Regret per Algorithm',fontsize=22,fontweight='bold')
    plt.xlabel(r'Time $t$',fontsize=22)
    plt.ylabel(r'$\mathcal{R}(t)$',fontsize=22)
    bool_initial=False

            # elif K==2 and N==5:
            #     for i, algorithm in enumerate(algorithms):
            #         name=algorithm
            #         col=color[i]
            #         mark=marker_list[i]
            #         ax.plot(range(T),regret[name],color=col, marker=mark, markersize=10,markevery=T_p,zorder=6-i)
            #         ax.errorbar(range(T), regret[name], yerr=1.96*std[name]/np.sqrt(repeat), color=col, errorevery=T_p, capsize=6,zorder=6-i)

            #     plt.title('N={}, K={}'.format(N,K),fontsize=22)
            #     plt.xlabel('Time step '+r'$t$',fontsize=22)
            #     plt.ylabel(r'$\mathcal{R}(t)$',fontsize=22)


            # elif K!=2 and N==4:
            #     for i, algorithm in enumerate(algorithms):
            #         name=algorithm
            #         col=color[i]
            #         mark=marker_list[i] 
            #         ax.plot(range(T),regret[name],color=col, marker=mark, markersize=10,markevery=T_p,zorder=6-i)
            #         ax.errorbar(range(T), regret[name], yerr=1.96*std[name]/np.sqrt(repeat), color=col, errorevery=T_p, capsize=6,zorder=6-i)

            #     plt.title('N={}, K={}'.format(N,K),fontsize=22)


            # elif K!=2 and N==5:
            #     for i, algorithm in enumerate(algorithms):
            #         name=algorithm
            #         col=color[i]
            #         mark=marker_list[i]
            #         ax.plot(range(T),regret[name],color=col, marker=mark, markersize=10,markevery=T_p,zorder=6-i)
            #         ax.errorbar(range(T), regret[name], yerr=1.96*std[name]/np.sqrt(repeat), color=col, errorevery=T_p, capsize=6,zorder=6-i)
            #     plt.title('N={}, K={}'.format(N,K),fontsize=22)
            #     plt.xlabel('Time step '+r'$t$',fontsize=22)
    Path("./plot").mkdir(parents=True, exist_ok=True)
    lines_labels = [ax.get_legend_handles_labels() for ax in fig.axes]
    lines, labels = [sum(lol, []) for lol in zip(*lines_labels)]
    # labels = [labels[5],labels[4],labels[3],labels[2], labels[1], labels[0]]
    # lines=[lines[5],lines[4],lines[3],lines[2],lines[1],lines[0]]
    labels = [labels[4],labels[3],labels[2], labels[1], labels[0]]
    lines=[lines[4],lines[3],lines[2],lines[1],lines[0]]
    # labels = [labels[1], labels[0]]
    # lines=[lines[1],lines[0]]
    fig.legend(lines, labels, loc="upper left", bbox_to_anchor=(0.13, 0.9))
    plt.tight_layout()
    plt.savefig('./plot/T'+str(T)+'d'+str(d)+'N'+str(N)+'K'+str(K)+'L'+str(L)+'repeat'+str(repeat)+'.pdf', bbox_inches = "tight")
    plt.show()  
    


def plot2(T,repeat,d,L,N,K):

    exp_reward=dict()
    avg_time=dict()
    sum_time=dict()
    time=dict()
    std=dict()
    regret=dict()
    for algorithm in ['Elimination','Elimination2','UCB','TS-QMB','UCB-QMB']:
        exp_reward[algorithm]=[]    
        time[algorithm]=np.zeros((repeat,1),float)
        avg_time[algorithm]=np.zeros(1,float)
        sum_time[algorithm]=np.zeros(1,float)

        std[algorithm]=np.zeros(1,float)
    algorithms=['Elimination','Elimination2','UCB','TS-QMB','UCB-QMB']
    gs = gridspec.GridSpec(1,1) 
    fig = plt.figure(figsize=(10, 6))
    bool_initial=True

    for algorithm in algorithms:
        name=algorithm
        avg_time[algorithm]=np.zeros(1,float)

        for i in range(repeat):

            filename_1=name+'_N_'+str(N)+'K_'+str(K)+'_repeat'+str(i)+'_time.txt'
            pickle_file1 = open('./result/'+filename_1, "rb")
            objects = []

            while True:
                try:
                    objects.append(pickle.load(pickle_file1))
                except EOFError:
                    break
            pickle_file1.close()
            # print(object[0])
            time[name][i,:]=objects[0]
            sum_time[name]+=objects[0]

        avg_time[name]=sum_time[name]/repeat
        std[name]=1.96*np.std(time[name],axis=0)/np.sqrt(repeat)


        # T_p=int(T/10)
        # ax = fig.add_subplot(gs[1, 1])
        # ax.tick_params(labelsize=22)
        # plt.rc('legend',fontsize=22)
        # ax.yaxis.get_offset_text().set_fontsize(22)
        # ax.xaxis.get_offset_text().set_fontsize(22)
        # plt.gcf().subplots_adjust(bottom=0.20)

        plt.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
        color=['violet','royalblue','limegreen','lightsalmon','gold' ,'tomato']
        marker_list=['^','s','v','o','P','D']
        # if bool_initial==True:
        #     for i, algorithm in enumerate(algorithms):
        #         name=algorithm
        #         if name=='Elimination':
        #             col=color[i]
        #             mark=marker_list[i]
        #             ax.plot(range(T),regret[name],color=col, marker=mark, label='SMB (Algorithm1)', markersize=10,markevery=T_p,zorder=6-i)
        #             ax.errorbar(range(T), regret[name], yerr=1.96*std[name]/np.sqrt(repeat), color=col, errorevery=T_p, capsize=6,zorder=6-i)
        #         else:
        #             col=color[i]
        #             mark=marker_list[i]
        #             ax.plot(range(T),regret[name],color=col, marker=mark, label=name, markersize=10,markevery=T_p,zorder=6-i)
        #             ax.errorbar(range(T), regret[name], yerr=1.96*std[name]/np.sqrt(repeat), color=col, errorevery=T_p, capsize=6,zorder=6-i)

    fig, ax = plt.subplots(figsize=(10, 6))
    # cleaned_data = {key: value.item() for key, value in data.items()}
    avg_time = {key: value.item() for key, value in avg_time.items()}
    std = {key: value.item() for key, value in std.items()}

    bars = list(avg_time.keys())
    print(bars)
    for key in enumerate(bars):
        if key[1]=='UCB':
            bars[key[0]]='OFU-MNL$^+$'
        elif key[1]=='Elimination':
            bars[key[0]]='B-SMB\n(Alg1)'
        elif key[1]=='Elimination2':
            bars[key[0]]='B-SMB$^+$\n(Alg2)'
    
    times = list(avg_time.values())
    stds = list(std.values())

    # print(stds2)
    bar_colors = [ 'violet',  'limegreen', 'royalblue','lightsalmon', 'gold', 'tomato']
    # print(avg_time.values())
    ax.bar(bars, times,yerr=stds ,capsize=5,color=bar_colors, edgecolor='black', width=0.6)
    ax.set_ylabel('Time (seconds)', fontsize=22)
    ax.set_xlabel('Algorithms', fontsize=22)
    ax.set_title('Runtime per Algorithm', fontsize=22,fontweight='bold')
    ax.tick_params(axis='both', labelsize=22)
    plt.xticks(fontsize=22)
    # Add time values at the end of each bar
    # for i, v in enumerate(times):
    #     ax.text(v + 0.05, i, f'{v:.2f} s', va='center', fontsize=12)

    # Save the plot
    Path("./plot").mkdir(parents=True, exist_ok=True)
    plt.tight_layout()
    plt.savefig('./plot/'+'T'+str(T)+'d'+str(d)+'N'+str(N)+'K'+str(K)+'L'+str(L)+'repeat'+str(repeat)+'runtime_bar_graph.pdf', bbox_inches="tight")
    plt.show()

if __name__=='__main__':
    d=2
    L=2
    T=5000
    repeat=10
    N=7
    K=4
    plot(T,repeat,d,L,N,K)

    plot2(T,repeat,d,L,N,K)
   


