import numpy as np
from Base_LB import Base_LB
from Base_GLB import Base_GLB
from Base_SCB import Base_SCB
from math import log, log2, sqrt
from BanditAlgorithm import BanditAlgorithm

class MASTER(BanditAlgorithm):
    def __init__(self, num_actions,horizon, d, delta, r_lambda, S, L, R, model, k_mu = 1, c_mu = 1):
        super().__init__(num_actions, horizon)

        self.d = d
        self.delta = delta
        self.r_lambda = r_lambda
        self.S = S
        self.L = L
        self.R = R
        self.model = model
        self.k_mu = k_mu
        self.c_mu = c_mu
   
        self.num_actions=num_actions
        self.T=horizon
        self.n_init = int(np.floor(np.log2(self.T)))-4



        self.init_params={
            'num_actions':num_actions,
            'horizon':horizon,
            'd':d,
            'delta':delta,
            'r_lambda':r_lambda,
            'S':S,
            'L':L,
            'R':R,
            'model':model,
            'k_mu':k_mu,
            'c_mu':c_mu
        }

        self.c = log(self.T/delta)
        self.theta_hat = np.zeros(self.d)
        self.n = self.n_init
        self.tn = self.t
        self.n_hat = log2(self.T)+1
        self.rho = lambda x: 4 * sqrt(self.d*self.c) * sqrt(self.d*self.c/x)
        self.rho_hat = lambda x: 6 * self.n_hat * self.c * self.rho(x)
        self.policy = [] 
        self.alg_index = 0 
        self.g_tilde = np.array([])
        self.alge = np.array([])
        self.reward = np.array([])
        self.procedure1()

    def select_arm(self, arms):
      
        self.arms=arms
        alg_index = 0
        run_len = 2**self.n
        for i, alg in enumerate(self.policy):
            if alg.s <= self.t - self.tn and self.t - self.tn<= alg.e:
                if alg.len <= run_len:
                    alg_index = i
                    run_len = alg.len
        self.alg_index = alg_index
        chosen_arm, f_tilde, theta_hat = self.policy[self.alg_index].select_arm(arms)
        self.g_tilde = np.append(self.g_tilde, f_tilde)
        self.theta_hat = theta_hat
        return chosen_arm

    def update_statistics(self,x, y):

        self.reward = np.append(self.reward, y)
        self.policy[self.alg_index].update_state(self.arms[x], y)
        t1 = self.test1()
        t2 = self.test2()
        if t1 == 0 or t2 == 0:
            self.n = self.n_init
            print('MASTER HAS RESTARTED')
            self.ChangePoints.append(self.t)
            self.restart()
        else:
            if self.t+1>= self.tn+2**self.n:
                self.n +=1
                self.restart()

    def restart(self):
        self.reward = np.array([])
        self.g_tilde = np.array([])
        self.tn = self.t
        self.procedure1()

    def re_init(self):

        super().re_init()
        self.ChangePoints=[]
        self.n = self.n_init
        self.tn = self.t
        self.policy = []
        self.alg_index = 0
        self.g_tilde = np.array([])
        self.alge = np.array([])
        self.reward = np.array([])
        self.procedure1()

    def procedure1(self):
        self.policy  = []
        for tau in range(0,2**self.n):
            for i in range(0,self.n+1):
                m = self.n - i
                if tau % 2**m == 0 and np.random.binomial(n=1, p=self.rho(2**self.n)/self.rho(2**m)):
                    algs = tau
                    alge = tau + 2**m - 1
                    if self.model == 'LB':
                        self.policy.append(Base_LB(self.d, self.delta, self.r_lambda, self.S, self.L, self.R, algs, alge))
                    elif self.model == 'GLB':
                        self.policy.append(
                            Base_GLB(self.d, self.delta, self.r_lambda, self.S, self.L, self.R, self.k_mu, self.c_mu, algs, alge))
                    elif self.model == 'SCB':
                        self.policy.append(
                            Base_SCB(self.d, self.delta, self.r_lambda, self.S, self.L, self.R, self.k_mu, self.c_mu, algs, alge))
                    self.alge = np.append(self.alge, alge)

    def test1(self):
        t1 = 1
        U = min(self.g_tilde)
        for alg in self.policy:
            if self.t - self.tn == alg.e:
                R_sum = self.reward[alg.s:alg.e+1].sum()
                if R_sum/alg.len >= U + 9*self.rho_hat(alg.len):
                    t1 = 0
        return t1

    def test2(self):
        t2 = 1
        a = self.g_tilde - self.reward
        if a.sum()/(self.t - self.tn + 1) >= 3 * self.rho_hat(self.t-self.tn+1):
            t2 = 0
        return t2

    def __str__(self):
        name = 'MASTER'
        if self.model == 'LB':
            name = 'LB-MASTER-n='+str(self.n_init)
        elif self.model == 'GLB':
            name = 'GLB-MASTER'
        elif self.model == 'SCB':
            name = 'SCB-MASTER'
        return name