import numpy as np
import matplotlib.pyplot as plt
import math
import random
from tqdm.notebook import tqdm
# from tqdm import tqdm
from Environment import Environment
from plotsettings import BBOX_INCHES, signature, maximizeWindow, palette, makemarkers, add_percent_formatter, legend, show_and_save, nrows_ncols, violin_or_box_plot, adjust_xticks_subplots, table_to_latex
import os

nb_break_points = 0  #: Default nb of random events
  
class AdSwitch(object):
    
    def __init__(self, configuration, ENV = Environment, max_nb_random_events = None, C=1, sub_sample=50, sub_sample2=20, path = None):
        self.C = C
        self.sub_sample = sub_sample
        self.sub_sample2 = sub_sample2
        self.path = path
        self.regret_versus_M = None
        self.regret = None
        self.cfg = configuration  #: Configuration dictionnary
        if ENV.experiment != "K":
            self.nb_break_points = self.cfg.get('nb_break_points')  #: How many random events?
            self.horizon = self.cfg['horizon']  #: Horizon (number of time steps)routermode
            print("Time horizon:", self.horizon)
            # list_nb_arms = len(self.cfg['environment']["params"]["listOfMeans"][0])
            self.nb_arms = len(self.cfg['environment']["params"]["listOfMeans"][0])
            # self.__initEnvironments__()
        self.experiment = ENV.experiment
        self.ENV = ENV
        self.repetitions = self.ENV.repetitions  #: Number of repetitions
        print("Number of repetitions:", self.repetitions)
        
        if ENV.experiment != "K":
            if max_nb_random_events is None:
                max_nb_random_events = len(self.cfg['environment']["params"]["changePoints"])
            assert max_nb_random_events > 0, "Error: for GLR_UCB policy the parameter max_nb_random_events should be > 0 but it was given as {}.".format(max_nb_random_events)
            self.nb_segs = max_nb_random_events


        if self.experiment == "M":
            self.experimentForM(path = path)
        elif self.experiment == "T":
            self.experimentForT(path = path)
        elif self.experiment == "K":
            self.experimentForK(path = path)    
        else:
            self.Env = self.ENV.Env 
            self.mu = self.ENV.mu 
            self.mu_max = self.ENV.mu_max
            self.mean_reward, self.reward_std = self.mainAlg(path = path)        
            self.regret_mean = self.getRegret()
            self.regret_std = self.reward_std

        if path is not None:
            self.__saveRegret(path)
            self.__saveFileName(path)

    def __initEnvironments__(self):
        """ Create environments."""
        print(" Create environments ...")
        list_of_means = np.array(self.cfg['environment'][0]["params"]["listOfMeans"])
        change_point = self.cfg['environment'][0]["params"]["changePoints"]
        self.Env = np.zeros((self.repetitions, self.nb_arms, self.horizon), dtype=int)
        self.mu = np.zeros((self.nb_arms, self.horizon), dtype=int)
        self.mu_max = np.zeros(self.horizon ,dtype=float)
        segment = 0
        for t in range(self.horizon):
            for k in range(self.nb_arms):
                if segment < len(change_point) - 1:
                    if t == change_point[segment + 1]:
                        segment += 1
                for i in range(self.repetitions):    
                    self.Env[i][k][t] = self.Bernoulli(list_of_means[segment][k])
                self.mu[k][t] = list_of_means[segment][k]
                self.mu_max[t] = np.max(list_of_means[segment,:])
            
                
    
    def mainAlg(self, path = None):
        # table contains a table of rewards for the K arms up to horizon T
        print("Running ... AdSwitch")
        A_mem = np.zeros((self.repetitions, self.nb_arms, self.horizon), dtype = int)
        X_mem = np.zeros((self.repetitions, self.horizon), dtype = float)
        for r in tqdm(range(self.repetitions)):

            # index of the episode
            episode = 1
            chosen_arms = np.zeros(self.horizon, dtype=int)
            received_rewards = np.zeros(self.horizon)
            change_points = []
            t = 0  # current number of samples

            while t < self.horizon:
                
                SUMS   = np.zeros((self.nb_arms, self.horizon - t , self.horizon - t ))  # array containing the S[a, s, t]
                DRAWS  = np.zeros((self.nb_arms, self.horizon - t , self.horizon - t ))  # array containing the N[a, s, t]
                bad_gaps = np.zeros(self.nb_arms)  # gaps of the bad arms
                bad_means = np.zeros(self.nb_arms)  # means of the bad arms
                t_loc = 0  # number of samples in the current episode
                total_number = np.zeros(self.nb_arms, dtype=int)  # total number of selections in the episode

                # initialize the set of good and bad arms
                good = list(range(self.nb_arms))
                bad = []
                # initialize sampling obligations
                obligations = np.zeros(self.nb_arms, dtype=int)
                new_episode = False  # should we start a new episode?

                # start the episode
                while not new_episode and t < self.horizon:
                    # print(t)
                    # form the set of candidate arms to choose from (good arm + sampling obligations)
                    candidate = good.copy()     #good arm set exploitation
                    candidate.extend([i for i in range(self.nb_arms) if obligations[i] > 0])    #bad arm set exploration

                    # draw an arm (least sampled among candidates)
                    I = np.argmax(np.random.rand(len(candidate)) * -total_number[candidate])
                    rew = self.Env[r][I][t]
                    total_number[I] += 1
                    chosen_arms[t] = I 
                    received_rewards[t] = rew
                    X_mem[r][t] = rew
                    A_mem[r][I][t] = 1
                    # update everything
                    if obligations[I] > 0:
                        # if the arm was sampled due to obligations, decrease the remaining time to sample
                        obligations[I] -= 1

                    t += 1

                    t_loc += 1

                    for i in range(self.nb_arms ):
                        if i != I and t_loc > 1:
                            DRAWS[i , 0 : t_loc - 1, t_loc - 1] = DRAWS[i, 0 : t_loc - 1, t_loc - 2]
                            SUMS[i , 0 : t_loc - 1, t_loc - 1] = SUMS[i, 0 : t_loc - 1, t_loc - 2]
                        else:
                            DRAWS[i , t_loc - 1, t_loc - 1] = 1
                            SUMS[i , t_loc - 1, t_loc - 1] = rew
                            if t_loc > 1:
                                DRAWS[i , 0 : t_loc - 1, t_loc-1] = DRAWS[i , 0 : t_loc - 1, t_loc - 2] + 1
                                SUMS[i ,  0 : t_loc - 1, t_loc-1] = SUMS[i ,  0 : t_loc - 1, t_loc - 2] + rew

                    # updating the set of good arms
                    for s in range(t_loc):
                        cand = [a for a in good if DRAWS[a , s , t_loc - 1] > 1]
                        cand_means = [SUMS[a, s , t_loc-1] / DRAWS[a , s , t_loc-1] for a in cand]

                        if len(cand_means) > 0:
                            mu_max = max(cand_means)

                            for i in range(len(cand)):
                                arm = cand[i]

                                if mu_max - cand_means[i] > np.sqrt(self.C * np.log(self.horizon) / (DRAWS[arm , s , t_loc - 1] - 1)):

                                    # remove arm from Good
                                    good.remove(arm)
                                    # add arm to bad and store its gap and mean
                                    bad.append(arm)
                                    bad_gaps[arm ] = mu_max - cand_means[i]
                                    bad_means[arm] = cand_means[i]

                    # perform tests
                    if t_loc % self.sub_sample == 1:
                        # check whether a bad arm has changed
                        check = 0
                        id_change = 0
                        s = 1

                        while s < t_loc and check == 0:
                            # print("while1")
                            # print(t)
                            for bad_arm in bad:
                                draws = DRAWS[bad_arm , s , t_loc - 1]
                                if draws > 1 and abs(SUMS[bad_arm , s , t_loc - 1] / draws - bad_means[bad_arm]) >bad_gaps[bad_arm] / 4 + np.sqrt(2 * np.log(self.horizon) / draws):
                                    new_episode = True
                                    check += 1
                                    id_change = bad_arm

                            s = s + 1

                        if check == 0:
                            # check whether a good arm has changed
                            s = 1

                            while s < t_loc and check == 0:
                                # print("while2")
                                # print(t)
                                for s1 in [j for j in range(t_loc) if j % self.sub_sample2 == 1]:
                                    for s2 in [j for j in range(s1 - 1, t_loc) if j % self.sub_sample2 == 1]:
                                        for good_arm in good:
                                            draws1 = DRAWS[good_arm, s1 - 1, s2 - 1]
                                            draws2 = DRAWS[good_arm, s - 1, t_loc - 1]

                                            if draws1 > 1 and draws2 > 1 and abs(
                                                    SUMS[good_arm , s1 - 1, s2 - 1] / draws1 -
                                                    SUMS[good_arm , s - 1, t_loc-1] / draws2) > \
                                                    np.sqrt(2 * np.log(self.horizon) / draws1) + np.sqrt(2 * np.log(self.horizon) / draws2):
                                                new_episode = True
                                                check += 1
                                                id_change = good_arm

                                s = s + 1

                        if check > 0:
                            episode += 1
                            change_points.append(t)

                    # possibly add some new sampling obligation
                    for bad_arm in bad:
                        i = 1
                        while 1 / (2 ** i) >= bad_gaps[bad_arm]/ 16:
                            # print("while3")
                            # print(t)
                            if np.random.rand() < (1 / (2 ** i)) * np.sqrt(episode / (self.nb_arms * self.horizon * np.log(self.horizon))):
                                n = int(np.floor(2 ** (2 * i + 1) * np.log(self.horizon)))
                                # update the sampling obligation
                                obligations[bad_arm] = max(obligations[bad_arm], n)
                            i += 1

            # End of the outer while loop
        X_mem_cum = np.cumsum(X_mem,axis=1)
        print("Roll data saving ...")
        if path != None:
            filename = path+'/'+'roll_data'+self.__str__()+str(random.randint(100000,999999))+'.csv'
            np.savetxt(filename, X_mem_cum ,delimiter=",")
        X_mean = np.mean(X_mem, axis=0)
        X_std = np.std(X_mem_cum, axis=0)
        A_mem_cum = np.cumsum(A_mem,axis=2)
        A_std = np.std(A_mem_cum, axis=0)
        A_mean = np.mean(A_mem_cum, axis=0)
        self.plot_arm(A_mean, A_std)
        return X_mean, X_std
    
    

    def getRegret(self):
        return np.cumsum(self.mu_max)-np.cumsum(self.mean_reward)

    def __str__(self):
        return r"AdSwitch"
    
    def beta(self, nb, delta=None, variant=None):
        if delta is None:
            delta = 1.0 / np.sqrt(self.horizon)
        if variant is not None:
            if variant == 0:
                c = -np.log(delta) + (3/2) * np.log(nb) + np.log(3)
            elif variant == 1:
                c = -np.log(delta) + np.log(1 + np.log(nb))
        else:
            c = -np.log(delta) + (3/2) * np.log(nb) + np.log(3)
        if c < 0 or np.isinf(c):
            c = float('+inf')
        return c

    def CD(self,nb, sums):
        s = 0
        while s < nb :  
            if s % self.subSample2 == 1:
                draw1 = s + 1 
                draw2 = (nb - s) + 1
                mu1 = sums[s] / draw1
                mu2 = (sums[nb] - sums[s]) / draw2 
                mu = sums[nb] / nb
                if (draw1 * self.klChange(mu1, mu) + draw2 * self.klChange(mu2, mu)) > self.beta(nb):
                    return True
            s += 1
        return False
    
    def klChange(self, p, q):
        # binary relative entropy
        res = 0
        if p != q:
            if p <= 0:
                p = np.finfo(float).eps
            if p >= 1:
                p = 1 - np.finfo(float).eps
            res = p * np.log(p / q) + (1 - p) * np.log((1 - p) / (1 - q))
        return res
    
    def klUp(self,p,level):
    # KL upper confidence bound:
    # return uM>p such that d(p,uM)=level 
        lM = p 
        uM = np.min([np.min([1, p + np.sqrt(level/2)]),1]) 
        for j in range(16):
            qM = (uM+lM)/2
            if self.klChange(p,qM) > level:
                uM= qM
            else:
                lM=qM
        return uM

    
    def experimentForM(self, path):
        self.regret_versus_M = [] 
        self.std_regret_versus_M = []
        for m in self.ENV.nb_break_points_list:
            plot_dir = path +"/M"+str(m)
            # Create the sub folder
            if os.path.isdir(plot_dir):
                print("{} is already a directory here...".format(plot_dir))
            elif os.path.isfile(plot_dir):
                raise ValueError("[ERROR] {} is a file, cannot use it as a directory !".format(plot_dir))
            else:
                os.mkdir(plot_dir)
            self.nb_segs = m
            print("Running ...{} for M = {}".format(self.__str__, m))
            self.Env = self.ENV.Env_for_M[m] 
            self.mu = self.ENV.mu_for_M[m] 
            self.mu_max = self.ENV.mu_max_for_M[m]
            self.mean_reward, self.reward_std = self.mainAlg(path = plot_dir)
            regret = self.getRegret()
            regret_std = self.reward_std
            self.regret_versus_M.append(regret[self.horizon-1])
            self.std_regret_versus_M.append(regret_std[self.horizon-1])

    def experimentForT(self, path):
        self.regret_versus_T = [] 
        self.std_regret_versus_T = []
        self.nb_segs = self.cfg.get('nb_break_points')  #: How many random events?

        for T in self.ENV.horizon_list:
            plot_dir = path +"/T"+str(T)
            # Create the sub folder
            if os.path.isdir(plot_dir):
                print("{} is already a directory here...".format(plot_dir))
            elif os.path.isfile(plot_dir):
                raise ValueError("[ERROR] {} is a file, cannot use it as a directory !".format(plot_dir))
            else:
                os.mkdir(plot_dir)
            self.horizon = T
            print("Running ...{} for T = {}".format(self.__str__, T))

            self.Env = self.ENV.Env_for_T[T] 
            self.mu = self.ENV.mu_for_T[T] 
            self.mu_max = self.ENV.mu_max_for_T[T]
            self.mean_reward, self.reward_std = self.mainAlg(path = plot_dir)
            regret = self.getRegret()
            regret_std = self.reward_std
            self.regret_versus_T.append(regret[self.horizon-1])
            self.std_regret_versus_T.append(regret_std[self.horizon-1])

    def experimentForK(self, path):
        self.regret_versus_K = [] 
        self.std_regret_versus_K = []
        self.nb_break_points = self.cfg[0].get('nb_break_points')  #: How many random events?
        self.horizon = self.cfg[0]['horizon']  #: Horizon (number of time steps)routermode

        regret_for_instance = np.zeros((self.ENV.num_of_instance, self.horizon), dtype=float)
        var_regret_for_instance = np.zeros((self.ENV.num_of_instance, self.horizon), dtype=float)
        for K in self.ENV.nb_arms_list:
            plot_dir = path +"/K"+str(K)
            # Create the sub folder
            if os.path.isdir(plot_dir):
                print("{} is already a directory here...".format(plot_dir))
            elif os.path.isfile(plot_dir):
                raise ValueError("[ERROR] {} is a file, cannot use it as a directory !".format(plot_dir))
            else:
                os.mkdir(plot_dir)
            self.nb_arms = K
            print("Running ...{} for K = {}".format(self.__str__, K))

            for i in tqdm(range(self.ENV.num_of_instance)):
                self.Env = self.ENV.Env_for_K[i][K-1] 
                self.mu = self.ENV.mu_for_K[i][K-1] 
                self.mu_max = self.ENV.mu_max_for_K[i][K-1] 
                self.mean_reward, self.reward_std = self.mainAlg(path = plot_dir)
                regret_for_instance[i] = self.getRegret()
                var_regret_for_instance[i] = self.reward_std**2
                # regret_std = self.reward_std
            regret = np.mean(regret_for_instance, axis = 0)
            std_regret = np.mean(var_regret_for_instance, axis = 0)/self.repetitions
            self.regret_versus_K.append(regret[self.horizon-1])
            self.std_regret_versus_K.append(std_regret[self.horizon-1])


    def __saveRegret(self, path):
        if self.experiment == "M":
            print("Regret saving ...")
            filename = path+'/'+'Regrets_versus_M_'+self.__str__()+'.csv'
            np.savetxt(filename, self.regret_versus_M)
            filename = path+'/'+'std_Regrets_versus_M_'+self.__str__()+'.csv'
            np.savetxt(filename, self.std_regret_versus_M)
        elif self.experiment == "K":
            print("Regret saving ...")
            filename = path+'/'+'Regrets_versus_K_'+self.__str__()+'.csv'
            np.savetxt(filename, self.regret_versus_K)
            filename = path+'/'+'std_Regrets_versus_K_'+self.__str__()+'.csv'
            np.savetxt(filename, self.std_regret_versus_K)
        elif self.experiment == "T":
            print("Regret saving ...")
            filename = path+'/'+'Regrets_versus_T_'+self.__str__()+'.csv'
            np.savetxt(filename, self.regret_versus_T)
            filename = path+'/'+'std_Regrets_versus_T_'+self.__str__()+'.csv'
            np.savetxt(filename, self.std_regret_versus_T)
        else:
            print("Regret saving ...")
            filename = path+'/'+'Regrets_'+self.__str__()+'.csv'
            np.savetxt(filename, self.regret_mean)
            filename = path+'/'+'std_Regrets_'+self.__str__()+'.csv'
            np.savetxt(filename, self.regret_std)

    def __saveFileName(self, path):
        if self.experiment == "M":
            print("FileName saving ...")
            filename = path+'/Regrets_versus_M_file_name.txt'
            f = open(filename,'a')
            f.write('Regrets_versus_M_'+self.__str__()+'.csv\n')
            f.close()
            filename = path+'/std_Regrets_versus_M_file_name.txt'
            f = open(filename,'a')
            f.write('std_Regrets_versus_M_'+self.__str__()+'.csv\n')
            f.close()
        elif self.experiment == "K":
            print("FileName saving ...")
            filename = path+'/Regrets_versus_K_file_name.txt'
            f = open(filename,'a')
            f.write('Regrets_versus_K_'+self.__str__()+'.csv\n')
            f.close()
            filename = path+'/std_Regrets_versus_K_file_name.txt'
            f = open(filename,'a')
            f.write('std_Regrets_versus_K_'+self.__str__()+'.csv\n')
            f.close()
        elif self.experiment == "T":
            print("FileName saving ...")
            filename = path+'/Regrets_versus_T_file_name.txt'
            f = open(filename,'a')
            f.write('Regrets_versus_T_'+self.__str__()+'.csv\n')
            f.close()
            filename = path+'/std_Regrets_versus_T_file_name.txt'
            f = open(filename,'a')
            f.write('std_Regrets_versus_T_'+self.__str__()+'.csv\n')
            f.close()
        else:
            print("FileName saving ...")
            filename = path+'/Regrets_file_name.txt'
            f = open(filename,'a')
            f.write('Regrets_'+self.__str__()+'.csv\n')
            f.close()
            filename = path+'/std_Regrets_file_name.txt'
            f = open(filename,'a')
            f.write('std_Regrets_'+self.__str__()+'.csv\n')
            f.close()

    def randmax(self, vector, rank=1):
        # returns an argmax uniformly at random among all maximizers
        # (integer, not CartesianIndex)
        sorted_vector = np.sort(vector)[::-1]
        m = sorted_vector[rank - 1]
        indices = np.where(vector == m)[0]
        random_index = indices[np.random.randint(0, len(indices))]
        return random_index

    def Bernoulli(self, p):
        if np.random.rand()<p:
            return 1
        else :
            return 0
        
    def plot_arm(self, mean, std):
        if self.path is not None:
            color = ["#1F77B4", "#FF7F0E", "#2CA02C"]
            formats = ('png', 'pdf', 'eps')
            # plt.rcParams['figure.figsize'] = (12,8)
            # plt.rcParams['figure.dpi'] = 400
            # plt.rcParams['figure.figsize'] = (20,10)
            # plt.rcParams['figure.dpi'] = 200
            plt.rcParams['font.family'] = "sans-serif"
            plt.rcParams['font.sans-serif'] = "DejaVu Sans"
            plt.rcParams['mathtext.fontset'] = "cm"
            plt.rcParams['mathtext.rm'] = "serif"
            fig = plt.figure()
            x = np.linspace(0, self.horizon-1, self.horizon)
            lw = 3
            if self.nb_arms < 4:
                for i in range(self.nb_arms):
                    plt.plot(mean[i], label = "arm {}".format(i+1),color = color[i], lw=lw)
                    plt.fill_between(x, mean[i]-std[i], mean[i]+std[i],color = color[i], alpha = 0.2)   
            else:        
                for i in range(self.nb_arms):
                    plt.plot(mean[i], label = "arm {}".format(i+1), lw=lw)
                    plt.fill_between(x, mean[i]-std[i], mean[i]+std[i], alpha = 0.2)  
            plt.grid(True)
            plt.legend(loc = 'upper left')
            plt.xlabel(r"T")
            # plt.xlabel(r"Time steps $t = 1...T$")
            plt.ylabel(r"Number of arm $i$ pulled")
            if self.experiment == "M":
                savefig = self.path+'/'+'arm_'+self.__str__()+"_M="+str(self.nb_segs)
            elif self.experiment == "K":
                savefig = self.path+'/'+'arm_'+self.__str__()+"_K="+str(self.nb_arms)
            elif self.experiment == "T":
                savefig = self.path+'/'+'arm_'+self.__str__()+"_T="+str(self.horizon)
            else: 
                savefig = self.path+'/'+'arm_'+self.__str__()
            show_and_save(False , savefig=savefig, fig=fig, pickleit=None)
        return fig