import numpy as np
from scipy.optimize import brentq
from copy import deepcopy
from functools import reduce
#----------------------------------------------------------------------------------------------------------------------------
#----------------------------------------------------------------------------------------------------------------------------
class BUCB(object):
    def __init__(self, ub, lb, avg):
        self.means = avg
        self.num_arms = avg.size
        self.var =  0.25
        self.ub = ub
        self.lb = lb
        self.best_arm = np.argmax(self.means)
        self.restart()
        return None
    
    def restart(self):
        self.time = 0.0
        self.emp_means = np.zeros(self.num_arms)
        self.num_plays = np.zeros(self.num_arms)
        self.ucb_arr = 1e5*np.ones(self.num_arms)
        self.cum_reg = [0]
        return None
    
    def get_best_arm(self):
        if np.argmax(self.ucb_arr).size == 1:
            return np.argmax(self.ucb_arr)
        else:
            return np.argmax(self.ucb_arr)[0]
    
    def update_stats(self, arm, rew):
        self.time += 1.0
        self.num_plays[arm] += 1.0
        self.emp_means[arm] = (self.emp_means[arm]*(self.num_plays[arm]-1.0) + rew)/self.num_plays[arm]
        return None
    
    def update_ucb(self):
        func = 2*self.var*np.log(1 + self.time*(np.log(self.time)**2))
        for i in range(self.num_arms):
            if self.num_plays[i] == 0:
                continue
            else:
                temp = self.emp_means[i] + np.sqrt(func/self.num_plays[i])
                self.ucb_arr[i] = max(self.lb[i], min(temp, self.ub[i]))
        return None
    
    def update_reg(self, arm, rew_vec):
        self.cum_reg += [self.cum_reg[-1] + rew_vec[self.best_arm] - rew_vec[arm]]
        return None
    
    def iterate(self, rew_vec):
        play = self.get_best_arm()
        self.update_stats(play, rew_vec[play])
        self.update_ucb()
        self.update_reg(play, rew_vec)
        return None
#----------------------------------------------------------------------------------------------------------------------------
#----------------------------------------------------------------------------------------------------------------------------
class UCBImproved(object):
    def __init__(self, ub, lb, avg):
        self.means = avg
        self.num_arms = avg.size
        self.ub = ub
        self.lb = lb
        self.sg = 0.25*np.ones(self.num_arms)
        for i in range(self.num_arms):
            if self.lb[i]>0.5:
                self.sg[i] = self.get_sg(self.lb[i])
            elif self.ub[i]<0.5:
                self.sg[i] = self.get_sg(self.ub[i])
        self.best_arm = np.argmax(self.means)
        self.restart()
        return None
    
    def func(self, x, m):
        fx = m*np.exp(x*(1-m)) + (1-m)*np.exp(-x*m)
        gx = m*(1-m)*(np.exp((1-m)*x) - np.exp(-m*x))/fx - (2/x)*np.log(fx)
        return gx
    
    def get_sg(self, m):
        if m > 0.5: 
            m = 1-m
        x = brentq(self.func, 0.0005, 100, args = m)
        sg = (2/x**2)*np.log(m*np.exp((1-m)*x) + (1-m)*np.exp(-m*x))
        return sg
    
    def restart(self):
        self.time = 0.0
        self.emp_means = np.zeros(self.num_arms)
        self.num_plays = np.zeros(self.num_arms)
        self.ucb_arr = 1e5*np.ones(self.num_arms)
        self.cum_reg = [0]
        return None
    
    def get_best_arm(self):
        if np.argmax(self.ucb_arr).size == 1:
            return np.argmax(self.ucb_arr)
        else:
            return np.argmax(self.ucb_arr)[0]
    
    def update_stats(self, arm, rew):
        self.time += 1.0
        self.num_plays[arm] += 1.0
        self.emp_means[arm] = (self.emp_means[arm]*(self.num_plays[arm]-1.0) + rew)/self.num_plays[arm]
        return None
    
    def update_ucb(self):
        for i in range(self.num_arms):
            if self.num_plays[i] == 0:
                self.ucb_arr[i] = self.ub[i]
            else:
                func = 2*self.sg[i]*np.log(1 + self.time*(np.log(self.time)**2))
                temp = self.emp_means[i] + np.sqrt(func/self.num_plays[i])
                self.ucb_arr[i] = max(self.lb[i],min(temp, self.ub[i]))
        return None
    
    def update_reg(self, arm, rew_vec):
        self.cum_reg += [self.cum_reg[-1] + rew_vec[self.best_arm] - rew_vec[arm]]
        return None
    
    def iterate(self, rew_vec):
        play = self.get_best_arm()
        self.update_stats(play, rew_vec[play])
        self.update_ucb()
        self.update_reg(play, rew_vec)
        return None
