import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import re
import xlsxwriter as xlsx
from plotsettings import BBOX_INCHES, signature, maximizeWindow, palette, makemarkers, add_percent_formatter, legend, show_and_save, nrows_ncols, violin_or_box_plot, adjust_xticks_subplots, table_to_latex
import tikzplotlib
import argparse


def main(variant):
    path    = variant.get('path')
    statement  = variant.get('statement')
    savefig = path + "/main" + statement
    alg_str = path+"/alg_str" + statement + ".txt"
    # plt.rcParams['figure.figsize'] = (12,8)
    # plt.rcParams['figure.dpi'] = 400
    # plt.rcParams['figure.figsize'] = (20,10)
    # plt.rcParams['figure.dpi'] = 200
    plt.rcParams['font.family'] = "sans-serif"
    plt.rcParams['font.sans-serif'] = "DejaVu Sans"
    plt.rcParams['mathtext.fontset'] = "cm"
    plt.rcParams['mathtext.rm'] = "serif"
    mean_Regrets_file_name = path+"/mean_regrets_file_name.txt"
    std_Regrets_file_name = path+"/std_regrets_file_name.txt"
    
    mean_Regrets_file_name = open(mean_Regrets_file_name).read().split("\n")
    std_Regrets_file_name = open(std_Regrets_file_name).read().split("\n")
    alg_str = open(alg_str).read().split("\n")
    while "" in mean_Regrets_file_name:
        mean_Regrets_file_name.remove("")
    print(mean_Regrets_file_name)
    while "" in std_Regrets_file_name:
        std_Regrets_file_name.remove("")
    print(std_Regrets_file_name)
    while "" in alg_str:
        alg_str.remove("")
    # alg_str = [
    #     "CUSUM-UCB", 
    #     "CUSUM-UCB (with skipping mechanism)", 
    #     "CUSUM-UCB (with diminishing)", 
    #     "CUSUM-UCB (with diminishing and skipping mechanism)", 
    #     "CUSUM-UCB (with skipping mechanism) uniform", 
    #     "CUSUM-UCB (with diminishing and skipping mechanism) uniform"
    # ]
    print(std_Regrets_file_name)
    Regret_data = {}
    Regret_std = {}
    fig = plt.figure()
    plotParams = plotParameter(alg_str)
    i = 0
    lw = 2
    for alg in alg_str:
        filename_path = path + "/mean_regrets_" + alg + ".csv"
        std_filename_path = path + "/std_regrets_" + alg + ".csv"
        Regret_data[alg] = np.loadtxt(filename_path)
        Regret_std[alg] = np.loadtxt(std_filename_path)
        plt.plot(Regret_data[alg], label = alg, color = plotParams[alg]["color"], marker=plotParams[alg]["marker"], markevery=(1 / 40., 0.1), ls = plotParams[alg]["ls"] , lw=lw, ms=int(3.5*lw))
        x = np.linspace(0, len(Regret_std[alg])-1, len(Regret_std[alg]))
        plt.fill_between(x, Regret_data[alg]-Regret_std[alg], Regret_data[alg]+Regret_std[alg], color = plotParams[alg]["color"], alpha = 0.2)
        i = i+1
    print(Regret_data)
    plt.grid(True)
    plt.legend(loc = 'upper left', fontsize = 'x-small')#, bbox_to_anchor=(-0.89, 1))
    plt.xlabel(r"Time steps $t = 1...T$, horizon", fontsize=16)
    plt.ylabel("Regret", fontsize=16)
    
    tikzplotlib.save(savefig+".tex")
    show_and_save(False , savefig=savefig, fig=fig, pickleit=None)
    plt.close(fig)
    return fig    

