import numpy as np
import matplotlib.pyplot as plt
import math
import random
from tqdm.notebook import tqdm
# from tqdm import tqdm
from Environment import Environment
from plotsettings import BBOX_INCHES, signature, maximizeWindow, palette, makemarkers, add_percent_formatter, legend, show_and_save, nrows_ncols, violin_or_box_plot, adjust_xticks_subplots, table_to_latex
import os

nb_break_points = 0  #: Default nb of random events
WINDOW_SIZE = None
THRESHOLD_B = None
class MUCB(object):
    
    def __init__(self, configuration, ENV = Environment, max_nb_random_events = None, alpha=None, w=None, b=None, gamma=None, path = None, klUCB = None):
        self.path = path
        self.klUCB = klUCB
        self.regret_versus_M = None
        self.regret = None
        self.cfg = configuration  #: Configuration dictionnary
        if ENV.experiment != "K":
            self.nb_break_points = self.cfg.get('nb_break_points')  #: How many random events?
            self.horizon = self.cfg['horizon']  #: Horizon (number of time steps)routermode
            print("Time horizon:", self.horizon)
            # list_nb_arms = len(self.cfg['environment']["params"]["listOfMeans"][0])
            self.nb_arms = len(self.cfg['environment']["params"]["listOfMeans"][0])
            # self.__initEnvironments__()
        self.experiment = ENV.experiment
        self.ENV = ENV
        self.repetitions = self.ENV.repetitions  #: Number of repetitions
        print("Number of repetitions:", self.repetitions)
        
        if ENV.experiment != "K":
            if max_nb_random_events is None:
                max_nb_random_events = len(self.cfg['environment']["params"]["changePoints"])
            assert max_nb_random_events > 0, "Error: for GLR_UCB policy the parameter max_nb_random_events should be > 0 but it was given as {}.".format(max_nb_random_events)
            self.nb_segs = max_nb_random_events
        
            if w is None or w == 'auto':
                # XXX Estimate w from Remark 1
                w = (4/delta**2) * (np.sqrt(np.log(2 * self.nb_arms * self.horizon**2)) + np.sqrt(np.log(2 * self.horizon)))**2
                w = int(np.ceil(w))
                if w % 2 != 0:
                    w = 2*(1 + w//2)
                WINDOW_SIZE = w
            assert w > 0, "Error: for Monitored_UCB policy the parameter w should be > 0 but it was given as {}.".format(w)  # DEBUG
            self.window_size = w  #: Parameter :math:`w` for the M-UCB algorithm.
            
            if b is None or b == 'auto':
                # XXX compute b from the formula from Theorem 6.1
                b = np.sqrt((w/2) * np.log(2 * self.nb_arms * self.horizon**2))
                THRESHOLD_B = b
            assert b > 0, "Error: for Monitored_UCB policy the parameter b should be > 0 but it was given as {}.".format(b)  # DEBUG
            self.threshold_b = b  #: Parameter :math:`b` for the M-UCB algorithm.

            if alpha is None or alpha == 'auto':
                alpha = np.sqrt(np.log(self.horizon) * self.nb_arms)
                # alpha = 1
            assert alpha > 0, "Error: for Monitored_UCB policy the parameter alpha should be > 0 but it was given as {}.".format(alpha)  # DEBUG
            self.alpha = alpha  #: Parameter :math:`alpha` for the M-UCB algorithm.

        if self.experiment == "M":
            if gamma == 'diminishing' or gamma == 'd':
                self.u = np.ceil((2*alpha - self.nb_arms/(4*alpha))**2)
                self.alpha = alpha 
            self.experimentForM(gamma, path = path)
        
        elif self.experiment == "T":
            self.experimentForT(gamma, w, alpha, path = path)

        elif self.experiment == "K":
            self.experimentForK(gamma, w, alpha, path = path)

        else:
            self.Env = self.ENV.Env 
            self.mu = self.ENV.mu 
            self.mu_max = self.ENV.mu_max
            if gamma == 'diminishing' or gamma == 'd':
                self.u = np.ceil((alpha - self.nb_arms/(4*alpha))**2)
                self.alpha = alpha 
                self.gamma = gamma
                if self.klUCB:
                    self.mean_reward, self.reward_std = self.mainAlgWithDiminishing_kl(path = path)  
                else:
                    self.mean_reward, self.reward_std = self.mainAlgWithDiminishing(path = path)          
            else:
                if gamma is None or gamma == 'auto':
                    # XXX Estimate w from Remark 1
                    gamma = np.sqrt(self.nb_segs*self.nb_arms*np.log(self.horizon)/self.horizon)
                assert gamma > 0, "Error: for Monitored_UCB policy the parameter gamma should be > 0 but it was given as {}.".format(gamma)  # DEBUG
                self.gamma = gamma  #: Parameter :math:`gamma` for the M-UCB algorithm.
                print("exploration rate = {}".format(self.gamma))
                if self.klUCB:
                    self.mean_reward, self.reward_std = self.mainAlg_kl(path = path)  
                else:
                    self.mean_reward, self.reward_std = self.mainAlg(path = path)
        
            self.regret_mean = self.getRegret()
            self.regret_std = self.reward_std

        
        if path is not None:    
            self.__saveRegret(path)
            self.__saveFileName(path)
        

    def __initEnvironments__(self):
        """ Create environments."""
        print(" Create environments ...")
        list_of_means = np.array(self.cfg['environment'][0]["params"]["listOfMeans"])
        change_point = self.cfg['environment'][0]["params"]["changePoints"]
        self.Env = np.zeros((self.repetitions, self.nb_arms, self.horizon), dtype=int)
        self.mu = np.zeros((self.nb_arms, self.horizon), dtype=int)
        self.mu_max = np.zeros(self.horizon ,dtype=float)
        segment = 0
        for t in range(self.horizon):
            for k in range(self.nb_arms):
                if segment < len(change_point) - 1:
                    if t == change_point[segment + 1]:
                        segment += 1
                for i in range(self.repetitions):    
                    self.Env[i][k][t] = self.Bernoulli(list_of_means[segment][k])
                self.mu[k][t] = list_of_means[segment][k]
                self.mu_max[t] = np.max(list_of_means[segment,:])
                          

    def mainAlg(self, path = None):
        print("Running ...M-UCB with gamma = {}".format(self.gamma))
        A_mem = np.zeros((self.repetitions, self.nb_arms, self.horizon), dtype = int)
        X_mem = np.zeros((self.repetitions, self.horizon), dtype = float)
        cd_true = 0
        for i in tqdm(range(self.repetitions)):
            tau = 0
            UCB = np.zeros(self.nb_arms,dtype = float)
            UCB[:] = 10000
            Z = np.zeros((self.nb_arms, self.horizon), dtype= int )
            X = np.zeros( self.horizon, dtype = float)
            n = np.zeros(self.nb_arms, dtype = int)
            A = np.zeros(self.horizon, dtype = int)
            for t in range(self.horizon):
                a = (t-tau) % np.floor( self.nb_arms  /self.gamma )
                if a < self.nb_arms : 
                    A[t] = a
                else:
                    for k in range(self.nb_arms):
                        if n[k]>0:
                            UCB[k] = np.sum(Z[k][0:n[k]])/n[k]+np.sqrt(2*np.log(t-tau)/n[k])
                    A[t] = np.argmax(UCB)
                X_mem[i][t] = self.Env[i][A[t]][t]
                A_mem[i][A[t]][t] = 1
                Z[A[t]][n[A[t]]] = X_mem[i][t]
                n[A[t]] = n[A[t]] + 1
                if n[A[t]] >= self.window_size:
                    if self.CD(Z[A[t],n[A[t]]-self.window_size:n[A[t]]-1]):
                        cd_true += 1
                        tau = t
                        n[:] = 0
                        Z[A[t],:] = 0
        print("Average alarm {} times".format(cd_true/self.repetitions))
        X_mem_cum = np.cumsum(X_mem,axis=1)
        print("Roll data saving ...")
        if path != None:
            filename = path+'/'+'roll_data'+self.__str__()+str(random.randint(100000,999999))+'.csv'
            np.savetxt(filename, X_mem_cum ,delimiter=",")
        X_mean = np.mean(X_mem, axis=0)
        X_std = np.std(X_mem_cum, axis=0)
        A_mem_cum = np.cumsum(A_mem,axis=2)
        A_std = np.std(A_mem_cum, axis=0)
        A_mean = np.mean(A_mem_cum, axis=0)
        self.plot_arm(A_mean, A_std)
        return X_mean, X_std
    
    def mainAlg_kl(self):
        print("Running ...M-klUCB with gamma = {}".format(self.gamma))
        A_mem = np.zeros((self.repetitions, self.nb_arms, self.horizon), dtype = int)
        X_mem = np.zeros((self.repetitions, self.horizon), dtype = float)
        cd_true = 0
        for i in tqdm(range(self.repetitions)):
            tau = 0
            tloc = 0 # number of samples in current episode 
            # UCB = np.zeros(self.nb_arms,dtype = float)
            # UCB[:] = 10000
            Z = np.zeros((self.nb_arms, self.horizon), dtype= int )
            X = np.zeros( self.horizon, dtype = float)
            n = np.zeros(self.nb_arms, dtype = int)
            A = np.zeros(self.horizon, dtype = int)
            tloc = 0
            for t in range(self.horizon):
                a = (t-tau) % np.floor( self.nb_arms  /self.gamma )
                if a < self.nb_arms : 
                    A[t] = a
                else:
                    indices = np.ones(self.nb_arms, dtype = float)
                    for k in range(self.nb_arms):
                        if n[k]>0:
                            indices[k] = self.klUp(np.sum(Z[k][0:n[k]])/n[k], np.log(tloc)/n[k])
                    A[t] = np.argmax(indices)
                X_mem[i][t] = self.Env[i][A[t]][t]
                A_mem[i][A[t]][t] = 1
                Z[A[t]][n[A[t]]] = X_mem[i][t]
                n[A[t]] = n[A[t]] + 1
                tloc=tloc+1
                if n[A[t]] >= self.window_size:
                    if self.CD(Z[A[t],n[A[t]]-self.window_size:n[A[t]]-1]):
                        cd_true += 1
                        tau = t
                        tloc = 0
                        n[:] = 0
                        Z[A[t],:] = 0
        print("Average alarm {} times".format(cd_true/self.repetitions))
        X_mean = np.mean(X_mem, axis=0)
        X_std = np.std(X_mem, axis=0)
        A_mem_cum = np.cumsum(A_mem,axis=2)
        A_std = np.std(A_mem_cum, axis=0)
        A_mean = np.mean(A_mem_cum, axis=0)
        self.plot_arm(A_mean, A_std)
        return X_mean, X_std

    def mainAlgWithDiminishing(self, path = None):
        print("Running ...M-UCB with diminishing")
        A_mem = np.zeros((self.repetitions, self.nb_arms, self.horizon), dtype = int)
        X_mem = np.zeros((self.repetitions, self.horizon), dtype = float)
        cd_true = 0
        for i in tqdm(range(self.repetitions)):
            tau = 0
            u = self.u
            UCB = np.zeros(self.nb_arms,dtype = float)
            UCB[:] = 10000
            Z = np.zeros((self.nb_arms, self.horizon), dtype= int )
            X = np.zeros(self.horizon, dtype = float)     
            n = np.zeros(self.nb_arms, dtype = int)
            A = np.zeros(self.horizon, dtype = int)
            for t in range(self.horizon):
                if u <= (t - tau) and (t - tau) < u + self.nb_arms:
                    A[t] = (t - tau) - u 
                else:
                    if (t - tau) == u + self.nb_arms:
                        u = np.ceil(u+np.sqrt(u)*self.nb_arms/self.alpha+(self.nb_arms/(2*self.alpha))**2)
                    for k in range(self.nb_arms):
                        UCB[k] = np.sum(Z[k][0:n[k]])/n[k]+np.sqrt(2*np.log(t-tau)/n[k])
                    A[t] = np.argmax(UCB)
                X_mem[i][t] = self.Env[i][A[t]][t]
                A_mem[i][A[t]][t] = 1
                Z[A[t]][n[A[t]]] = X_mem[i][t]
                n[A[t]] = n[A[t]] + 1
                if n[A[t]] >= self.window_size:
                    if self.CD(Z[A[t],n[A[t]]-self.window_size:n[A[t]]-1]):
                        cd_true += 1
                        tau = t
                        u = self.u
                        n[:] = 0
                        Z[A[t],:] = 0
        print("Average alarm {} times".format(cd_true/self.repetitions))
        X_mem_cum = np.cumsum(X_mem,axis=1)
        print("Roll data saving ...")
        if path != None:
            filename = path+'/'+'roll_data'+self.__str__()+str(random.randint(100000,999999))+'.csv'
            np.savetxt(filename, X_mem_cum ,delimiter=",")
        X_mean = np.mean(X_mem, axis=0)
        X_std = np.std(X_mem_cum, axis=0)
        A_mem_cum = np.cumsum(A_mem,axis=2)
        A_std = np.std(A_mem_cum, axis=0)
        A_mean = np.mean(A_mem_cum, axis=0)
        self.plot_arm(A_mean, A_std)
        return X_mean, X_std

    def mainAlgWithDiminishing_kl(self):
        print("Running ...M-klUCB with diminishing")
        A_mem = np.zeros((self.repetitions, self.nb_arms, self.horizon), dtype = int)
        X_mem = np.zeros((self.repetitions, self.horizon), dtype = float)
        cd_true = 0
        for i in tqdm(range(self.repetitions)):
            tau = 0
            u = self.u
            tloc = 0 # number of samples in current episode 
            # UCB = np.zeros(self.nb_arms,dtype = float)
            # UCB[:] = 10000
            Z = np.zeros((self.nb_arms, self.horizon), dtype= int )
            X = np.zeros( self.horizon, dtype = float)
            n = np.zeros(self.nb_arms, dtype = int)
            A = np.zeros(self.horizon, dtype = int)
            tloc = 0
            for t in range(self.horizon):
                if u <= (t - tau) and (t - tau) < u + self.nb_arms:
                    A[t] = (t - tau) - u 
                else:
                    if (t - tau) == u + self.nb_arms:
                        u = np.ceil(u+np.sqrt(u)*self.nb_arms/self.alpha+(self.nb_arms/(2*self.alpha))**2)
                    indices = np.ones(self.nb_arms, dtype = float)
                    for k in range(self.nb_arms):
                        if n[k]>0:
                            indices[k] = self.klUp(np.sum(Z[k][0:n[k]])/n[k], np.log(tloc)/n[k])
                    A[t] = np.argmax(indices)
                X_mem[t] = self.Env[i][A[t]][t]
                A_mem[i][A[t]][t] = 1
                Z[A[t]][n[A[t]]] = X_mem[i][t]
                n[A[t]] = n[A[t]] + 1
                tloc=tloc+1
                if n[A[t]] >= self.window_size:
                    if self.CD(Z[A[t],n[A[t]]-self.window_size:n[A[t]]-1]):
                        cd_true += 1
                        tau = t
                        n[:] = 0
                        u = self.u
                        Z[A[t],:] = 0
        print("Average alarm {} times".format(cd_true/self.repetitions))
        X_mean = np.mean(X_mem, axis=0)
        X_std = np.std(X_mem, axis=0)
        A_mem_cum = np.cumsum(A_mem,axis=2)
        A_std = np.std(A_mem_cum, axis=0)
        A_mean = np.mean(A_mem_cum, axis=0)
        self.plot_arm(A_mean, A_std)
        return X_mean, X_std

    def experimentForM(self, gamma, path):
        self.regret_versus_M = [] 
        self.std_regret_versus_M = []
        for m in self.ENV.nb_break_points_list:
            plot_dir = path +"/M"+str(m)
            # Create the sub folder
            if os.path.isdir(plot_dir):
                print("{} is already a directory here...".format(plot_dir))
            elif os.path.isfile(plot_dir):
                raise ValueError("[ERROR] {} is a file, cannot use it as a directory !".format(plot_dir))
            else:
                os.mkdir(plot_dir)

            self.nb_segs = m
            print("Running ...{} for M = {}".format(self.__str__, m))
            self.Env = self.ENV.Env_for_M[m] 
            self.mu = self.ENV.mu_for_M[m] 
            self.mu_max = self.ENV.mu_max_for_M[m]
            if gamma == 'diminishing' or gamma == 'd':
                self.gamma = gamma
                if self.klUCB:
                    self.mean_reward, self.reward_std = self.mainAlgWithDiminishing_kl(path = plot_dir)
                else:
                    self.mean_reward, self.reward_std = self.mainAlgWithDiminishing(path = plot_dir)          
            else:
                if gamma is None or gamma == 'auto':
                    # XXX Estimate w from Remark 1
                    _gamma = np.sqrt(m*self.nb_arms*np.log(self.horizon)/self.horizon)
                    assert _gamma > 0, "Error: for Monitored_UCB policy the parameter gamma should be > 0 but it was given as {}.".format(gamma)  # DEBUG
                    self.gamma = _gamma  #: Parameter :math:`gamma` for the M-UCB algorithm.
                else:
                    self.gamma = gamma
                print("exploration rate = {}".format(self.gamma))
                if self.klUCB:
                    self.mean_reward, self.reward_std = self.mainAlg_kl(path = plot_dir)
                else:
                    self.mean_reward, self.reward_std = self.mainAlg(path = plot_dir)
            regret = self.getRegret()
            regret_std = self.reward_std
            self.regret_versus_M.append(regret[self.horizon-1])
            self.std_regret_versus_M.append(regret_std[self.horizon-1])

    def experimentForT(self, gamma, w , alpha, path):
        self.regret_versus_T = [] 
        self.std_regret_versus_T = []
        self.nb_segs = self.cfg.get('nb_break_points')  #: How many random events?

        if w is None or w == 'auto':
            # XXX Estimate w from Remark 1
            w = (4/delta**2) * (np.sqrt(np.log(2 * self.nb_arms * self.horizon**2)) + np.sqrt(np.log(2 * self.horizon)))**2
            w = int(np.ceil(w))
            if w % 2 != 0:
                w = 2*(1 + w//2)
            WINDOW_SIZE = w
            assert w > 0, "Error: for Monitored_UCB policy the parameter w should be > 0 but it was given as {}.".format(w)  # DEBUG
        self.window_size = w  #: Parameter :math:`w` for the M-UCB algorithm.

        for T in self.ENV.horizon_list:
            plot_dir = path +"/T"+str(T)
            # Create the sub folder
            if os.path.isdir(plot_dir):
                print("{} is already a directory here...".format(plot_dir))
            elif os.path.isfile(plot_dir):
                raise ValueError("[ERROR] {} is a file, cannot use it as a directory !".format(plot_dir))
            else:
                os.mkdir(plot_dir)
            self.horizon = T
            print("Running ...{} for T = {}".format(self.__str__, T))

            # XXX compute b from the formula from Theorem 6.1
            b = np.sqrt((self.window_size/2) * np.log(2 * self.nb_arms * self.horizon**2))
            THRESHOLD_B = b
            assert b > 0, "Error: for Monitored_UCB policy the parameter b should be > 0 but it was given as {}.".format(b)  # DEBUG
            self.threshold_b = b  #: Parameter :math:`b` for the M-UCB algorithm.

            self.Env = self.ENV.Env_for_T[T] 
            self.mu = self.ENV.mu_for_T[T] 
            self.mu_max = self.ENV.mu_max_for_T[T]
            if gamma == 'diminishing' or gamma == 'd':
                self.gamma = gamma
                self.u = np.ceil((2*alpha - self.nb_arms/(4*alpha))**2)
                self.alpha = alpha 
                if self.klUCB:
                    self.mean_reward, self.reward_std = self.mainAlgWithDiminishing_kl(plot_dir)
                else:
                    self.mean_reward, self.reward_std = self.mainAlgWithDiminishing(plot_dir)          
            else:
                if gamma is None or gamma == 'auto':
                    # XXX Estimate w from Remark 1
                    _gamma = np.sqrt(self.nb_segs*self.nb_arms*np.log(self.horizon)/self.horizon)
                    assert _gamma > 0, "Error: for Monitored_UCB policy the parameter gamma should be > 0 but it was given as {}.".format(gamma)  # DEBUG
                    self.gamma = _gamma  #: Parameter :math:`gamma` for the M-UCB algorithm.
                else:
                    self.gamma = gamma
                print("exploration rate = {}".format(self.gamma))
                if self.klUCB:
                    self.mean_reward, self.reward_std = self.mainAlg_kl(plot_dir)
                else:
                    self.mean_reward, self.reward_std = self.mainAlg(plot_dir)
            regret = self.getRegret()
            regret_std = self.reward_std
            self.regret_versus_T.append(regret[self.horizon-1])
            self.std_regret_versus_T.append(regret_std[self.horizon-1])

    def experimentForK(self, gamma, w, alpha, path):
        self.regret_versus_K = [] 
        self.std_regret_versus_K = []
        self.nb_break_points = self.cfg[0].get('nb_break_points')  #: How many random events?
        self.horizon = self.cfg[0]['horizon']  #: Horizon (number of time steps)routermode
        if w is None or w == 'auto':
            # XXX Estimate w from Remark 1
            w = (4/delta**2) * (np.sqrt(np.log(2 * self.nb_arms * self.horizon**2)) + np.sqrt(np.log(2 * self.horizon)))**2
            w = int(np.ceil(w))
            if w % 2 != 0:
                w = 2*(1 + w//2)
            WINDOW_SIZE = w
            assert w > 0, "Error: for Monitored_UCB policy the parameter w should be > 0 but it was given as {}.".format(w)  # DEBUG
        self.window_size = w  #: Parameter :math:`w` for the M-UCB algorithm.

        regret_for_instance = np.zeros((self.ENV.num_of_instance, self.horizon), dtype=float)
        var_regret_for_instance = np.zeros((self.ENV.num_of_instance, self.horizon), dtype=float)
        for K in self.ENV.nb_arms_list:
            plot_dir = path +"/K"+str(K)
            # Create the sub folder
            if os.path.isdir(plot_dir):
                print("{} is already a directory here...".format(plot_dir))
            elif os.path.isfile(plot_dir):
                raise ValueError("[ERROR] {} is a file, cannot use it as a directory !".format(plot_dir))
            else:
                os.mkdir(plot_dir)
            self.nb_arms = K
            print("Running ...{} for K = {}".format(self.__str__, K))

            # XXX compute b from the formula from Theorem 6.1
            b = np.sqrt((self.window_size/2) * np.log(2 * self.nb_arms * self.horizon**2))
            THRESHOLD_B = b
            assert b > 0, "Error: for Monitored_UCB policy the parameter b should be > 0 but it was given as {}.".format(b)  # DEBUG
            self.threshold_b = b  #: Parameter :math:`b` for the M-UCB algorithm.

            for i in tqdm(range(self.ENV.num_of_instance)):
                self.Env = self.ENV.Env_for_K[i][K-1] 
                self.mu = self.ENV.mu_for_K[i][K-1] 
                self.mu_max = self.ENV.mu_max_for_K[i][K-1] 
                
                if gamma == 'diminishing' or gamma == 'd':
                    self.gamma = gamma
                    self.u = np.ceil((2*alpha - self.nb_arms/(4*alpha))**2)
                    self.alpha = alpha 
                    if self.klUCB:
                        self.mean_reward, self.reward_std = self.mainAlgWithDiminishing_kl(plot_dir)
                    else:
                        self.mean_reward, self.reward_std = self.mainAlgWithDiminishing(plot_dir)          
                else:
                    if gamma is None or gamma == 'auto':
                        # XXX Estimate w from Remark 1
                        _gamma = np.sqrt(self.nb_break_points*self.nb_arms*np.log(self.horizon)/self.horizon)
                        assert _gamma > 0, "Error: for Monitored_UCB policy the parameter gamma should be > 0 but it was given as {}.".format(gamma)  # DEBUG
                        self.gamma = _gamma  #: Parameter :math:`gamma` for the M-UCB algorithm.
                    else:
                        self.gamma = gamma
                    print("exploration rate = {}".format(self.gamma))
                    if self.klUCB:
                        self.mean_reward, self.reward_std = self.mainAlg_kl(plot_dir)
                    else:
                        self.mean_reward, self.reward_std = self.mainAlg(plot_dir)
                regret_for_instance[i] = self.getRegret()
                var_regret_for_instance[i] = self.reward_std**2
                # regret_std = self.reward_std
            regret = np.mean(regret_for_instance, axis = 0)
            std_regret = np.mean(var_regret_for_instance, axis = 0)/self.repetitions
            self.regret_versus_K.append(regret[self.horizon-1])
            self.std_regret_versus_K.append(std_regret[self.horizon-1])


    def getRegret(self):    
        return np.cumsum(self.mu_max)-np.cumsum(self.mean_reward)

    def __str__(self):
        args = r"{}{}".format("$w={:g}$".format(self.window_size) if self.window_size != WINDOW_SIZE else "", ", with diminishing" if (self.gamma == "diminishing" or self.gamma == "d") else "")
        args = " ({})".format(args) if args else ""
        kl = r"{}".format("kl" if self.klUCB else "")
        return r"M-{}UCB{}".format(kl, args)

    def CD(self,Y):
        #print("w=",w,"b=",b," ",np.abs(np.sum(Y[int(w/2):w])-np.sum(Y[0:int(w/2)])))
        a = np.sum(Y[int(self.window_size/2):self.window_size])
        c = np.sum(Y[0:int(self.window_size/2)])
        if np.abs(a-c)>self.threshold_b:
            return True
        else: 
            return False

    def __saveRegret(self, path):
        if self.experiment == "M":
            print("Regret saving ...")
            filename = path+'/'+'Regrets_versus_M_'+self.__str__()+'.csv'
            np.savetxt(filename, self.regret_versus_M)
            filename = path+'/'+'std_Regrets_versus_M_'+self.__str__()+'.csv'
            np.savetxt(filename, self.std_regret_versus_M)
        elif self.experiment == "K":
            print("Regret saving ...")
            filename = path+'/'+'Regrets_versus_K_'+self.__str__()+'.csv'
            np.savetxt(filename, self.regret_versus_K)
            filename = path+'/'+'std_Regrets_versus_K_'+self.__str__()+'.csv'
            np.savetxt(filename, self.std_regret_versus_K)
        else:
            print("Regret saving ...")
            filename = path+'/'+'Regrets_'+self.__str__()+'.csv'
            np.savetxt(filename, self.regret_mean)
            filename = path+'/'+'std_Regrets_'+self.__str__()+'.csv'
            np.savetxt(filename, self.regret_std)

    def __saveRegret(self, path):
        if self.experiment == "M":
            print("Regret saving ...")
            filename = path+'/'+'Regrets_versus_M_'+self.__str__()+'.csv'
            np.savetxt(filename, self.regret_versus_M)
            filename = path+'/'+'std_Regrets_versus_M_'+self.__str__()+'.csv'
            np.savetxt(filename, self.std_regret_versus_M)
        elif self.experiment == "K":
            print("Regret saving ...")
            filename = path+'/'+'Regrets_versus_K_'+self.__str__()+'.csv'
            np.savetxt(filename, self.regret_versus_K)
            filename = path+'/'+'std_Regrets_versus_K_'+self.__str__()+'.csv'
            np.savetxt(filename, self.std_regret_versus_K)
        elif self.experiment == "T":
            print("Regret saving ...")
            filename = path+'/'+'Regrets_versus_T_'+self.__str__()+'.csv'
            np.savetxt(filename, self.regret_versus_T)
            filename = path+'/'+'std_Regrets_versus_T_'+self.__str__()+'.csv'
            np.savetxt(filename, self.std_regret_versus_T)
        else:
            print("Regret saving ...")
            filename = path+'/'+'Regrets_'+self.__str__()+'.csv'
            np.savetxt(filename, self.regret_mean)
            filename = path+'/'+'std_Regrets_'+self.__str__()+'.csv'
            np.savetxt(filename, self.regret_std)

    def __saveFileName(self, path):
        if self.experiment == "M":
            print("FileName saving ...")
            filename = path+'/Regrets_versus_M_file_name.txt'
            f = open(filename,'a')
            f.write('Regrets_versus_M_'+self.__str__()+'.csv\n')
            f.close()
            filename = path+'/std_Regrets_versus_M_file_name.txt'
            f = open(filename,'a')
            f.write('std_Regrets_versus_M_'+self.__str__()+'.csv\n')
            f.close()
        elif self.experiment == "K":
            print("FileName saving ...")
            filename = path+'/Regrets_versus_K_file_name.txt'
            f = open(filename,'a')
            f.write('Regrets_versus_K_'+self.__str__()+'.csv\n')
            f.close()
            filename = path+'/std_Regrets_versus_K_file_name.txt'
            f = open(filename,'a')
            f.write('std_Regrets_versus_K_'+self.__str__()+'.csv\n')
            f.close()
        elif self.experiment == "T":
            print("FileName saving ...")
            filename = path+'/Regrets_versus_T_file_name.txt'
            f = open(filename,'a')
            f.write('Regrets_versus_T_'+self.__str__()+'.csv\n')
            f.close()
            filename = path+'/std_Regrets_versus_T_file_name.txt'
            f = open(filename,'a')
            f.write('std_Regrets_versus_T_'+self.__str__()+'.csv\n')
            f.close()
        else:
            print("FileName saving ...")
            filename = path+'/Regrets_file_name.txt'
            f = open(filename,'a')
            f.write('Regrets_'+self.__str__()+'.csv\n')
            f.close()
            filename = path+'/std_Regrets_file_name.txt'
            f = open(filename,'a')
            f.write('std_Regrets_'+self.__str__()+'.csv\n')
            f.close()
            
    def klUp(self,p,level):
    # KL upper confidence bound:
    # return uM>p such that d(p,uM)=level 
        lM = p 
        uM = np.min([np.min([1, p + np.sqrt(level/2)]),1]) 
        for j in range(16):
            qM = (uM+lM)/2
            if self.klChange(p,qM) > level:
                uM= qM
            else:
                lM=qM
        return uM

    def klChange(self, p, q):
        # binary relative entropy
        res = 0
        if p != q:
            if p <= 0:
                p = np.finfo(float).eps
            if p >= 1:
                p = 1 - np.finfo(float).eps
            res = p * np.log(p / q) + (1 - p) * np.log((1 - p) / (1 - q))
        return res

    def Bernoulli(self, p):
        if np.random.rand()<p:
            return 1
        else :
            return 0
        
    def plot_arm(self, mean, std):
        if self.path is not None:
            color = ["#1F77B4", "#FF7F0E", "#2CA02C"]
            formats = ('png', 'pdf', 'eps')
            # plt.rcParams['figure.figsize'] = (12,8)
            # plt.rcParams['figure.dpi'] = 400
            # plt.rcParams['figure.figsize'] = (20,10)
            # plt.rcParams['figure.dpi'] = 200
            plt.rcParams['font.family'] = "sans-serif"
            plt.rcParams['font.sans-serif'] = "DejaVu Sans"
            plt.rcParams['mathtext.fontset'] = "cm"
            plt.rcParams['mathtext.rm'] = "serif"
            fig = plt.figure()
            x = np.linspace(0, self.horizon-1, self.horizon)
            lw = 3
            if self.nb_arms < 4:
                for i in range(self.nb_arms):
                    plt.plot(mean[i], label = "arm {}".format(i+1),color = color[i], lw=lw)
                    plt.fill_between(x, mean[i]-std[i], mean[i]+std[i],color = color[i], alpha = 0.2)   
            else:        
                for i in range(self.nb_arms):
                    plt.plot(mean[i], label = "arm {}".format(i+1), lw=lw)
                    plt.fill_between(x, mean[i]-std[i], mean[i]+std[i], alpha = 0.2)  
            plt.grid(True)
            plt.legend(loc = 'upper left')
            plt.xlabel(r"T")
            # plt.xlabel(r"Time steps $t = 1...T$")
            plt.ylabel(r"Number of arm $i$ pulled")
            if self.experiment == "M":
                savefig = self.path+'/'+'arm_'+self.__str__()+"_M="+str(self.nb_segs)
            elif self.experiment == "K":
                savefig = self.path+'/'+'arm_'+self.__str__()+"_K="+str(self.nb_arms)
            elif self.experiment == "T":
                savefig = self.path+'/'+'arm_'+self.__str__()+"_T="+str(self.horizon)
            else: 
                savefig = self.path+'/'+'arm_'+self.__str__()
            show_and_save(False , savefig=savefig, fig=fig, pickleit=None)
        return fig