#----------------------------------------------------------------------------------------------------------------------------
#----------------------------------------------------------------------------------------------------------------------------
class BKL(object):
    def __init__(self, ub, lb, avg):
        self.means = avg
        self.num_arms = avg.size
        self.ub = ub
        self.lb = lb
        self.best_arm = np.argmax(self.means)
        self.max_rew = np.max(self.means)
        self.restart()
        return None
    
    def restart(self):
        self.time = 0.0
        self.emp_means = np.zeros(self.num_arms)
        self.plays = np.zeros(self.num_arms)
        self.ucb_arr = 1e5*np.ones(self.num_arms)
        self.cum_reg = [0]
        return None
        
    def get_best_arm(self):
        if np.argmax(self.ucb_arr).size == 1:
            return np.argmax(self.ucb_arr)
        else:
            return np.argmax(self.ucb_arr)[0]
        
    def update_stats(self, arm, rew):
        self.time += 1.0
        self.plays[arm] += 1.0
        self.emp_means[arm] = (self.emp_means[arm]*(self.plays[arm] -1.0) + rew)/(1.0*self.plays[arm])
        function = np.log(1 + self.time*(np.log(self.time)**2))
        return function
    
    def klfunc(self, p, q):
        if p == 0:
            return np.log(1.0/(1.0-q))
        if p == 1:
            return np.log(1.0/q)
        else:
            return p*np.log(p/q) + (1.0-p)*np.log((1.0-p)/(1.0-q))
          
    def klucb(self, x, d, upper=1, lower=0, precision=1e-6, max_iterations=50):
        value = max(x, lower)
        u = upper
        _count_iteration = 0
        while _count_iteration < max_iterations and u - value > precision:
            _count_iteration += 1
            m = (value + u) / 2.0
            if self.klfunc(x, m) > d:
                u = m
            else:
                value = m
        return (value + u) / 2.0
         
    def update_ucb(self, exp_num):
        for i in range(self.num_arms):
            if self.plays[i] == 0:
                self.ucb_arr[i] = 1e5
            else:
                temp = self.klucb(self.emp_means[i],exp_num/self.plays[i])
                self.ucb_arr[i] = min(temp,self.ub[i])
        return None
    
    def update_reg(self,arm,rew_vec):
        self.cum_reg += [self.cum_reg[-1] + rew_vec[self.best_arm] - rew_vec[arm]]
        return None
    
    def iterate(self, rew_vec):
        play = self.get_best_arm()
        exp_num = self.update_stats(play,rew_vec[play])
        self.update_ucb(exp_num)
        self.update_reg(play, rew_vec)
        return None
