import numpy as np
import matplotlib.pyplot as plt
import math
import random
from tqdm import tqdm
from skip_mech import skip_mech
from plotsettings import BBOX_INCHES, signature, maximizeWindow, palette, makemarkers, add_percent_formatter, legend, show_and_save, nrows_ncols, violin_or_box_plot, adjust_xticks_subplots, table_to_latex
from useful_func import log_message

H = None
EPSILION = None
M = 50

class CusumUCB(object):
    def __init__(
        self, 
        repetitions,
        nb_arms,
        nb_break_points,
        horizon,
        Env,
        klUCB = False,
        path = str,
        diminishing = False,
        skip = False,
        alpha = None,
        gamma = None,
        h = H,
        epsilon = EPSILION,
        m = M,
        skip_uncertainty = 0            
    ):
        self.repetitions = repetitions
        self.nb_arms = nb_arms
        self.nb_break_points = nb_break_points
        self.horizon = horizon
        self.Env = Env.Env

        self.klUCB = klUCB
        self.diminishing = diminishing
        self.skip = skip
        self.path = path

        self.skip_uncertainty = skip_uncertainty
        if skip_uncertainty == None:
            self.skip_uncertainty = 0

        if alpha == None:
            alpha = np.sqrt(np.log(self.horizon) * self.nb_arms)
        assert alpha > 0, "Error: for Monitored_UCB policy the parameter alpha should be > 0 but it was given as {}.".format(alpha)  # DEBUG
        self.alpha = alpha  #: Parameter :math:`alpha` for the M-UCB algorithm.

        if self.diminishing:
            self.u = np.ceil((2*self.alpha - self.nb_arms/(4*self.alpha))**2)
        else:
            if gamma == None:
                gamma = np.sqrt(self.nb_break_points*self.nb_arms*np.log(self.horizon)/self.horizon)
                assert gamma > 0, "Error: for Monitored_UCB policy the parameter gamma should be > 0 but it was given as {}.".format(gamma)  # DEBUG
                self.gamma = gamma  #: Parameter :math:`gamma` for the M-UCB algorithm.
                print("exploration rate = {}".format(self.gamma))

        if h is None or h == 'auto':
            # XXX Estimate w from Remark 1
            h = np.log(self.horizon/self.nb_break_points)
            H = h
            assert h > 0, "Error: for Monitored_UCB policy the parameter h should be > 0 but it was given as {}.".format(h)  # DEBUG
        self.threshold_h = h  #: Parameter :math:`w` for the M-UCB algorithm.
        
        if epsilon is None or epsilon == 'auto':
            # XXX compute b from the formula from Theorem 6.1
            epsilon = 0.05
            EPSILION = epsilon
            assert epsilon > 0, "Error: for Monitored_UCB policy the parameter b should be > 0 but it was given as {}.".format(epsilon)  # DEBUG
        self.epsilon = epsilon  #: Parameter :math:`epsilon` for the CUSUM-UCB algorithm.
        
        self.M = m
        if self.M is None:
            self.M = M
        print(self.M)

        

    def __call__(self, diminishing = None):
        if diminishing != None:
            self.diminishing = diminishing
        if self.diminishing:
            X_mean, X_std, A_mean, A_std = self.mainAlgWithDiminishing()    
        else:
            X_mean, X_std, A_mean, A_std = self.mainAlg()
        
        return X_mean, X_std, A_mean, A_std
    
    def __str__(self):
        args = r"{}".format("$M={:g}$".format(self.M) if self.M != M else "")
        if self.diminishing and self.skip:
            args = args + f"with diminishing and skipping mechanism B = {self.skip_uncertainty}"
        elif self.diminishing: 
            args = args + "with diminishing"
        elif self.skip:
            args = args + f"with skipping mechanism B = {self.skip_uncertainty}"
        args = " ({})".format(args) if args else ""
        kl = r"{}".format("kl" if self.klUCB else "")
        return r"CUSUM-{}UCB{}".format(kl,args)
    
    def mainAlg(self):
        print("Running ...CUSUM-UCB with gamma = {}".format(self.gamma))
        A_mem = np.zeros((self.repetitions, self.nb_arms, self.horizon), dtype = int)
        X_mem = np.zeros((self.repetitions, self.horizon), dtype = float)
        cd_true = 0
        for i in tqdm(range(self.repetitions)):
            tau = 0
            UCB = np.zeros(self.nb_arms,dtype = float)
            UCB[:] = 10000
            Z = np.zeros((self.nb_arms, self.horizon), dtype= int )
            X = np.zeros( self.horizon, dtype = float)
            n = np.zeros(self.nb_arms, dtype = int)
            n_for_skip = np.zeros(self.nb_arms, dtype = int)
            skip_duration = False
            A = np.zeros(self.horizon, dtype = int)
            for t in range(self.horizon):
                a = (t-tau) % np.floor( self.nb_arms  /self.gamma )
                if a < self.nb_arms : 
                    A[t] = a
                # elif skip_duration:
                #     A[t] = t % self.nb_arms
                else:
                    for k in range(self.nb_arms):
                        UCB[k] = np.sum(Z[k][0:n[k]])/n[k]+np.sqrt(2*np.log(t-tau)/n[k])
                    A[t] = np.argmax(UCB)
                X_mem[i][t] = self.Env[i][A[t]][t]
                A_mem[i][A[t]][t] = 1
                Z[A[t]][n[A[t]]] = X_mem[i][t]
                n[A[t]] = n[A[t]] + 1
                if skip_duration:
                    n_for_skip[A[t]] = n_for_skip[A[t]] + 1
                if np.min(n_for_skip)> M+1:
                    n_for_skip = np.zeros(self.nb_arms, dtype = int)
                    skip_duration = False
                if self.CD(Z[A[t],0:n[A[t]]]):
                    if self.skip:                      
                        skip_duration = True
                        if skip_mech(Z, n, n_for_skip, M, A[t], self.skip_uncertainty) == False:
                            log_str = f"{self.__str__()} t = {t} Trigger CD and reset, n = {n}, n_for_skip = {n_for_skip}, A[t] = {A[t]}"
                            cd_true += 1
                            tau = t
                            n_for_skip = np.zeros(self.nb_arms, dtype = int)
                            skip_duration = False
                            n[:] = 0
                            Z[A[t],:] = 0
                        else:
                            log_str = f"{self.__str__()} t = {t} Trigger CD and ignore, n = {n}, n_for_skip = {n_for_skip}, A[t] = {A[t]}"
                        # log_message(message = log_str, path = self.path)
                    else:
                        cd_true += 1
                        tau = t
                        n[:] = 0
                        Z[A[t],:] = 0
                        # print(f"Alarm at t = {t}")
        print("Average alarm {} times".format(cd_true/self.repetitions))
        X_mem_cum = np.cumsum(X_mem,axis=1)
        X_mean = np.mean(X_mem, axis=0)
        X_std = np.std(X_mem_cum, axis=0)
        A_mem_cum = np.cumsum(A_mem,axis=2)
        A_std = np.std(A_mem_cum, axis=0)
        A_mean = np.mean(A_mem_cum, axis=0)
        # self.plot_arm(A_mean, A_std)
        return X_mean, X_std, A_mean, A_std
    
    def mainAlgWithDiminishing(self):
        print("Running ...CUSUM-UCB with diminishing")
        A_mem = np.zeros((self.repetitions, self.nb_arms, self.horizon), dtype = int)
        X_mem = np.zeros((self.repetitions, self.horizon), dtype = float)
        cd_true = 0
        for i in tqdm(range(self.repetitions)):
            tau = 0
            u = self.u
            UCB = np.zeros(self.nb_arms,dtype = float)
            UCB[:] = 10000
            Z = np.zeros((self.nb_arms, self.horizon), dtype= int )
            X = np.zeros(self.horizon, dtype = float)     
            n = np.zeros(self.nb_arms, dtype = int)
            n_for_skip = np.zeros(self.nb_arms, dtype = int)
            skip_duration = False
            A = np.zeros(self.horizon, dtype = int)
            for t in range(self.horizon):
                if u <= (t - tau) and (t - tau) < u + self.nb_arms:
                    A[t] = (t - tau) - u 
                # elif skip_duration:
                #         A[t] = t % self.nb_arms
                else:
                    if (t - tau) == u + self.nb_arms:
                        u = np.ceil(u+np.sqrt(u)*self.nb_arms/self.alpha+(self.nb_arms/(2*self.alpha))**2)
                    for k in range(self.nb_arms):
                        UCB[k] = np.sum(Z[k][0:n[k]])/n[k]+np.sqrt(2*np.log(t-tau)/n[k])
                    A[t] = np.argmax(UCB)
                X_mem[i][t] = self.Env[i][A[t]][t]
                A_mem[i][A[t]][t] = 1
                Z[A[t]][n[A[t]]] = X_mem[i][t]
                n[A[t]] = n[A[t]] + 1
                if skip_duration:
                    n_for_skip[A[t]] = n_for_skip[A[t]] + 1
                if np.min(n_for_skip)> 50:
                    n_for_skip = np.zeros(self.nb_arms, dtype = int)
                    skip_duration = False
                if self.CD(Z[A[t],0:n[A[t]]]):
                    if self.skip:                      
                        skip_duration = True
                        if skip_mech(Z, n, n_for_skip, M, A[t], self.skip_uncertainty) == False:
                            cd_true += 1
                            tau = t
                            n_for_skip = np.zeros(self.nb_arms, dtype = int)
                            skip_duration = False
                            n[:] = 0
                            Z[A[t],:] = 0
                            log_str = f"{self.__str__()} t = {t} Trigger CD and reset, n = {n}, n_for_skip = {n_for_skip}, A[t] = {A[t]}"
                        # else:
                        #     log_str = f"{self.__str__()} t = {t} Trigger CD and ignore, n = {n}, n_for_skip = {n_for_skip}, A[t] = {A[t]}"
                        # log_message(message = log_str, path = self.path)
                    else:
                        cd_true += 1
                        tau = t
                        u = self.u
                        n[:] = 0
                        Z[A[t],:] = 0
        print("Average alarm {} times".format(cd_true/self.repetitions))
        X_mem_cum = np.cumsum(X_mem,axis=1)
        X_mean = np.mean(X_mem, axis=0)
        X_std = np.std(X_mem_cum, axis=0)
        A_mem_cum = np.cumsum(A_mem,axis=2)
        A_std = np.std(A_mem_cum, axis=0)
        A_mean = np.mean(A_mem_cum, axis=0)
        # self.plot_arm(A_mean, A_std)
        return X_mean, X_std, A_mean, A_std
    
    def CD(self, data_y):
        r""" Detect a change in the current arm, using the two-sided CUSUM algorithm [Page, 1954].

        - For each *data* k, compute:

        .. math::

            s_k^- &= (y_k - \hat{u}_0 - \varepsilon) 1(k > M),\\
            s_k^+ &= (\hat{u}_0 - y_k - \varepsilon) 1(k > M),\\
            g_k^+ &= \max(0, g_{k-1}^+ + s_k^+),\\
            g_k^- &= \max(0, g_{k-1}^- + s_k^-).

        - The change is detected if :math:`\max(g_k^+, g_k^-) > h`, where :attr:`threshold_h` is the threshold of the test,
        - And :math:`\hat{u}_0 = \frac{1}{M} \sum_{k=1}^{M} y_k` is the mean of the first M samples, where M is :attr:`M` the min number of observation between change points.
        """
        gp, gm = 0, 0
        if len(data_y) <= self.M:
            return False
        # First we use the first M samples to calculate the average :math:`\hat{u_0}`.
        u0hat = np.mean(data_y[:M])  # DONE okay this is efficient we don't compute the same means too many times!
        for y_k in data_y: # no need to multiply by (k > self.M)
            gp = max(0, gp + (u0hat - y_k - self.epsilon))
            gm = max(0, gm + (y_k - u0hat - self.epsilon))
            if gp >= self.threshold_h or gm >= self.threshold_h:
                return True
        return False