

class evaluate(object):

    def __init__(self):
        print("^w^")


    def __saveRegret(self, path, experiment, string, ):
        if experiment == "M":
            print("Regret saving ...")
            filename = path+'/'+'Regrets_versus_M_'+string+'.csv'
            np.savetxt(filename, self.regret_versus_M)
        else:
            print("Regret saving ...")
            filename = path+'/'+'Regrets_'+self.__str__()+'.csv'
            np.savetxt(filename, self.regret)

    def __saveFileName(self, path):
        if self.experiment == "M":
            print("FileName saving ...")
            filename = path+'/Regrets_versus_M_file_name.txt'
            f = open(filename,'a')
            f.write('Regrets_versus_M_'+self.__str__()+'.csv\n')
            f.close()
        else:
            print("FileName saving ...")
            filename = path+'/Regrets_file_name.txt'
            f = open(filename,'a')
            f.write('Regrets_'+self.__str__()+'.csv\n')
            f.close()

    def plot_arm(self, mean, std):
        if self.path is not None:
            color = ["#1F77B4", "#FF7F0E", "#2CA02C"]
            formats = ('png', 'pdf', 'eps')
            # plt.rcParams['figure.figsize'] = (12,8)
            # plt.rcParams['figure.dpi'] = 400
            # plt.rcParams['figure.figsize'] = (20,10)
            # plt.rcParams['figure.dpi'] = 200
            plt.rcParams['font.family'] = "sans-serif"
            plt.rcParams['font.sans-serif'] = "DejaVu Sans"
            plt.rcParams['mathtext.fontset'] = "cm"
            plt.rcParams['mathtext.rm'] = "serif"
            fig = plt.figure()
            x = np.linspace(0, self.horizon-1, self.horizon)
            lw = 3
            if self.nb_arms < 4:
                for i in range(self.nb_arms):
                    plt.plot(mean[i], label = "arm {}".format(i+1),color = color[i], lw=lw)
                    plt.fill_between(x, mean[i]-std[i], mean[i]+std[i],color = color[i], alpha = 0.2)   
            else:        
                for i in range(self.nb_arms):
                    plt.plot(mean[i], label = "arm {}".format(i+1), lw=lw)
                    plt.fill_between(x, mean[i]-std[i], mean[i]+std[i], alpha = 0.2)  
            plt.grid(True)
            plt.legend(loc = 'upper left')
            plt.xlabel(r"T")
            # plt.xlabel(r"Time steps $t = 1...T$")
            plt.ylabel(r"Number of arm $i$ pulled")
            savefig = self.path+'/'+'arm_'+self.__str__()
            show_and_save(False , savefig=savefig, fig=fig, pickleit=None)
            plt.close()
        return fig