#----------------------------------------------------------------------------------------------------------------------------
#----------------------------------------------------------------------------------------------------------------------------
class GLUE(object):
    def __init__(self, ub, lb, avg):
        self.means = avg
        self.num_arms = avg.size
        self.ub = ub
        self.lb = lb
        self.sg = 0.25*np.ones(self.num_arms)
        self.lmax = np.max(self.lb)
        if self.lmax.size>1: self.lmax = self.lmax[0]
        if self.lmax >0.5:
            self.sg = self.get_sg(self.lmax)*np.ones(self.num_arms)
        else:
            for i in range(self.num_arms):
                if self.ub[i]<0.5:
                    self.sg[i] = self.get_sg(self.ub[i])
        self.best_arm = np.argmax(self.means)
        self.restart()
        return None
    
    def func(self, x, m):
        fx = m*np.exp(x*(1-m)) + (1-m)*np.exp(-x*m)
        gx = m*(1-m)*(np.exp((1-m)*x) - np.exp(-m*x))/fx - (2/x)*np.log(fx)
        return gx
    
    def get_sg(self, m):
        if m > 0.5: 
            m = 1-m
        x = brentq(self.func, 0.0005, 100, args = m)
        sg = (2/x**2)*np.log(m*np.exp((1-m)*x) + (1-m)*np.exp(-m*x))
        return sg
    
    def restart(self):
        self.time = 0.0
        self.emp_means = np.zeros(self.num_arms)
        self.num_plays = np.zeros(self.num_arms)
        self.ucb_arr = 1e5*np.ones(self.num_arms)
        self.cum_reg = [0]
        return None
    
    def get_best_arm(self):
        if np.argmax(self.ucb_arr).size == 1:
            return np.argmax(self.ucb_arr)
        else:
            return np.argmax(self.ucb_arr)[0]
    
    def update_stats(self, arm, rew):
        self.time += 1.0
        self.num_plays[arm] += 1.0
        self.emp_means[arm] = (self.emp_means[arm]*(self.num_plays[arm]-1.0) + rew)/self.num_plays[arm]
        return None
    
    def update_ucb(self):
        for i in range(self.num_arms):
            if self.num_plays[i] == 0:
                self.ucb_arr[i] = self.ub[i]
            else:
                func = 2*self.sg[i]*np.log(1 + self.time*(np.log(self.time)**2))
                temp = self.emp_means[i] + np.sqrt(func/self.num_plays[i])
                self.ucb_arr[i] = max(self.lb[i],min(temp, self.ub[i]))
        return None
    
    def update_reg(self, arm, rew_vec):
        self.cum_reg += [self.cum_reg[-1] + rew_vec[self.best_arm] - rew_vec[arm]]
        return None
    
    def iterate(self, rew_vec):
        play = self.get_best_arm()
        self.update_stats(play, rew_vec[play])
        self.update_ucb()
        self.update_reg(play, rew_vec)
        return None
    
def mat_mul(*matrices):
    #helper function for matrix products
    return reduce(np.dot, list(matrices))

class Vector:
    def __init__(self, d, sigma, delta):
        self.theta_hat = np.zeros((d, 1)) # the center (empirical estimate) of the ball
        self.V = np.eye(d)
        self.X = np.zeros((d, 1))
        self.Y = np.zeros((1, 1))
        self.sigma = sigma
        self.delta = delta
        self.theta_norm = 1
        self.d = d
        self.num_queries = 0
        
    def beta(self):
        return self.sigma*np.sqrt(self.d*np.log((1+self.num_queries)/self.delta))+self.theta_norm
    
    def update(self, a, signal):
        self.V += np.matmul(a, a.T)
        self.X = np.append(self.X, a, axis = 1)
        self.Y = np.vstack((self.Y, np.array(signal)))
        self.num_queries += 1
        self.theta_hat = mat_mul(np.linalg.inv(self.V), self.X, self.Y)
        
    def tilde(self, a):
        x = np.sqrt(self.beta()/mat_mul(a.T, self.V, a)[0])*a
        return self.theta_hat+x
    
    def uncertainty(self, a):
        return np.sqrt(self.beta()/float(mat_mul(a.T, self.V, a)))

