import numpy as np
import matplotlib.pyplot as plt
import math
import random
from tqdm import tqdm
from skip_mech import skip_mech
from plotsettings import BBOX_INCHES, signature, maximizeWindow, palette, makemarkers, add_percent_formatter, legend, show_and_save, nrows_ncols, violin_or_box_plot, adjust_xticks_subplots, table_to_latex

M = 50

class GLRUCB(object):
    def __init__(
        self, 
        repetitions,
        nb_arms,
        nb_break_points,
        horizon,
        Env,
        path = str,
        klUCB = False,
        diminishing = False,
        skip = False,
        alpha = None,
        gamma = None,
        nb_break_points_known = False,
        subSample = 11,
        subSample2 = 7,    
        skip_uncertainty = 0
    ):
        self.repetitions = repetitions
        self.nb_arms = nb_arms
        self.nb_break_points = nb_break_points
        self.horizon = horizon
        self.Env = Env.Env
        
        self.klUCB = klUCB
        self.diminishing = diminishing
        self.skip = skip
        self.path = path
        self.subSample = subSample
        self.subSample2 = subSample2
        self.nb_break_points_known = nb_break_points_known
        self.skip_uncertainty = skip_uncertainty
        if skip_uncertainty == None:
            self.skip_uncertainty = 0

        if alpha == None:
            alpha = np.sqrt(np.log(self.horizon) * self.nb_arms)
        assert alpha > 0, "Error: for Monitored_UCB policy the parameter alpha should be > 0 but it was given as {}.".format(alpha)  # DEBUG
        self.alpha = alpha  #: Parameter :math:`alpha` for the M-UCB algorithm.

        if self.nb_break_points_known:
            pass
        else:
            self.alpha0 = np.sqrt(np.log(self.horizon)/self.horizon)

        if self.diminishing:
            self.u = np.ceil((2*self.alpha - self.nb_arms/(4*self.alpha))**2)
        else:
            if gamma == None:
                gamma = np.sqrt(self.nb_break_points*self.nb_arms*np.log(self.horizon)/self.horizon)
                assert gamma > 0, "Error: for Monitored_UCB policy the parameter gamma should be > 0 but it was given as {}.".format(gamma)  # DEBUG
                self.gamma = gamma  #: Parameter :math:`gamma` for the M-UCB algorithm.
                print("exploration rate = {}".format(self.gamma))


    def __call__(self, diminishing = None):
        if diminishing != None:
            self.diminishing = diminishing
        if self.diminishing:
            X_mean, X_std, A_mean, A_std = self.mainAlgWithDiminishing()    
        else:
            if self.nb_break_points_known:
                X_mean, X_std, A_mean, A_std = self.mainAlg()
            else:
                X_mean, X_std, A_mean, A_std = self.mainAlgWithIncreasing()

        
        return X_mean, X_std, A_mean, A_std
    
    def __str__(self):
        kl = "{}".format("kl" if (self.klUCB) else "")
        args = r"{}".format("" if (self.nb_break_points_known == "known") else "")
        if self.diminishing and self.skip:
            args = args + f"with diminishing and skipping mechanism B = {self.skip_uncertainty}"
        elif self.diminishing: 
            args = args + "with diminishing"
        elif self.skip:
            args = args + f"with skipping mechanism B = {self.skip_uncertainty}"
        args = " ({})".format(args) if args else ""
        return r"GLR-{}UCB{}".format(kl,args)

    def mainAlg(self):
        print("Running ...GLR-UCB with gamma = {}".format(self.gamma))
        A_mem = np.zeros((self.repetitions, self.nb_arms, self.horizon), dtype = int)
        X_mem = np.zeros((self.repetitions, self.horizon), dtype = float)
        cd_true = 0
        for i in tqdm(range(self.repetitions)):
            tau = 0
            tloc = 0 # number of samples in current episode 
            UCB = np.zeros(self.nb_arms,dtype = float)
            UCB[:] = 10000
            Z = np.zeros((self.nb_arms, self.horizon), dtype= int )
            X = np.zeros( self.horizon, dtype = float)
            n = np.zeros(self.nb_arms, dtype = int)
            A = np.zeros(self.horizon, dtype = int)
            n_for_skip = np.zeros(self.nb_arms, dtype = int)
            skip_duration = False
            tloc = 0
            for t in range(self.horizon):
                a = (t-tau) % np.floor( self.nb_arms  /self.gamma )
                if a < self.nb_arms : 
                    A[t] = a
                else:
                    for k in range(self.nb_arms):
                        if n[k]>0:
                            UCB[k] = np.sum(Z[k][0:n[k]])/n[k]+np.sqrt(2*np.log(t-tau)/n[k])
                    A[t] = np.argmax(UCB)
                X_mem[i][t] = self.Env[i][A[t]][t]
                A_mem[i][A[t]][t] = 1
                Z[A[t]][n[A[t]]] = X_mem[i][t]
                n[A[t]] = n[A[t]] + 1
                if skip_duration:
                    n_for_skip[A[t]] = n_for_skip[A[t]] + 1
                tloc=tloc+1
                if tloc % self.subSample == 1:
                    sums = np.cumsum(Z, axis = 1)
                    if self.CD(n[A[t]], sums[A[t]]):
                        if self.skip:                      
                            skip_duration = True
                            if skip_mech(Z, n, n_for_skip, M, A[t], self.skip_uncertainty) == False:
                                log_str = f"{self.__str__()} t = {t} Trigger CD and reset, n = {n}, n_for_skip = {n_for_skip}, A[t] = {A[t]}"
                                cd_true += 1
                                tau = t
                                tloc = 0
                                n_for_skip = np.zeros(self.nb_arms, dtype = int)
                                skip_duration = False
                                n[:] = 0
                                Z[A[t],:] = 0
                            else:
                                log_str = f"{self.__str__()} t = {t} Trigger CD and ignore, n = {n}, n_for_skip = {n_for_skip}, A[t] = {A[t]}"
                            # log_message(message = log_str, path = self.path)
                        else:
                            cd_true += 1
                            tau = t
                            tloc = 0
                            n[:] = 0
                            Z[A[t],:] = 0
                            # print(f"Alarm at t = {t}")
        print("Average alarm {} times".format(cd_true/self.repetitions))
        X_mem_cum = np.cumsum(X_mem,axis=1)
        # print("Roll data saving ...")
        # if path != None:
        #     filename = path+'/'+'roll_data'+self.__str__()+str(random.randint(100000,999999))+'.csv'
        #     np.savetxt(filename, X_mem_cum ,delimiter=",")
        X_mean = np.mean(X_mem, axis=0)
        X_std = np.std(X_mem_cum, axis=0)
        A_mem_cum = np.cumsum(A_mem,axis=2)
        A_std = np.std(A_mem_cum, axis=0)
        A_mean = np.mean(A_mem_cum, axis=0)
        # self.plot_arm(A_mean, A_std)
        return X_mean, X_std, A_mean, A_std
    
    def mainAlgWithDiminishing(self):
        print("Running ...GLR-UCB with diminishing".format(self.alpha))
        A_mem = np.zeros((self.repetitions, self.nb_arms, self.horizon), dtype = int)
        X_mem = np.zeros((self.repetitions, self.horizon), dtype = float)
        cd_true = 0
        for i in tqdm(range(self.repetitions)):
            tau = 0
            u = self.u
            tloc = 0 # number of samples in current episode 
            UCB = np.zeros(self.nb_arms,dtype = float)
            UCB[:] = 10000
            Z = np.zeros((self.nb_arms, self.horizon), dtype= int )
            X = np.zeros( self.horizon, dtype = float)
            n = np.zeros(self.nb_arms, dtype = int)
            n_for_skip = np.zeros(self.nb_arms, dtype = int)
            skip_duration = False
            A = np.zeros(self.horizon, dtype = int)
            tloc = 0
            for t in range(self.horizon):
                if u <= (t - tau) and (t - tau) < u + self.nb_arms:
                    A[t] = (t - tau) - u 
                else:
                    if (t - tau) == u + self.nb_arms:
                        u = np.ceil(u+np.sqrt(u)*self.nb_arms/self.alpha+(self.nb_arms/(2*self.alpha))**2)
                    for k in range(self.nb_arms):
                        if n[k]>0:
                            UCB[k] = np.sum(Z[k][0:n[k]])/n[k]+np.sqrt(2*np.log(t-tau)/n[k])
                    A[t] = np.argmax(UCB)
                X_mem[i][t] = self.Env[i][A[t]][t]
                A_mem[i][A[t]][t] = 1
                Z[A[t]][n[A[t]]] = X_mem[i][t]
                n[A[t]] = n[A[t]] + 1
                if skip_duration:
                    n_for_skip[A[t]] = n_for_skip[A[t]] + 1
                tloc=tloc+1
                if tloc % self.subSample == 1:
                    sums = np.cumsum(Z, axis = 1)
                    if self.CD(n[A[t]], sums[A[t]]):
                        if self.skip:                      
                            skip_duration = True
                            if skip_mech(Z, n, n_for_skip, M, A[t], self.skip_uncertainty) == False:
                                cd_true += 1
                                tau = t
                                u = self.u
                                tloc = 0
                                n[:] = 0
                                Z[A[t],:] = 0
                                log_str = f"{self.__str__()} t = {t} Trigger CD and reset, n = {n}, n_for_skip = {n_for_skip}, A[t] = {A[t]}"
                        # else:
                        #     log_str = f"{self.__str__()} t = {t} Trigger CD and ignore, n = {n}, n_for_skip = {n_for_skip}, A[t] = {A[t]}"
                        # log_message(message = log_str, path = self.path)
                        else:
                            cd_true += 1
                            tau = t
                            u = self.u
                            tloc = 0
                            n[:] = 0
                            Z[A[t],:] = 0    
        print("Average alarm {} times".format(cd_true/self.repetitions))
        X_mem_cum = np.cumsum(X_mem,axis=1)
        # print("Roll data saving ...")
        # if path != None:
        #     filename = path+'/'+'roll_data'+self.__str__()+str(random.randint(100000,999999))+'.csv'
        #     np.savetxt(filename, X_mem_cum ,delimiter=",")
        X_mean = np.mean(X_mem, axis=0)
        X_std = np.std(X_mem_cum, axis=0)
        A_mem_cum = np.cumsum(A_mem,axis=2)
        A_std = np.std(A_mem_cum, axis=0)
        A_mean = np.mean(A_mem_cum, axis=0)
        # self.plot_arm(A_mean, A_std)
        return X_mean, X_std, A_mean, A_std
    
    def mainAlgWithIncreasing(self):
        print("Running ...GLR-UCB with alpha0 = {}".format(self.alpha0))
        A_mem = np.zeros((self.repetitions, self.nb_arms, self.horizon), dtype = int)
        X_mem = np.zeros((self.repetitions, self.horizon), dtype = float)
        cd_true = 0
        for i in tqdm(range(self.repetitions)):
            tau = 0
            UCB = np.zeros(self.nb_arms,dtype = float)
            UCB[:] = 10000
            Z = np.zeros((self.nb_arms, self.horizon), dtype= int )
            X = np.zeros( self.horizon, dtype = float)     
            n = np.zeros(self.nb_arms, dtype = int)
            A = np.zeros(self.horizon, dtype = int)
            seg_th = 1
            self.alpha = np.sqrt(seg_th)*self.alpha0
            tloc = 0
            for t in range(self.horizon):
                a = (t-tau) % np.floor( self.nb_arms  /self.alpha )
                if a < self.nb_arms : 
                    A[t] = a
                else:
                    for k in range(self.nb_arms):
                        if n[k]>0:
                            UCB[k] = np.sum(Z[k][0:n[k]])/n[k]+np.sqrt(2*np.log(t-tau)/n[k])
                    A[t] = np.argmax(UCB)
                X_mem[i][t] = self.Env[i][A[t]][t]
                A_mem[i][A[t]][t] = 1
                Z[A[t]][n[A[t]]] = X_mem[i][t]
                n[A[t]] = n[A[t]] + 1
                tloc=tloc+1
                if tloc % self.subSample == 1:
                    sums = np.cumsum(Z, axis = 1)
                    if self.CD(n[A[t]], sums[A[t]]):
                        tau = t
                        tloc = 0
                        n[:] = 0
                        Z[A[t],:] = 0
                        seg_th += 1
                        self.alpha = np.sqrt(seg_th)*self.alpha0
                        cd_true += 1
        print("Average alarm {} times".format(cd_true/self.repetitions))
        X_mem_cum = np.cumsum(X_mem,axis=1)
        # print("Roll data saving ...")
        # if path != None:
        #     filename = path+'/'+'roll_data'+self.__str__()+str(random.randint(100000,999999))+'.csv'
        #     np.savetxt(filename, X_mem_cum ,delimiter=",")
        X_mean = np.mean(X_mem, axis=0)
        X_std = np.std(X_mem_cum, axis=0)
        A_mem_cum = np.cumsum(A_mem,axis=2)
        A_std = np.std(A_mem_cum, axis=0)
        A_mean = np.mean(A_mem_cum, axis=0)
        # self.plot_arm(A_mean, A_std)
        return X_mean, X_std, A_mean, A_std

    def beta(self, nb, delta=None, variant=None):
        if delta is None:
            delta = 1.0 / np.sqrt(self.horizon)
        if variant is not None:
            if variant == 0:
                c = -np.log(delta) + (3/2) * np.log(nb) + np.log(3)
            elif variant == 1:
                c = -np.log(delta) + np.log(1 + np.log(nb))
        else:
            c = -np.log(delta) + (3/2) * np.log(nb) + np.log(3)
        if c < 0 or np.isinf(c):
            c = float('+inf')
        return c
    
    def CD(self,nb, sums):
        s = 0
        while s < nb :  
            if s % self.subSample2 == 1:
                draw1 = s + 1 
                draw2 = (nb - s) + 1
                mu1 = sums[s] / draw1
                mu2 = (sums[nb-1] - sums[s]) / draw2 
                mu = sums[nb-1] / nb
                if (draw1 * self.klChange(mu1, mu) + draw2 * self.klChange(mu2, mu)) > self.beta(nb):
                    return True
            s += 1
        return False
    
    def klChange(self, p, q):
        # binary relative entropy
        res = 0
        if p != q:
            if p <= 0:
                p = np.finfo(float).eps
            if p >= 1:
                p = 1 - np.finfo(float).eps
            res = p * np.log(p / q) + (1 - p) * np.log((1 - p) / (1 - q))
        return res
    
    def klUp(self,p,level):
    # KL upper confidence bound:
    # return uM>p such that d(p,uM)=level 
        lM = p 
        uM = np.min([np.min([1, p + np.sqrt(level/2)]),1]) 
        for j in range(16):
            qM = (uM+lM)/2
            if self.klChange(p,qM) > level:
                uM= qM
            else:
                lM=qM
        return uM