import numpy as np
import os
# from MUCB import MUCB
from select_alg import select_alg
import matplotlib.pyplot as plt
from plotsettings import BBOX_INCHES, signature, maximizeWindow, palette, makemarkers, add_percent_formatter, legend, show_and_save, nrows_ncols, violin_or_box_plot, adjust_xticks_subplots, table_to_latex
import time 

"""
cfg = {
        'alg' : alg,
        'diminishing' : diminishing,
        'skip' : skip,
        'alg_arg' : alg_arg,
        'repetitions': repetitions,
        'nb_arms': nb_arms,
        'nb_break_points': nb_break_point,
        'horizon': horizon,
        'dict_index_list' : dict_index_list,
        'env_samples_dict' : env_samples_dict,
        'nb_of_instances' : nb_of_instances,
        'experiment' : experiment,
        'plot_dir' : plot_dir
    }
"""
def run(cfg):
    alg                 = cfg.get("alg")
    diminishing         = cfg.get("diminishing")
    skip                = cfg.get("skip")
    alg_arg             = cfg.get("alg_arg")
    repetitions         = cfg.get("repetitions")
    nb_arms             = cfg.get("nb_arms")
    nb_break_points     = cfg.get("nb_break_points")
    horizon             = cfg.get("horizon")
    dict_index_list     = cfg.get("dict_index_list")
    env_samples_dict    = cfg.get("env_samples_dict")
    nb_of_instances     = cfg.get("nb_of_instances")
    experiment          = cfg.get("experiment") 
    plot_dir            = cfg.get("plot_dir") 


    last_mean_regret = []
    last_std_regret = []
    for index in dict_index_list:
        match experiment:
            case "t":
                print(f"experiment = t")
            case "M":
                print(f"experiment = M")
                nb_break_points = index
                print(f"nb_break_points = {nb_break_points}")
            case "T":
                print(f"experiment = T")
            case "K":
                print(f"experiment = K")
        alg = select_alg(
            repetitions = repetitions,
            nb_arms = nb_arms,
            nb_break_points = nb_break_points,
            horizon = horizon,
            env = env_samples_dict[index],
            alg = alg, 
            diminishing = diminishing,
            skip = skip,
            arg = alg_arg,
            path = plot_dir,
        )
        

        start_time = time.time()
        ################### Running algorithm ###################
        reward_mean, reward_std, action_mean, action_std = alg()
        end_time = time.time()
        execution_time = end_time - start_time
        with open(plot_dir+"/execution_time_log_rep_"+str(repetitions)+alg.__str__()+".txt", "a") as file:
            file.write(f"{execution_time}\n")
        with open(plot_dir+"/execution_time_log_mean"+alg.__str__()+".txt", "a") as file:
            file.write(f"{execution_time/repetitions}\n")


        plot_arm(
            path = plot_dir,
            mean = action_mean,
            std = action_std,
            alg_str = alg.__str__(),
            horizon = horizon,
            nb_arms = nb_arms,
            nb_break_points = nb_break_points
        )
        regret = getRegret(mu_max=env_samples_dict[index].mu_max, mean_reward=reward_mean)
        last_mean_regret.append(regret[horizon-1])
        last_std_regret.append(reward_std[horizon-1])

    saveFileName(
        plot_dir,
        alg.__str__(),
        experiment,
        last_mean_regret,
        last_std_regret,
        regret,
        reward_std
    )

def getRegret(mu_max, mean_reward):    
    return np.cumsum(mu_max)-np.cumsum(mean_reward)

def saveFileName(
        path, 
        algname, 
        experiment, 
        last_mean_regret, 
        last_std_regret, 
        regret, 
        reward_std
    ):
    if experiment == 't':
        print("Regret saving ...")
        filename = path+'/'+'mean_regrets_'+algname+'.csv'
        np.savetxt(filename, regret)
        filename = path+'/'+'std_regrets_'+algname+'.csv'
        np.savetxt(filename, reward_std)
    else:
        print("Regret saving ...")
        filename = path+'/'+'mean_regrets_'+algname+'.csv'
        np.savetxt(filename, last_mean_regret)
        filename = path+'/'+'std_regrets_'+algname+'.csv'
        np.savetxt(filename, last_std_regret)
    print("FileName saving ...")
    mean_regret_filename_txt = path + '/' + 'mean_regrets_file_name.txt'
    std_regret_filename_txt = path + '/' + 'std_regrets_file_name.txt'
    alg_str_txt = path + '/' + 'alg_str.txt'
    mean_entry = 'mean_regrets_' + algname + '.csv' + '\n'
    std_entry = 'std_regrets_' + algname + '.csv' + '\n'
    alg_entry = algname + '\n'
    
    # Append to mean_regret_filename_txt
    append_if_not_exists(mean_regret_filename_txt, mean_entry)

    # Append to std_regret_filename_txt
    append_if_not_exists(std_regret_filename_txt, std_entry)

    # Append to std_regret_filename_txt
    append_if_not_exists(alg_str_txt, alg_entry)

                
# Function to check and append entry to a file if it doesn't exist
def append_if_not_exists(filename, entry):
    try:
        # Read the file content
        with open(filename, 'r') as f:
            content = f.readlines()
    except FileNotFoundError:
        content = []

    # Check if the entry already exists
    if entry not in content:
        with open(filename, 'a') as f:
            f.write(entry)

def plot_arm(
        path,
        mean,
        std,
        alg_str,
        horizon,
        nb_arms,
        nb_break_points
    ):
    if path is not None:
        color = ["#1F77B4", "#FF7F0E", "#2CA02C"]
        formats = ('png', 'pdf', 'eps')
        # plt.rcParams['figure.figsize'] = (12,8)
        # plt.rcParams['figure.dpi'] = 400
        # plt.rcParams['figure.figsize'] = (20,10)
        # plt.rcParams['figure.dpi'] = 200
        plt.rcParams['font.family'] = "sans-serif"
        plt.rcParams['font.sans-serif'] = "DejaVu Sans"
        plt.rcParams['mathtext.fontset'] = "cm"
        plt.rcParams['mathtext.rm'] = "serif"
        fig = plt.figure()
        x = np.linspace(0, horizon-1, horizon)
        lw = 3
        if nb_arms < 4:
            for i in range(nb_arms):
                plt.plot(mean[i], label = "arm {}".format(i+1),color = color[i], lw=lw)
                plt.fill_between(x, mean[i]-std[i], mean[i]+std[i],color = color[i], alpha = 0.2)   
        else:        
            for i in range(nb_arms):
                plt.plot(mean[i], label = "arm {}".format(i+1), lw=lw)
                plt.fill_between(x, mean[i]-std[i], mean[i]+std[i], alpha = 0.2)  
        plt.grid(True)
        plt.legend(loc = 'upper left')
        plt.xlabel(r"T")
        # plt.xlabel(r"Time steps $t = 1...T$")
        plt.ylabel(r"Number of arm $i$ pulled")
        savefig = path+'/'+'arm_'+"_M=" + str(nb_break_points) +"_K=" + str(nb_arms) +"_T="+ str(horizon) + alg_str
        show_and_save(False , savefig=savefig, fig=fig, pickleit=None)
    return fig

    