class RestrictedVector():
    def __init__(self, d, delta):
        self.theta_hat = np.zeros((d, 1)) # the center (empirical estimate) of the ball
        self.V = np.eye(d)
        self.X = np.zeros((d, 1))
        self.Y = np.zeros((1, 1))
        self.delta = delta
        self.theta_norm = 1
        self.d = d
        self.num_queries = 0
        
    def beta(self, sigma):
        return sigma*np.sqrt(self.d*np.log((1+self.num_queries)/self.delta))+self.theta_norm
    
    def update(self, a, signal):
        self.V += np.matmul(a, a.T)
        self.X = np.append(self.X, a, axis = 1)
        self.Y = np.vstack((self.Y, np.array(signal)))
        self.num_queries += 1
        self.theta_hat = mat_mul(np.linalg.inv(self.V), self.X, self.Y)
        
    def tilde(self, a, sigma):
        x = np.sqrt(self.beta(sigma)/mat_mul(a.T, self.V, a)[0])*a
        return self.theta_hat+x
    
    def uncertainty(self, a, sigma):
        return np.sqrt(self.beta(sigma)/float(mat_mul(a.T, self.V, a)))

class LinUCB():
    def __init__(self, dim, sigma, delta):
        self.d = dim
        self.sigma = sigma
        self.delta = delta
        self.restart()
    
    def restart(self):
        self.theta = Vector(self.d, self.sigma, self.delta)
        self.count2 = 0
        self.count3 = []
        self.regret = [0]

    def update_theta(self, arm, reward):
        self.theta.update(arm, reward)

    def get_UCB(self, arm):
        return np.dot(self.theta.tilde(arm).T, arm)
    
    def get_best_arm(self, arms):
        evals = [self.get_UCB(arm) for arm in arms]
        best_arm_idx = np.argmax(evals)
        return best_arm_idx
    
    def iterate(self, arms, rewards, true_best_arm_idx):
        action_idx = self.get_best_arm(arms)
        reward = rewards[action_idx]
        #print(arms[action_idx].shape)
        self.update_theta(arms[action_idx], reward)
        instantaneous_regret = rewards[true_best_arm_idx] - reward
        self.regret += [self.regret[-1] + instantaneous_regret]
        if action_idx != true_best_arm_idx:
            self.count2 += 1.0
            self.count3 += [instantaneous_regret]

    def get_regret(self):
        return self.regret
    

class ClippedLinUCB():
    def __init__(self, dim, sigma, delta):
        self.d = dim
        self.sigma = sigma
        self.delta = delta
        self.restart()
    
    def restart(self):
        self.theta = Vector(self.d, self.sigma, self.delta)
        self.count2 = 0
        self.count3 = []
        self.regret = [0]

    def update_theta(self, arm, reward):
        self.theta.update(arm, reward)

    def get_UCB(self, arm):
        return np.dot(self.theta.tilde(arm).T, arm)
    
    # def get_best_arm(self, arms):
    #     evals = [self.get_UCB(arm) for arm in arms]
    #     best_arm_idx = np.argmax(evals)
    #     return best_arm_idx
    
    def get_best_arm(self, arms, lower, upper):
        evals = -100.0*np.ones(len(arms))
        for i in range(len(arms)):
            evals[i] = np.clip(self.get_UCB(arms[i]), lower[i], upper[i])
        best_arm_idx = np.argmax(evals)
        if best_arm_idx.size > 1: best_arm_idx = best_arm_idx[0]
        return best_arm_idx

    def iterate(self, arms, upper, lower, rewards, true_best_arm_idx):
        action_idx = self.get_best_arm(arms, lower, upper)
        reward = rewards[action_idx]
        #print(arms[action_idx].shape)
        self.update_theta(arms[action_idx], reward)
        instantaneous_regret = rewards[true_best_arm_idx] - reward
        self.regret += [self.regret[-1] + instantaneous_regret]
        if action_idx != true_best_arm_idx:
            self.count2 += 1.0
            self.count3 += [instantaneous_regret]

    def get_regret(self):
        return self.regret

