import numpy as np
import matplotlib.pyplot as plt
import math
import random
from tqdm import tqdm
from skip_mech import skip_mech, skip_mech_MUCB
from plotsettings import BBOX_INCHES, signature, maximizeWindow, palette, makemarkers, add_percent_formatter, legend, show_and_save, nrows_ncols, violin_or_box_plot, adjust_xticks_subplots, table_to_latex

WINDOW_SIZE = None
THRESHOLD_B = None

class MUCB(object):
    def __init__(
        self, 
        repetitions,
        nb_arms,
        nb_break_points,
        horizon,
        Env,
        path = str,
        klUCB = False, 
        diminishing = False,
        skip = False,
        alpha = None,
        gamma = None,
        w = None,
        b = None,
        skip_uncertainty = 0     
    ):

        self.repetitions = repetitions
        self.nb_arms = nb_arms
        self.nb_break_points = nb_break_points
        self.horizon = horizon
        self.Env = Env.Env
        
        self.klUCB = klUCB
        self.diminishing = diminishing
        self.skip = skip
        self.path = path
        self.skip_uncertainty = skip_uncertainty
        if skip_uncertainty == None:
            self.skip_uncertainty = 0

        if alpha == None:
            alpha = np.sqrt(np.log(self.horizon) * self.nb_arms)
        assert alpha > 0, "Error: for Monitored_UCB policy the parameter alpha should be > 0 but it was given as {}.".format(alpha)  # DEBUG
        self.alpha = alpha  #: Parameter :math:`alpha` for the M-UCB algorithm.

        if w == None:
            w = (4/delta**2) * (np.sqrt(np.log(2 * self.nb_arms * self.horizon**2)) + np.sqrt(np.log(2 * self.horizon)))**2
            w = int(np.ceil(w))
            if w % 2 != 0:
                w = 2*(1 + w//2)
        assert w > 0, "Error: for Monitored_UCB policy the parameter b should be > 0 but it was given as {}.".format(b)  # DEBUG
        self.window_size = w

        if b == None:
            b = np.sqrt((w/2) * np.log(2 * self.nb_arms * self.horizon**2))
        assert b > 0, "Error: for Monitored_UCB policy the parameter alpha should be > 0 but it was given as {}.".format(alpha)  # DEBUG
        self.threshold_b = b

        if self.diminishing:
            self.u = np.ceil((2*self.alpha - self.nb_arms/(4*self.alpha))**2)
        else:
            if gamma == None:
                gamma = np.sqrt(self.nb_break_points*self.nb_arms*np.log(self.horizon)/self.horizon)
                assert gamma > 0, "Error: for Monitored_UCB policy the parameter gamma should be > 0 but it was given as {}.".format(gamma)  # DEBUG
                self.gamma = gamma  #: Parameter :math:`gamma` for the M-UCB algorithm.
                print("exploration rate = {}".format(self.gamma))

    def __call__(self, diminishing = None):
        if diminishing != None:
            self.diminishing = diminishing
        if self.diminishing:
            X_mean, X_std, A_mean, A_std = self.mainAlgWithDiminishing()    
        else:
            X_mean, X_std, A_mean, A_std = self.mainAlg()
        
        return X_mean, X_std, A_mean, A_std
        
    def __str__(self):
        args = r"{}".format("$w={:g}$".format(self.window_size) if self.window_size != WINDOW_SIZE else "")
        if self.diminishing and self.skip:
            args = args + f", with diminishing and skipping mechanism B = {self.skip_uncertainty}"
        elif self.diminishing: 
            args = args + ", with diminishing"
        elif self.skip:
            args = args + f", with skipping mechanism B = {self.skip_uncertainty}"
        args = " ({})".format(args) if args else ""
        kl = r"{}".format("kl" if self.klUCB else "")
        return r"M-{}UCB{}".format(kl, args)

    def mainAlg(self):
        print("Running ...M-UCB with gamma = {}".format(self.gamma))
        A_mem = np.zeros((self.repetitions, self.nb_arms, self.horizon), dtype = int)
        X_mem = np.zeros((self.repetitions, self.horizon), dtype = float)
        cd_true = 0
        for i in tqdm(range(self.repetitions)):
            tau = 0
            UCB = np.zeros(self.nb_arms,dtype = float)
            UCB[:] = 10000
            Z = np.zeros((self.nb_arms, self.horizon), dtype= int )
            X = np.zeros( self.horizon, dtype = float)
            n = np.zeros(self.nb_arms, dtype = int)
            n_for_skip = np.zeros(self.nb_arms, dtype = int)
            skip_duration = False
            A = np.zeros(self.horizon, dtype = int)
            for t in range(self.horizon):
                a = (t-tau) % np.floor( self.nb_arms  /self.gamma )
                if a < self.nb_arms : 
                    A[t] = a
                else:
                    for k in range(self.nb_arms):
                        if n[k]>0:
                            UCB[k] = np.sum(Z[k][0:n[k]])/n[k]+np.sqrt(2*np.log(t-tau)/n[k])
                    A[t] = np.argmax(UCB)
                X_mem[i][t] = self.Env[i][A[t]][t]
                A_mem[i][A[t]][t] = 1
                Z[A[t]][n[A[t]]] = X_mem[i][t]
                n[A[t]] = n[A[t]] + 1
                if skip_duration:
                    n_for_skip[A[t]] = n_for_skip[A[t]] + 1
                if np.min(n_for_skip)> self.window_size:
                    n_for_skip = np.zeros(self.nb_arms, dtype = int)
                    skip_duration = False
                if n[A[t]] >= self.window_size:
                    if self.CD(Z[A[t],n[A[t]]-self.window_size:n[A[t]]-1]):
                        skip_duration = True
                        if self.skip:
                            if skip_mech_MUCB(Z, n, self.window_size/2, A[t], self.skip_uncertainty) == False:
                                cd_true += 1
                                tau = t
                                n_for_skip = np.zeros(self.nb_arms, dtype = int)
                                skip_duration = False
                                n[:] = 0
                                Z[A[t],:] = 0
                        else:
                            cd_true += 1
                            tau = t
                            n[:] = 0
                            Z[A[t],:] = 0

        print("Average alarm {} times".format(cd_true/self.repetitions))
        X_mean = np.mean(X_mem, axis=0)
        X_std = np.std(X_mem, axis=0)
        A_mem_cum = np.cumsum(A_mem,axis=2)
        A_std = np.std(A_mem_cum, axis=0)
        A_mean = np.mean(A_mem_cum, axis=0)
        # self.plot_arm(A_mean, A_std)
        return X_mean, X_std, A_mean, A_std

    def mainAlgWithDiminishing(self):
        print("Running ...M-UCB with diminishing")
        A_mem = np.zeros((self.repetitions, self.nb_arms, self.horizon), dtype = int)
        X_mem = np.zeros((self.repetitions, self.horizon), dtype = float)
        cd_true = 0
        for i in tqdm(range(self.repetitions)):
            tau = 0
            u = self.u
            UCB = np.zeros(self.nb_arms,dtype = float)
            UCB[:] = 10000
            Z = np.zeros((self.nb_arms, self.horizon), dtype= int )
            X = np.zeros(self.horizon, dtype = float)     
            n = np.zeros(self.nb_arms, dtype = int)
            n_for_skip = np.zeros(self.nb_arms, dtype = int)
            skip_duration = False
            A = np.zeros(self.horizon, dtype = int)
            for t in range(self.horizon):
                if u <= (t - tau) and (t - tau) < u + self.nb_arms:
                    A[t] = (t - tau) - u 
                else:
                    if (t - tau) == u + self.nb_arms:
                        u = np.ceil(u+np.sqrt(u)*self.nb_arms/self.alpha+(self.nb_arms/(2*self.alpha))**2)
                    for k in range(self.nb_arms):
                        UCB[k] = np.sum(Z[k][0:n[k]])/n[k]+np.sqrt(2*np.log(t-tau)/n[k])
                    A[t] = np.argmax(UCB)
                X_mem[i][t] = self.Env[i][A[t]][t]
                A_mem[i][A[t]][t] = 1
                Z[A[t]][n[A[t]]] = X_mem[i][t]
                n[A[t]] = n[A[t]] + 1
                if skip_duration:
                    n_for_skip[A[t]] = n_for_skip[A[t]] + 1
                if np.min(n_for_skip)> self.window_size:
                    n_for_skip = np.zeros(self.nb_arms, dtype = int)
                    skip_duration = False
                if n[A[t]] >= self.window_size:
                    if self.CD(Z[A[t],n[A[t]]-self.window_size:n[A[t]]-1]):
                        skip_duration = True
                        if self.skip:
                            if skip_mech_MUCB(Z, n, self.window_size/2, A[t], self.skip_uncertainty) == False:
                                cd_true += 1
                                tau = t
                                n_for_skip = np.zeros(self.nb_arms, dtype = int)
                                skip_duration = False
                                u = self.u
                                n[:] = 0
                                Z[A[t],:] = 0
                        else:
                            cd_true += 1
                            tau = t
                            u = self.u
                            n[:] = 0
                            Z[A[t],:] = 0
        print("Average alarm {} times".format(cd_true/self.repetitions))
        X_mean = np.mean(X_mem, axis=0)
        X_std = np.std(X_mem, axis=0)
        A_mem_cum = np.cumsum(A_mem,axis=2)
        A_std = np.std(A_mem_cum, axis=0)
        A_mean = np.mean(A_mem_cum, axis=0)
        # self.plot_arm(A_mean, A_std)
        return X_mean, X_std, A_mean, A_std   
    
    def CD(self,Y):
        #print("w=",w,"b=",b," ",np.abs(np.sum(Y[int(w/2):w])-np.sum(Y[0:int(w/2)])))
        a = np.sum(Y[int(self.window_size/2):self.window_size])
        c = np.sum(Y[0:int(self.window_size/2)])
        if np.abs(a-c)>self.threshold_b:
            return True
        else: 
            return False

    
    
        