import numpy as np
import matplotlib.pyplot as plt
import math
import random
from tqdm import tqdm
from skip_mech import skip_mech
from plotsettings import BBOX_INCHES, signature, maximizeWindow, palette, makemarkers, add_percent_formatter, legend, show_and_save, nrows_ncols, violin_or_box_plot, adjust_xticks_subplots, table_to_latex


def Bernoulli(p):
    if np.random.rand()<p:
        return 1
    else :
        return 0

class Rho(object):
    def __init__(self, nb_arms, horizon, delta = 0.1):
        self.nb_arms = nb_arms
        self.horizon = horizon
        self.delta = delta

    def __call__(self, t):
        return np.sqrt(self.nb_arms*np.log(self.horizon/self.delta)/t) + self.nb_arms*np.log(self.horizon/self.delta)/t
    
class Rho_hat(object):
    def __init__(self, nb_arms, horizon, delta = 0.1):
        self.nb_arms = nb_arms
        self.horizon = horizon
        self.n_hat = np.log2(horizon)+1
        self.delta = delta

    def __call__(self, t):
        rho = Rho(nb_arms = self.nb_arms, horizon = self.horizon)
        return 6*self.n_hat*np.log(self.horizon/self.delta)*rho(t)
    
def schedule(n, t, horizon, nb_arms):
    schedule_list = []
    for tau in range(2**n):
        for m in range(n,-1,-1):
            if tau%(2**m) == 0:
                rho = Rho(nb_arms=nb_arms, horizon=horizon)
                # print(rho(t=2**n)/rho(t=2**m))
                if Bernoulli(rho(t=2**n)/rho(t=2**m))==1:
                    algs = tau + t
                    alge = tau+2**m + t - 1
                    schedule_list.append({"alg.s":algs, "alg.e":alge, "order-m":m})
    return schedule_list

class Malg(object):
    def __init__(self, n, t, horizon, nb_arms):
        self.n = n
        self.schedule_list = schedule(n, t, horizon, nb_arms)
    def __call__(self, alg):
        UCB = np.ones(alg.nb_arms, dtype = float)*10000
        for k in range(alg.nb_arms):
            if alg.n[k]>0:
                UCB[k] = np.sum(alg.Z[k][0:alg.n[k]])/alg.n[k]+np.sqrt(2*np.log(2**alg.order_m)/alg.n[k])
        pi = np.argmax(UCB)
        g = np.max(UCB)

        return pi, g
    
class instance(object):
    def __init__(self, nb_arms , order_m):
        self.order_m = order_m
        self.reward = []
        self.nb_arms = nb_arms
        self.Z = np.zeros((nb_arms, 2**order_m), dtype= int )
        self.n = np.zeros(nb_arms, dtype = int)
        
class Master(object):
    def __init__(
        self,
        repetitions,
        nb_arms,
        nb_break_points,
        horizon,
        Env,   
        path = str   
    ):
        self.repetitions = repetitions
        self.nb_arms = nb_arms
        self.nb_break_points = nb_break_points
        self.horizon = horizon
        self.Env = Env.Env

        self.path = path
        
    def __call__(self):
        X_mean, X_std, A_mean, A_std = self.mainAlg()      
        return X_mean, X_std, A_mean, A_std       
    
    def mainAlg(self):
        # table contains a table of rewards for the K arms up to horizon T
        print("Running ... Master")
        A_mem = np.zeros((self.repetitions, self.nb_arms, self.horizon), dtype = int)
        X_mem = np.zeros((self.repetitions, self.horizon), dtype = float)
        rho_hat = Rho_hat(nb_arms=self.nb_arms, horizon=self.horizon)

        for r in tqdm(range(self.repetitions)):
            t = 0
            n = 0
            gt = np.zeros(self.horizon, dtype = float)
            PIt = np.zeros(self.horizon, dtype = int)
            Ut = np.zeros(self.horizon, dtype = float)
            Rt = np.zeros(self.horizon, dtype = float)
            while t < self.horizon:
                t_n = t
                MALG = Malg(n=n, t=t, horizon=self.horizon, nb_arms=self.nb_arms)
                backto=[]
                alg = {}
                g_list = []
                pi_list = []
                for i in range(len(MALG.schedule_list)):
                    alg[str(MALG.schedule_list[i])] = instance(self.nb_arms, MALG.schedule_list[i]["order-m"])
                active_instance = MALG.schedule_list[0]
                while t < t_n + 2**n and t < self.horizon:
                    for i in range(len(MALG.schedule_list)):
                        if t == MALG.schedule_list[i]["alg.s"]:
                            if MALG.schedule_list[i]["alg.e"]<=active_instance["alg.e"] and MALG.schedule_list[i]["order-m"]<active_instance["order-m"]:
                                if active_instance not in backto:     
                                    backto.append(active_instance)
                                active_instance = MALG.schedule_list[i]
                            else:
                                active_instance = MALG.schedule_list[i]
                    if t > active_instance["alg.e"]:
                        active_instance = backto[len(backto)-1]
                        backto.pop()
                    I, g = MALG(alg[str(active_instance)])
                    pi_list.append(I)
                    g_list.append(g)
                    gt[t] = g
                    rew = self.Env[r][I][t]
                    Rt[t] = rew 
                    PIt[t] = I
                    A_mem[r][I][t] = 1
                    X_mem[r][t] = rew
                    alg[str(active_instance)].reward.append(rew)
                    n_instance = alg[str(active_instance)].n
                    alg[str(active_instance)].Z[I][n_instance[I]] = rew
                    alg[str(active_instance)].n[I] = alg[str(active_instance)].n[I] + 1
                    Ut[t] = min(g_list)

                    if t == active_instance["alg.e"]:
                        if np.sum(Rt[active_instance["alg.s"]:active_instance["alg.e"]+1])/(2**active_instance["order-m"]) >= Ut[t] + 9*rho_hat(2**active_instance["order-m"]):
                            test1 = False
                        else:
                            test1 = True
                    
                    if (np.sum(gt[t_n:t+1])-np.sum(Rt[t_n:t+1]))/(t-t_n+1) >= 3*rho_hat(t-t_n+1):
                        test2 = False
                    else:
                        test2 = True
                    t = t+1                  
                    if test1 == False or test2 == False:
                        break
                n = n + 1

        X_mem_cum = np.cumsum(X_mem,axis=1)
        # print("Roll data saving ...")
        # if path != None:
        #     filename = path+'/'+'roll_data'+self.__str__()+str(random.randint(100000,999999))+'.csv'
        #     np.savetxt(filename, X_mem_cum ,delimiter=",")
        X_mean = np.mean(X_mem, axis=0)
        X_std = np.std(X_mem_cum, axis=0)
        A_mem_cum = np.cumsum(A_mem,axis=2)
        A_std = np.std(A_mem_cum, axis=0)
        A_mean = np.mean(A_mem_cum, axis=0)
        # self.plot_arm(A_mean, A_std)
        return X_mean, X_std, A_mean, A_std  
    
    def __str__(self):
        return r"Master+UCB"