import numpy as np
from scipy.optimize import brentq
from copy import deepcopy
#----------------------------------------------------------------------------------------------------------------------------
#----------------------------------------------------------------------------------------------------------------------------
class BUCB(object):
    def __init__(self, ub, lb, avg):
        self.means = avg
        self.num_arms = avg.size
        self.var =  0.25
        self.ub = ub
        self.lb = lb
        self.best_arm = np.argmax(self.means)
        self.restart()
        return None
    
    def restart(self):
        self.time = 0.0
        self.emp_means = np.zeros(self.num_arms)
        self.num_plays = np.zeros(self.num_arms)
        self.ucb_arr = 1e5*np.ones(self.num_arms)
        self.cum_reg = [0]
        return None
    
    def get_best_arm(self):
        if np.argmax(self.ucb_arr).size == 1:
            return np.argmax(self.ucb_arr)
        else:
            return np.argmax(self.ucb_arr)[0]
    
    def update_stats(self, arm, rew):
        self.time += 1.0
        self.num_plays[arm] += 1.0
        self.emp_means[arm] = (self.emp_means[arm]*(self.num_plays[arm]-1.0) + rew)/self.num_plays[arm]
        return None
    
    def update_ucb(self):
        func = 2*self.var*np.log(1 + self.time*(np.log(self.time)**2))
        for i in range(self.num_arms):
            if self.num_plays[i] == 0:
                continue
            else:
                temp = self.emp_means[i] + np.sqrt(func/self.num_plays[i])
                self.ucb_arr[i] = max(self.lb[i], min(temp, self.ub[i]))
        return None
    
    def update_reg(self, arm, rew_vec):
        self.cum_reg += [self.cum_reg[-1] + rew_vec[self.best_arm] - rew_vec[arm]]
        return None
    
    def iterate(self, rew_vec):
        play = self.get_best_arm()
        self.update_stats(play, rew_vec[play])
        self.update_ucb()
        self.update_reg(play, rew_vec)
        return None
#----------------------------------------------------------------------------------------------------------------------------
#----------------------------------------------------------------------------------------------------------------------------
class UCBImproved(object):
    def __init__(self, ub, lb, avg):
        self.means = avg
        self.num_arms = avg.size
        self.ub = ub
        self.lb = lb
        self.sg = 0.25*np.ones(self.num_arms)
        for i in range(self.num_arms):
            if self.lb[i]>0.5:
                self.sg[i] = self.get_sg(self.lb[i])
            elif self.ub[i]<0.5:
                self.sg[i] = self.get_sg(self.ub[i])
        self.best_arm = np.argmax(self.means)
        self.restart()
        return None
    
    def func(self, x, m):
        fx = m*np.exp(x*(1-m)) + (1-m)*np.exp(-x*m)
        gx = m*(1-m)*(np.exp((1-m)*x) - np.exp(-m*x))/fx - (2/x)*np.log(fx)
        return gx
    
    def get_sg(self, m):
        if m > 0.5: 
            m = 1-m
        x = brentq(self.func, 0.0005, 100, args = m)
        sg = (2/x**2)*np.log(m*np.exp((1-m)*x) + (1-m)*np.exp(-m*x))
        return sg
    
    def restart(self):
        self.time = 0.0
        self.emp_means = np.zeros(self.num_arms)
        self.num_plays = np.zeros(self.num_arms)
        self.ucb_arr = 1e5*np.ones(self.num_arms)
        self.cum_reg = [0]
        return None
    
    def get_best_arm(self):
        if np.argmax(self.ucb_arr).size == 1:
            return np.argmax(self.ucb_arr)
        else:
            return np.argmax(self.ucb_arr)[0]
    
    def update_stats(self, arm, rew):
        self.time += 1.0
        self.num_plays[arm] += 1.0
        self.emp_means[arm] = (self.emp_means[arm]*(self.num_plays[arm]-1.0) + rew)/self.num_plays[arm]
        return None
    
    def update_ucb(self):
        for i in range(self.num_arms):
            if self.num_plays[i] == 0:
                self.ucb_arr[i] = self.ub[i]
            else:
                func = 2*self.sg[i]*np.log(1 + self.time*(np.log(self.time)**2))
                temp = self.emp_means[i] + np.sqrt(func/self.num_plays[i])
                self.ucb_arr[i] = max(self.lb[i],min(temp, self.ub[i]))
        return None
    
    def update_reg(self, arm, rew_vec):
        self.cum_reg += [self.cum_reg[-1] + rew_vec[self.best_arm] - rew_vec[arm]]
        return None
    
    def iterate(self, rew_vec):
        play = self.get_best_arm()
        self.update_stats(play, rew_vec[play])
        self.update_ucb()
        self.update_reg(play, rew_vec)
        return None