class RestrictedLinUCB():
    def __init__(self, dim, delta, sg_dict):
        self.d = dim
        self.delta = delta
        self.sg_dict = sg_dict
        self.restart()
    
    def restart(self):
        self.theta = RestrictedVector(self.d, self.delta)
        self.max_sg_over_time = -10
        self.count = []
        self.count2 = 0
        self.count3 = []
        self.regret = [0]

    def get_sg(self, m):
        if m > 0.5: m = 1-m
        def myround(x, prec=4, base=.0005):
            return round(base * round(float(x)/base),prec)
        m = myround(m)
        return self.sg_dict[m]

    def update_theta(self, arm, reward):
        self.theta.update(arm, reward)

    def get_UCB(self, arm):
        return np.dot(self.theta.tilde(arm, self.max_sg_over_time).T, arm)
    
    def get_best_arm(self, arms, lower, upper, pruned_ids):
        evals = -100.0*np.ones(len(arms))
        for i in range(len(arms)):
            if i in pruned_ids:
                evals[i] = np.clip(self.get_UCB(arms[i]), lower[i], upper[i])
        best_arm_idx = np.argmax(evals)
        if best_arm_idx.size > 1: best_arm_idx = best_arm_idx[0]
        return best_arm_idx

    def prune(self, arms, upper, lower):
        # Prune all arms below lmax
        lmax = np.argmax(lower)
        pruned_ids1 = [i for i in range(len(arms)) if upper[i]>=lower[lmax]]
        # Prune all arms more than 2*alpha away from the most promising arm
        good_arm = arms[lmax]
        def arg(arm1, arm2):
            dp = np.dot(arm1.T, arm2)
            dp = dp/np.linalg.norm(dp)
            return np.arccos(np.clip(dp,-1,1))
        lvl = 2*np.arccos(lower[lmax])
        pruned_ids = [i for i in pruned_ids1 if arg(arms[i], good_arm)<=lvl]
        return pruned_ids

    def iterate(self, all_arms, upper_bounds, lower_bounds, rewards, true_best_arm_idx):
        #upper1, lower1 = self.refine_bounds(all_arms, upper_bounds, lower_bounds)
        pruned_ids = self.prune(all_arms, upper_bounds, lower_bounds)

        sg = np.zeros(len(pruned_ids))
        for i in range(len(pruned_ids)):
            if np.abs(lower_bounds[pruned_ids[i]])>0.5:
                sg[i] = self.get_sg(np.abs(lower_bounds[pruned_ids[i]]))
            elif np.abs(upper_bounds[pruned_ids[i]])<0.5:
                sg[i] = self.get_sg(np.abs(upper_bounds[pruned_ids[i]]))
            else:
                sg[i] = 0.25
        
        if sg.size>1:
            self.count += [len(sg)]
            max_sg = np.sqrt(np.amax(sg))
            self.max_sg_over_time = max(self.max_sg_over_time, max_sg)

            action_idx = self.get_best_arm(all_arms, lower_bounds, upper_bounds, pruned_ids)
            reward = rewards[action_idx]
            self.update_theta(all_arms[action_idx], reward)
            instantaneous_regret = rewards[true_best_arm_idx] - reward
            self.regret += [self.regret[-1] + instantaneous_regret]
            if action_idx != true_best_arm_idx: 
                self.count2 += 1
                self.count3 += [instantaneous_regret]
        else:
            self.max_sg_over_times = max(self.max_sg_over_time, np.sqrt(sg))
            arm, reward = all_arms[pruned_ids[0]], rewards[pruned_ids[0]]
            self.update_theta(arm, reward)
            instantaneous_regret = rewards[true_best_arm_idx] - reward
            self.regret += [self.regret[-1] + instantaneous_regret]
            if pruned_ids[0] != true_best_arm_idx:
                self.count2 += 1

    def get_regret(self):
        return self.regret