def plotParameter(Regrets_file_name):
    plotParams = {}
    allmarkers = ['o', 'D', 'v', 'p', 's', '*', 'h', '>']
    i = 0
    for filename in Regrets_file_name:
        plotParams[filename] = {}
        plotParams[filename]["color"] = None 
        plotParams[filename]["marker"] = None
        plotParams[filename]["ls"] = None
        if filename.find("CUSUM-UCB") != -1:
            plotParams[filename]["color"] = "orange"
            plotParams[filename]["marker"] =  allmarkers[i%8]
            i = i + 1
        elif filename.find("M-UCB") != -1:
            # plotParams[filename]["color"] = "cornflowerblue"
            # plotParams[filename]["color"] = "#1f77b4"
            plotParams[filename]["color"] = "gray"
            if filename.find("M-UCB(w=80") != -1:
                plotParams[filename]["color"] = "cyan"
            elif filename.find("M-UCB(w=160") != -1:
                plotParams[filename]["color"] = "deepskyblue"
            elif filename.find("M-UCB(w=320") != -1:
                plotParams[filename]["color"] = "aquamarine"
            elif filename.find("M-UCB(w=720") != -1:
                plotParams[filename]["color"] = "greenyellow"
            plotParams[filename]["marker"] =  allmarkers[i%8]
            i = i + 1
        elif filename.find("GLR-klUCB(with ") != -1:
            plotParams[filename]["color"] = "gray"
            # plotParams[filename]["marker"] =  allmarkers[i]
            i = i + 1
        elif filename.find("GLR-UCB") != -1:
            plotParams[filename]["color"] = "olive"
            # plotParams[filename]["color"] = "red"
            plotParams[filename]["marker"] =  allmarkers[i%8]
            i = i + 1
        elif filename.find("Oracle kl-UCB") != -1:
            plotParams[filename]["color"] = "#1E76B4"
        elif filename.find("kl-UCB") != -1:
            plotParams[filename]["color"] = "#FF7E0D"
            plotParams[filename]["ls"] = (0, (2, 3))
        elif filename.find("CUSUM-klUCB") != -1:
            # plotParams[filename]["color"] = "#E376C2"
            plotParams[filename]["color"] = "orange"
        elif filename.find("M-klUCB") != -1:
            plotParams[filename]["color"] = "#8B564A"
        elif filename.find("AdSwitch") != -1:
            plotParams[filename]["color"] = "lime"
        elif filename.find("AdSwitch") != -1:
            plotParams[filename]["color"] = "lime"
        elif filename.find("Discounted-TS") != -1:
            plotParams[filename]["color"] = "royalblue"
            plotParams[filename]["marker"] =  allmarkers[i%8]
            i = i + 1
        elif filename.find("Discounted-klUCB") != -1:
            plotParams[filename]["color"] = "magenta"
            plotParams[filename]["marker"] =  allmarkers[i%8]
            i = i + 1
        if filename.find("diminishing") != -1 and filename.find("skipping mechanism") != -1:
            plotParams[filename]["ls"] = "-."
        elif filename.find("diminishing") != -1:
            plotParams[filename]["ls"] = "--"
        elif filename.find("skipping mechanism") != -1:
            plotParams[filename]["ls"] = ":"
        else:
            plotParams[filename]["ls"] = "-"
        if filename.find("uniform") != -1:
            plotParams[filename]["color"] = "lime"

        if filename.find("B = 1") != -1:
            plotParams[filename]["marker"] ='o'
        elif filename.find("B = 0.5") != -1:
            plotParams[filename]["marker"] = 'D'
        elif filename.find("B = -0.5") != -1:
            plotParams[filename]["marker"] = 'v'
        elif filename.find("B = -1") != -1:
            plotParams[filename]["marker"] = 'p'    
        elif filename.find("B = -0.25") != -1:
            plotParams[filename]["marker"] = '*'
        elif filename.find("B = 0.25") != -1:
            plotParams[filename]["marker"] = 'h'    
        else:
            plotParams[filename]["marker"] = 's' 
            plotParams[filename]["ls"] = "-"
    return plotParams
    
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--path', type=str, default="C:/Users/USER/Code/diminishing-exploration/plot/normal/K3_T20000_N100_M5_envId0_seed10")
    parser.add_argument('--statement', type=str, default="")

    args = parser.parse_args()

    main(variant=vars(args))