#----------------------------------------------------------------------------------------------------------------------------
#----------------------------------------------------------------------------------------------------------------------------
class BKL(object):
    def __init__(self, ub, lb, avg):
        self.means = avg
        self.num_arms = avg.size
        self.ub = ub
        self.lb = lb
        self.best_arm = np.argmax(self.means)
        self.max_rew = np.max(self.means)
        self.restart()
        return None
    
    def restart(self):
        self.time = 0.0
        self.emp_means = np.zeros(self.num_arms)
        self.plays = np.zeros(self.num_arms)
        self.ucb_arr = 1e5*np.ones(self.num_arms)
        self.cum_reg = [0]
        return None
        
    def get_best_arm(self):
        if np.argmax(self.ucb_arr).size == 1:
            return np.argmax(self.ucb_arr)
        else:
            return np.argmax(self.ucb_arr)[0]
        
    def update_stats(self, arm, rew):
        self.time += 1.0
        self.plays[arm] += 1.0
        self.emp_means[arm] = (self.emp_means[arm]*(self.plays[arm] -1.0) + rew)/(1.0*self.plays[arm])
        function = np.log(1 + self.time*(np.log(self.time)**2))
        return function
    
    def klfunc(self, p, q):
        if p == 0:
            return np.log(1.0/(1.0-q))
        if p == 1:
            return np.log(1.0/q)
        else:
            return p*np.log(p/q) + (1.0-p)*np.log((1.0-p)/(1.0-q))
          
    def klucb(self, x, d, upper=1, lower=0, precision=1e-6, max_iterations=50):
        value = max(x, lower)
        u = upper
        _count_iteration = 0
        while _count_iteration < max_iterations and u - value > precision:
            _count_iteration += 1
            m = (value + u) / 2.0
            if self.klfunc(x, m) > d:
                u = m
            else:
                value = m
        return (value + u) / 2.0
         
    def update_ucb(self, exp_num):
        for i in range(self.num_arms):
            if self.plays[i] == 0:
                self.ucb_arr[i] = 1e5
            else:
                temp = self.klucb(self.emp_means[i],exp_num/self.plays[i])
                self.ucb_arr[i] = min(temp,self.ub[i])
        return None
    
    def update_reg(self,arm,rew_vec):
        self.cum_reg += [self.cum_reg[-1] + rew_vec[self.best_arm] - rew_vec[arm]]
        return None
    
    def iterate(self, rew_vec):
        play = self.get_best_arm()
        exp_num = self.update_stats(play,rew_vec[play])
        self.update_ucb(exp_num)
        self.update_reg(play, rew_vec)
        return None
#----------------------------------------------------------------------------------------------------------------------------
#----------------------------------------------------------------------------------------------------------------------------
class GLUE(object):
    def __init__(self, ub, lb, avg):
        self.means = avg
        self.num_arms = avg.size
        self.ub = ub
        self.lb = lb
        self.sg = 0.25*np.ones(self.num_arms)
        self.lmax = np.max(self.lb)
        if self.lmax.size>1: self.lmax = self.lmax[0]
        if self.lmax >0.5:
            self.sg = self.get_sg(self.lmax)*np.ones(self.num_arms)
        else:
            for i in range(self.num_arms):
                if self.ub[i]<0.5:
                    self.sg[i] = self.get_sg(self.ub[i])
        self.best_arm = np.argmax(self.means)
        self.restart()
        return None
    
    def func(self, x, m):
        fx = m*np.exp(x*(1-m)) + (1-m)*np.exp(-x*m)
        gx = m*(1-m)*(np.exp((1-m)*x) - np.exp(-m*x))/fx - (2/x)*np.log(fx)
        return gx
    
    def get_sg(self, m):
        if m > 0.5: 
            m = 1-m
        x = brentq(self.func, 0.0005, 100, args = m)
        sg = (2/x**2)*np.log(m*np.exp((1-m)*x) + (1-m)*np.exp(-m*x))
        return sg
    
    def restart(self):
        self.time = 0.0
        self.emp_means = np.zeros(self.num_arms)
        self.num_plays = np.zeros(self.num_arms)
        self.ucb_arr = 1e5*np.ones(self.num_arms)
        self.cum_reg = [0]
        return None
    
    def get_best_arm(self):
        if np.argmax(self.ucb_arr).size == 1:
            return np.argmax(self.ucb_arr)
        else:
            return np.argmax(self.ucb_arr)[0]
    
    def update_stats(self, arm, rew):
        self.time += 1.0
        self.num_plays[arm] += 1.0
        self.emp_means[arm] = (self.emp_means[arm]*(self.num_plays[arm]-1.0) + rew)/self.num_plays[arm]
        return None
    
    def update_ucb(self):
        for i in range(self.num_arms):
            if self.num_plays[i] == 0:
                self.ucb_arr[i] = self.ub[i]
            else:
                func = 2*self.sg[i]*np.log(1 + self.time*(np.log(self.time)**2))
                temp = self.emp_means[i] + np.sqrt(func/self.num_plays[i])
                self.ucb_arr[i] = max(self.lb[i],min(temp, self.ub[i]))
        return None
    
    def update_reg(self, arm, rew_vec):
        self.cum_reg += [self.cum_reg[-1] + rew_vec[self.best_arm] - rew_vec[arm]]
        return None
    
    def iterate(self, rew_vec):
        play = self.get_best_arm()
        self.update_stats(play, rew_vec[play])
        self.update_ucb()
        self.update_reg(play, rew_vec)
        return None