from Environment import *
from Algorithms import *
import numpy as np
import matplotlib.pyplot as plt
import pickle
from pathlib import Path
import matplotlib.gridspec as gridspec


def plot(T,repeat,d,L):

    exp_reward=dict()
    avg_regret_sum=dict()
    regret_sum_list=dict()
    std=dict()
    regret=dict()
    for algorithm in ['Elimination','ETC-GS','UCB-GS']:
        exp_reward[algorithm]=[]    
        regret_sum_list[algorithm]=np.zeros((repeat,T),float)
        avg_regret_sum[algorithm]=np.zeros(T,float)
        std[algorithm]=np.zeros(T,float)
    algorithms=['Elimination','ETC-GS','UCB-GS']
    gs = gridspec.GridSpec(2,3) 
    fig = plt.figure(figsize=(18, 9))
    bool_initial=True
    for n,N in enumerate([4, 5]):
        for k,K in enumerate([2,3,4]):
            for algorithm in   algorithms:
                name=algorithm
                avg_regret_sum[algorithm]=np.zeros(T,float)

                for i in range(repeat):

                    filename_1=name+'T'+str(T)+'d'+str(d)+'N'+str(N)+'K'+str(K)+'L'+str(L)+'repeat'+str(i)+'regret.txt'
                    pickle_file1 = open('./result/'+filename_1, "rb")
                    objects = []

                    while True:
                        try:
                            objects.append(pickle.load(pickle_file1))
                        except EOFError:
                            break
                    pickle_file1.close()
                    regret[name]=objects[0]
                    regret_sum_list[name][i,:]=objects[0]
                    avg_regret_sum[name]+=objects[0]  

                regret[name]=avg_regret_sum[name]/repeat
                std[name]=np.std(regret_sum_list[name],axis=0)


            T_p=int(T/10)
            ax = fig.add_subplot(gs[n, k])
            ax.tick_params(labelsize=22)
            plt.rc('legend',fontsize=22)
            ax.yaxis.get_offset_text().set_fontsize(22)
            ax.xaxis.get_offset_text().set_fontsize(22)
            plt.gcf().subplots_adjust(bottom=0.20)

            plt.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
            color=['violet','royalblue','limegreen','lightsalmon','gold' ,'tomato']
            marker_list=['^','s','v','o','P','D']
            if bool_initial==True:
                for i, algorithm in enumerate(algorithms):
                    name=algorithm
                    if name=='Elimination':
                        col=color[i]
                        mark=marker_list[i]
                        ax.plot(range(T),regret[name],color=col, marker=mark, label='SMB', markersize=10,markevery=T_p,zorder=6-i)
                        ax.errorbar(range(T), regret[name], yerr=1.96*std[name]/np.sqrt(repeat), color=col, errorevery=T_p, capsize=6,zorder=6-i)
                    else:
                        col=color[i]
                        mark=marker_list[i]
                        ax.plot(range(T),regret[name],color=col, marker=mark, label=name, markersize=10,markevery=T_p,zorder=6-i)
                        ax.errorbar(range(T), regret[name], yerr=1.96*std[name]/np.sqrt(repeat), color=col, errorevery=T_p, capsize=6,zorder=6-i)

                plt.title('N={}, K={}'.format(N,K),fontsize=22)
                plt.ylabel(r'$\mathcal{R}(t)$',fontsize=22)
                bool_initial=False

            elif K==2 and N==5:
                for i, algorithm in enumerate(algorithms):
                    name=algorithm
                    col=color[i]
                    mark=marker_list[i]
                    ax.plot(range(T),regret[name],color=col, marker=mark, markersize=10,markevery=T_p,zorder=6-i)
                    ax.errorbar(range(T), regret[name], yerr=1.96*std[name]/np.sqrt(repeat), color=col, errorevery=T_p, capsize=6,zorder=6-i)

                plt.title('N={}, K={}'.format(N,K),fontsize=22)
                plt.xlabel('Time step '+r'$t$',fontsize=22)
                plt.ylabel(r'$\mathcal{R}(t)$',fontsize=22)


            elif K!=2 and N==4:
                for i, algorithm in enumerate(algorithms):
                    name=algorithm
                    col=color[i]
                    mark=marker_list[i] 
                    ax.plot(range(T),regret[name],color=col, marker=mark, markersize=10,markevery=T_p,zorder=6-i)
                    ax.errorbar(range(T), regret[name], yerr=1.96*std[name]/np.sqrt(repeat), color=col, errorevery=T_p, capsize=6,zorder=6-i)

                plt.title('N={}, K={}'.format(N,K),fontsize=22)


            elif K!=2 and N==5:
                for i, algorithm in enumerate(algorithms):
                    name=algorithm
                    col=color[i]
                    mark=marker_list[i]
                    ax.plot(range(T),regret[name],color=col, marker=mark, markersize=10,markevery=T_p,zorder=6-i)
                    ax.errorbar(range(T), regret[name], yerr=1.96*std[name]/np.sqrt(repeat), color=col, errorevery=T_p, capsize=6,zorder=6-i)
                plt.title('N={}, K={}'.format(N,K),fontsize=22)
                plt.xlabel('Time step '+r'$t$',fontsize=22)
    Path("./plot").mkdir(parents=True, exist_ok=True)
    lines_labels = [ax.get_legend_handles_labels() for ax in fig.axes]
    lines, labels = [sum(lol, []) for lol in zip(*lines_labels)]
    labels = [labels[2], labels[1], labels[0]]
    lines=[lines[2],lines[1],lines[0]]
    fig.legend(lines, labels, loc='upper center', ncol=4,bbox_to_anchor=(0.5, 1.13))
    plt.tight_layout()
    plt.savefig('./plot/T'+str(T)+'d'+str(d)+'L'+str(L)+'repeat'+str(repeat)+'.pdf', bbox_inches = "tight")
    plt.show()  
    

if __name__=='__main__':
    d=3
    L=2
    T=1000
    repeat=10  
    plot(T,repeat,d,L)
   


