import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import re
import xlsxwriter as xlsx
from plotsettings import BBOX_INCHES, signature, maximizeWindow, palette, makemarkers, add_percent_formatter, legend, show_and_save, nrows_ncols, violin_or_box_plot, adjust_xticks_subplots, table_to_latex
import tikzplotlib




class Plot(object):

    def __init__(self):
        self.C = 5
        pass
        
        
    def plot(self, savefig=None, path=None, experiment = None):
        if experiment == None:
            self.plotMyRegret(savefig = savefig, path=path)
        elif experiment == "M":
            self.plotLastRegretVersusM(savefig = savefig, path=path)
        elif experiment == "K":
            self.plotLastRegretVersusK(savefig = savefig, path=path)
        elif experiment == "T":
            self.plotLastRegretVersusT(savefig = savefig, path=path)


    def plotMyRegret(self, savefig=None, path=None):
        formats = ('png', 'pdf', 'eps')
        # plt.rcParams['figure.figsize'] = (12,8)
        # plt.rcParams['figure.dpi'] = 400
        # plt.rcParams['figure.figsize'] = (20,10)
        # plt.rcParams['figure.dpi'] = 200
        plt.rcParams['font.family'] = "sans-serif"
        plt.rcParams['font.sans-serif'] = "DejaVu Sans"
        plt.rcParams['mathtext.fontset'] = "cm"
        plt.rcParams['mathtext.rm'] = "serif"
        Regrets_file_name = path+"/Regrets_file_name.txt"
        std_Regrets_file_name = path+"/std_Regrets_file_name.txt"
        Regrets_file_name = open(Regrets_file_name).read().split("\n")
        std_Regrets_file_name = open(std_Regrets_file_name).read().split("\n")
        while "" in Regrets_file_name:
            Regrets_file_name.remove("")
        print(Regrets_file_name)
        while "" in std_Regrets_file_name:
            std_Regrets_file_name.remove("")
        print(std_Regrets_file_name)
        Regret_data = {}
        Regret_std = {}
        fig = plt.figure()
        plotParams = self._plotParameter(Regrets_file_name)
        i = 0
        lw = 1
        for filename in Regrets_file_name:
            filename_path = path + "/" + filename
            std_filename_path = path + "/std_" + filename
            Regret_data[filename] = np.loadtxt(filename_path)
            Regret_std[filename] = np.loadtxt(std_filename_path)
            plt.plot(Regret_data[filename], label = filename[8:-4], color = plotParams[filename]["color"], marker=plotParams[filename]["marker"], markevery=(1 / 40., 0.1), ls = plotParams[filename]["ls"] , lw=lw, ms=int(self.C*lw))
            x = np.linspace(0, len(Regret_std[filename])-1, len(Regret_std[filename]))
            plt.fill_between(x, Regret_data[filename]-Regret_std[filename], Regret_data[filename]+Regret_std[filename], color = plotParams[filename]["color"], alpha = 0.2)
            i = i+1
        print(Regret_data)
        # print(type(Regret_data[filename][2]))
        plt.grid(True)
        plt.legend(loc = 'upper left', fontsize = 'x-small')
        # plt.ylim(-30, 1300)
        # self._xlabel(0, r"Time steps $t = 1...T$, horizon $T = {}${}".format(self.horizon, self.signature))
        # env = self.envs[envId]
        # if hasattr(env, 'changePoints'):
        #     ymin, ymax = plt.ylim()
        #     # ymin, ymax = plt.ylim(0, ymax)
        #     taus = self.envs[envId].changePoints
        #     # if len(taus) > 25:
        #     #     print("WARNING: Adding vlines for the change points with more than 25 change points will be ugly on the plots...")  # DEBUG
        #     # if len(taus) > 50:  # Force to NOT add the vlines
        #     #     return plt.xlabel(r"Time steps $t = 1...T$, horizon $T = {}${}".format(self.horizon, self.signature))
        #     for tau in taus:
        #         if tau > 0 and tau < self.horizon:
        #             plt.vlines(tau, ymin, ymax, linestyles='dotted', alpha=0.2)
        # plt.xlim(0, self.horizon)
        plt.xlabel(r"Time steps $t = 1...T$, horizon", fontsize=16)
        # plt.xlabel(r"Time steps $t = 1...T$")
        plt.ylabel("Regret", fontsize=16)
        
        # for form in formats:
        #     plt.savefig("{}.{}".format(savefig, form),format=form, dpi=400)
        # plt.show()
        tikzplotlib.save(savefig+".tex")
        show_and_save(False , savefig=savefig, fig=fig, pickleit=None)
        plt.close(fig)
        return fig
    
    def plotLastRegretVersusM(self, savefig=None, path=None):
        formats = ('png', 'pdf', 'eps')
        # plt.rcParams['figure.figsize'] = (12,8)
        # plt.rcParams['figure.dpi'] = 400
        # plt.rcParams['figure.figsize'] = (20,10)
        # plt.rcParams['figure.dpi'] = 200
        plt.rcParams['font.family'] = "sans-serif"
        plt.rcParams['font.sans-serif'] = "DejaVu Sans"
        plt.rcParams['mathtext.fontset'] = "cm"
        plt.rcParams['mathtext.rm'] = "serif"
        Regrets_file_name = path+"/Regrets_versus_M_file_name.txt"
        std_Regrets_file_name = path+"/std_Regrets_versus_M_file_name.txt"
        Regrets_file_name = open(Regrets_file_name).read().split("\n")
        std_Regrets_file_name = open(std_Regrets_file_name).read().split("\n")
        while "" in Regrets_file_name:
            Regrets_file_name.remove("")
        print(Regrets_file_name)
        while "" in Regrets_file_name:
            std_Regrets_file_name.remove("")
        print(std_Regrets_file_name)
        Regret_data = {}
        Regret_std = {}
        fig = plt.figure()
        plotParams = self._plotParameter(Regrets_file_name)
        i = 0
        lw = 1
        nb_break_points_path = path + "/nb_break_points.csv"
        nb_break_points = np.loadtxt(nb_break_points_path)
        for filename in Regrets_file_name:
            filename_path = path + "/" + filename
            std_filename_path = path + "/std_" + filename
            Regret_data[filename] = np.loadtxt(filename_path)
            Regret_std[filename] = np.loadtxt(std_filename_path)
            plt.plot(nb_break_points, Regret_data[filename], label = filename[17:-4], color = plotParams[filename]["color"], marker=plotParams[filename]["marker"], markevery=(i / 40., 0.1), ls = plotParams[filename]["ls"] , linewidth=lw, ms=int(self.C*lw))
            x = np.linspace(0, len(Regret_std[filename])-1, len(Regret_std[filename]))
            plt.fill_between(nb_break_points, Regret_data[filename]-Regret_std[filename], Regret_data[filename]+Regret_std[filename], color = plotParams[filename]["color"], alpha = 0.2)
            i = i+1
        print(Regret_data)
        # print(type(Regret_data[filename][2]))
        plt.grid(True)
        plt.legend(loc = 'upper left', fontsize = 'x-small')
        # self._xlabel(0, r"Time steps $t = 1...T$, horizon $T = {}${}".format(self.horizon, self.signature))
        # env = self.envs[envId]
        # if hasattr(env, 'changePoints'):
        #     ymin, ymax = plt.ylim()
        #     # ymin, ymax = plt.ylim(0, ymax)
        #     taus = self.envs[envId].changePoints
        #     # if len(taus) > 25:
        #     #     print("WARNING: Adding vlines for the change points with more than 25 change points will be ugly on the plots...")  # DEBUG
        #     # if len(taus) > 50:  # Force to NOT add the vlines
        #     #     return plt.xlabel(r"Time steps $t = 1...T$, horizon $T = {}${}".format(self.horizon, self.signature))
        #     for tau in taus:
        #         if tau > 0 and tau < self.horizon:
        #             plt.vlines(tau, ymin, ymax, linestyles='dotted', alpha=0.2)
        # plt.xlim(0, self.horizon)
        plt.xlabel(r"$M$", fontsize = 16)
        # plt.xlabel(r"Time steps $t = 1...T$")
        plt.ylabel("Regret", fontsize = 16)
        
        # for form in formats:
        #     plt.savefig("{}.{}".format(savefig, form),format=form, dpi=400)
        # plt.show()
        tikzplotlib.save(savefig+".tex")
        show_and_save(False , savefig=savefig, fig=fig, pickleit=None)
        plt.close(fig)
        return fig
    
    def plotLastRegretVersusK(self, savefig=None, path=None):
        formats = ('png', 'pdf', 'eps')
        # plt.rcParams['figure.figsize'] = (12,8)
        # plt.rcParams['figure.dpi'] = 400
        # plt.rcParams['figure.figsize'] = (20,10)
        # plt.rcParams['figure.dpi'] = 200
        plt.rcParams['font.family'] = "sans-serif"
        plt.rcParams['font.sans-serif'] = "DejaVu Sans"
        plt.rcParams['mathtext.fontset'] = "cm"
        plt.rcParams['mathtext.rm'] = "serif"
        Regrets_file_name = path+"/Regrets_versus_K_file_name.txt"
        # std_Regrets_file_name = path+"/std_Regrets_versus_K_file_name.txt"
        Regrets_file_name = open(Regrets_file_name).read().split("\n")
        # std_Regrets_file_name = open(std_Regrets_file_name).read().split("\n")
        while "" in Regrets_file_name:
            Regrets_file_name.remove("")
        print(Regrets_file_name)
        # while "" in Regrets_file_name:
        #     std_Regrets_file_name.remove("")
        # print(std_Regrets_file_name)
        Regret_data = {}
        Regret_std = {}
        fig = plt.figure()
        plotParams = self._plotParameter(Regrets_file_name)
        i = 0
        lw = 1
        nb_arms_path = path + "/nb_arms.csv"
        nb_arms = np.loadtxt(nb_arms_path)
        for filename in Regrets_file_name:
            filename_path = path + "/" + filename
            std_filename_path = path + "/std_" + filename
            Regret_data[filename] = np.loadtxt(filename_path)
            Regret_std[filename] = np.loadtxt(std_filename_path)
            plt.plot(nb_arms, Regret_data[filename], label = filename[17:-4], color = plotParams[filename]["color"], marker=plotParams[filename]["marker"], markevery=(i / 40., 0.1), ls = plotParams[filename]["ls"] , linewidth=lw, ms=int(self.C*lw))
            # plt.fill_between(nb_arms, Regret_data[filename]-Regret_std[filename], Regret_data[filename]+Regret_std[filename], color = plotParams[filename]["color"], alpha = 0.2)
            # i = i+1
            # plt.errorbar(nb_arms, Regret_data[filename], yerr=Regret_std[filename], label = filename[17:-4], color = plotParams[filename]["color"], marker=plotParams[filename]["marker"], markevery=(i / 50., 0.1), ls = plotParams[filename]["ls"] , linewidth=lw, ms=int(7*lw))
            # x = np.linspace(0, len(Regret_std[filename])-1, len(Regret_std[filename]))
            # plt.fill_between(x, Regret_data[filename]-Regret_std[filename], Regret_data[filename]+Regret_std[filename], color = plotParams[filename]["color"], alpha = 0.2)
            # i = i+1
        print(Regret_data)
        # print(type(Regret_data[filename][2]))
        plt.grid(True)
        plt.legend(loc = 'upper left', fontsize = 'x-small')
        # self._xlabel(0, r"Time steps $t = 1...T$, horizon $T = {}${}".format(self.horizon, self.signature))
        # env = self.envs[envId]
        # if hasattr(env, 'changePoints'):
        #     ymin, ymax = plt.ylim()
        #     # ymin, ymax = plt.ylim(0, ymax)
        #     taus = self.envs[envId].changePoints
        #     # if len(taus) > 25:
        #     #     print("WARNING: Adding vlines for the change points with more than 25 change points will be ugly on the plots...")  # DEBUG
        #     # if len(taus) > 50:  # Force to NOT add the vlines
        #     #     return plt.xlabel(r"Time steps $t = 1...T$, horizon $T = {}${}".format(self.horizon, self.signature))
        #     for tau in taus:
        #         if tau > 0 and tau < self.horizon:
        #             plt.vlines(tau, ymin, ymax, linestyles='dotted', alpha=0.2)
        # plt.xlim(0, self.horizon)
        plt.xlabel(r"$K$", fontsize = 16)
        # plt.xlabel(r"Time steps $t = 1...T$")
        plt.ylabel("Regret", fontsize = 16)
        
        # for form in formats:
        #     plt.savefig("{}.{}".format(savefig, form),format=form, dpi=400)
        # plt.show()
        tikzplotlib.save(savefig+".tex")
        show_and_save(False , savefig=savefig, fig=fig, pickleit=None)
        plt.close(fig)
        return fig
    
    def plotLastRegretVersusT(self, savefig=None, path=None):
        formats = ('png', 'pdf', 'eps')
        # plt.rcParams['figure.figsize'] = (12,8)
        # plt.rcParams['figure.dpi'] = 400
        # plt.rcParams['figure.figsize'] = (20,10)
        # plt.rcParams['figure.dpi'] = 200
        plt.rcParams['font.family'] = "sans-serif"
        plt.rcParams['font.sans-serif'] = "DejaVu Sans"
        plt.rcParams['mathtext.fontset'] = "cm"
        plt.rcParams['mathtext.rm'] = "serif"
        Regrets_file_name = path+"/Regrets_versus_T_file_name.txt"
        # std_Regrets_file_name = path+"/std_Regrets_versus_T_file_name.txt"
        Regrets_file_name = open(Regrets_file_name).read().split("\n")
        # std_Regrets_file_name = open(std_Regrets_file_name).read().split("\n")
        while "" in Regrets_file_name:
            Regrets_file_name.remove("")
        print(Regrets_file_name)
        # while "" in Regrets_file_name:
        #     std_Regrets_file_name.remove("")
        # print(std_Regrets_file_name)
        Regret_data = {}
        Regret_std = {}
        fig = plt.figure()
        plotParams = self._plotParameter(Regrets_file_name)
        i = 0
        lw = 1
        horizon_path = path + "/T.csv"
        horizon = np.loadtxt(horizon_path)
        for filename in Regrets_file_name:
            filename_path = path + "/" + filename
            std_filename_path = path + "/std_" + filename
            Regret_data[filename] = np.loadtxt(filename_path)
            Regret_std[filename] = np.loadtxt(std_filename_path)
            plt.plot(horizon, Regret_data[filename], label = filename[17:-4], color = plotParams[filename]["color"], marker=plotParams[filename]["marker"], markevery=(i / 40., 0.1), ls = plotParams[filename]["ls"] , linewidth=lw, ms=int(self.C*lw))
            x = np.linspace(0, len(Regret_std[filename])-1, len(Regret_std[filename]))
            plt.fill_between(horizon, Regret_data[filename]-Regret_std[filename], Regret_data[filename]+Regret_std[filename], color = plotParams[filename]["color"], alpha = 0.2)
            i = i+1
        print(Regret_data)
        # print(type(Regret_data[filename][2]))
        plt.grid(True)
        plt.legend(loc = 'upper left', fontsize = 'x-small')
        # self._xlabel(0, r"Time steps $t = 1...T$, horizon $T = {}${}".format(self.horizon, self.signature))
        # env = self.envs[envId]
        # if hasattr(env, 'changePoints'):
        #     ymin, ymax = plt.ylim()
        #     # ymin, ymax = plt.ylim(0, ymax)
        #     taus = self.envs[envId].changePoints
        #     # if len(taus) > 25:
        #     #     print("WARNING: Adding vlines for the change points with more than 25 change points will be ugly on the plots...")  # DEBUG
        #     # if len(taus) > 50:  # Force to NOT add the vlines
        #     #     return plt.xlabel(r"Time steps $t = 1...T$, horizon $T = {}${}".format(self.horizon, self.signature))
        #     for tau in taus:
        #         if tau > 0 and tau < self.horizon:
        #             plt.vlines(tau, ymin, ymax, linestyles='dotted', alpha=0.2)
        # plt.xlim(0, self.horizon)
        plt.xlabel(r"$T$", fontsize = 16)
        # plt.xlabel(r"Time steps $t = 1...T$")
        plt.ylabel("Regret", fontsize = 16)
        
        # for form in formats:
        #     plt.savefig("{}.{}".format(savefig, form),format=form, dpi=400)
        # plt.show()
        tikzplotlib.save(savefig+".tex")
        show_and_save(False , savefig=savefig, fig=fig, pickleit=None)
        plt.close(fig)
        return fig
    
    def _plotParameter(self, Regrets_file_name):
        plotParams = {}
        allmarkers = ['o', 'D', 'v', 'p', 's', '*', 'h', '>',"x"]
        i = 0
        for filename in Regrets_file_name:
            plotParams[filename] = {}
            plotParams[filename]["color"] = None 
            plotParams[filename]["marker"] = None
            plotParams[filename]["ls"] = None
            if filename.find("CUSUM-UCB") != -1:
                plotParams[filename]["color"] = "orange"
                plotParams[filename]["marker"] =  allmarkers[i]
                i = i + 1
            elif filename.find("M-UCB") != -1:
                # plotParams[filename]["color"] = "cornflowerblue"
                # plotParams[filename]["color"] = "#1f77b4"
                plotParams[filename]["color"] = "gray"
                if filename.find("M-UCB(w=80") != -1:
                    plotParams[filename]["color"] = "cyan"
                elif filename.find("M-UCB(w=160") != -1:
                    plotParams[filename]["color"] = "deepskyblue"
                elif filename.find("M-UCB(w=320") != -1:
                    plotParams[filename]["color"] = "aquamarine"
                elif filename.find("M-UCB(w=720") != -1:
                    plotParams[filename]["color"] = "greenyellow"
                plotParams[filename]["marker"] =  allmarkers[i]
                i = i + 1

            elif filename.find("GLR-UCB") != -1:
                plotParams[filename]["color"] = "olive"
                # plotParams[filename]["color"] = "red"
                plotParams[filename]["marker"] =  allmarkers[i]
                i = i + 1
            elif filename.find("Oracle kl-UCB") != -1:
                plotParams[filename]["color"] = "#1E76B4"
            elif filename.find("kl-UCB") != -1:
                plotParams[filename]["color"] = "#FF7E0D"
                plotParams[filename]["ls"] = (0, (2, 3))
            elif filename.find("CUSUM-klUCB") != -1:
                # plotParams[filename]["color"] = "#E376C2"
                plotParams[filename]["color"] = "orange"
            elif filename.find("M-klUCB") != -1:
                plotParams[filename]["color"] = "#8B564A"
            elif filename.find("AdSwitch") != -1:
                plotParams[filename]["color"] = "sienna"
                plotParams[filename]["marker"] =  allmarkers[i]
                i = i + 1
            elif filename.find("ArmSwitch") != -1:
                plotParams[filename]["color"] = "cyan"
                plotParams[filename]["marker"] =  allmarkers[i]
                i = i + 1
            elif filename.find("Discounted-TS") != -1:
                plotParams[filename]["color"] = "royalblue"
                plotParams[filename]["marker"] =  allmarkers[i]
                i = i + 1
            elif filename.find("Discounted-klUCB") != -1:
                plotParams[filename]["color"] = "magenta"
                plotParams[filename]["marker"] =  allmarkers[i]
                i = i + 1
            elif filename.find("Master+UCB1") != -1:
                plotParams[filename]["color"] = "springgreen"
                plotParams[filename]["marker"] =  allmarkers[i]
                i = i + 1
            if filename.find("diminishing") != -1:
                plotParams[filename]["ls"] = "--"
        return plotParams