import numpy as np
from functools import reduce

def mat_mul(*matrices):
    #helper function for matrix products
    return reduce(np.dot, list(matrices))

class Vector:
    def __init__(self, d, sigma, delta):
        self.theta_hat = np.zeros((d, 1)) # the center (empirical estimate) of the ball
        self.V = np.eye(d)
        self.X = np.zeros((d, 1))
        self.Y = np.zeros((1, 1))
        self.sigma = sigma
        self.delta = delta
        self.theta_norm = 1
        self.d = d
        self.num_queries = 0
        
    def beta(self):
        return self.sigma*np.sqrt(self.d*np.log((1+self.num_queries)/self.delta))+self.theta_norm
    
    def update(self, a, signal):
        self.V += np.matmul(a, a.T)
        self.X = np.append(self.X, a, axis = 1)
        self.Y = np.vstack((self.Y, np.array(signal)))
        self.num_queries += 1
        self.theta_hat = mat_mul(np.linalg.inv(self.V), self.X, self.Y)
        
    def tilde(self, a):
        x = np.sqrt(self.beta()/mat_mul(a.T, self.V, a)[0])*a
        return self.theta_hat+x
    
    def uncertainty(self, a):
        return np.sqrt(self.beta()/float(mat_mul(a.T, self.V, a)))

class RestrictedVector():
    def __init__(self, d, delta):
        self.theta_hat = np.zeros((d, 1)) # the center (empirical estimate) of the ball
        self.V = np.eye(d)
        self.X = np.zeros((d, 1))
        self.Y = np.zeros((1, 1))
        self.delta = delta
        self.theta_norm = 1
        self.d = d
        self.num_queries = 0
        
    def beta(self, sigma):
        return sigma*np.sqrt(self.d*np.log((1+self.num_queries)/self.delta))+self.theta_norm
    
    def update(self, a, signal):
        self.V += np.matmul(a, a.T)
        self.X = np.append(self.X, a, axis = 1)
        self.Y = np.vstack((self.Y, np.array(signal)))
        self.num_queries += 1
        self.theta_hat = mat_mul(np.linalg.inv(self.V), self.X, self.Y)
        
    def tilde(self, a, sigma):
        x = np.sqrt(self.beta(sigma)/mat_mul(a.T, self.V, a)[0])*a
        return self.theta_hat+x
    
    def uncertainty(self, a, sigma):
        return np.sqrt(self.beta(sigma)/float(mat_mul(a.T, self.V, a)))

class LinUCB():
    def __init__(self, dim, sigma, delta):
        self.d = dim
        self.sigma = sigma
        self.delta = delta
        self.restart()
    
    def restart(self):
        self.theta = Vector(self.d, self.sigma, self.delta)
        self.count2 = 0
        self.count3 = []
        self.regret = [0]

    def update_theta(self, arm, reward):
        self.theta.update(arm, reward)

    def get_UCB(self, arm):
        return np.dot(self.theta.tilde(arm).T, arm)
    
    def get_best_arm(self, arms, arms_z, theta_z):
        evals = [self.get_UCB(arms[i]) + np.dot(theta_z.T, arms_z[i]) for i in range(len(arms))]
        best_arm_idx = np.argmax(evals)
        return best_arm_idx
    
    def iterate(self, arms, theta_z, arms_z, rewards, true_best_arm_idx):
        action_idx = self.get_best_arm(arms, arms_z, theta_z)
        pseudoreward = rewards[action_idx] - np.dot(theta_z.T, arms_z[action_idx])
        #print(arms[action_idx].shape)
        self.update_theta(arms[action_idx], pseudoreward)
        instantaneous_regret = rewards[true_best_arm_idx] - rewards[action_idx]
        self.regret += [self.regret[-1] + instantaneous_regret]
        if action_idx != true_best_arm_idx:
            self.count2 += 1.0
            self.count3 += [instantaneous_regret]

    def get_regret(self):
        return self.regret

class RestrictedLinUCB():
    def __init__(self, dim, delta, sg_dict):
        self.d = dim
        self.delta = delta
        self.sg_dict = sg_dict
        self.restart()
    
    def restart(self):
        self.theta = RestrictedVector(self.d, self.delta)
        self.max_sg_over_time = -10
        self.count = []
        self.count2 = 0
        self.count3 = []
        self.regret = [0]

    def get_sg(self, m):
        if m > 0.5: m = 1-m
        def myround(x, prec=4, base=.0005):
            return round(base * round(float(x)/base),prec)
        m = myround(m)
        return self.sg_dict[m]

    def update_theta(self, arm, reward):
        self.theta.update(arm, reward)

    def get_UCB(self, arm):
        return np.dot(self.theta.tilde(arm, self.max_sg_over_time).T, arm)
    
    def get_best_arm(self, arms, theta_z, arms_z, upper, lower, pruned_ids):
        evals = -100.0*np.ones(len(arms))
        for i in range(len(arms)):
            if i in pruned_ids:
                evals[i] = np.clip(self.get_UCB(arms[i]), lower[i], upper[i]) + np.dot(theta_z.T, arms_z[i])
        best_arm_idx = np.argmax(evals)
        if best_arm_idx.size > 1: best_arm_idx = best_arm_idx[0]
        return best_arm_idx

    def prune(self, arms, upper, lower):
        # Prune all arms below lmax
        lmax = np.argmax(lower)
        pruned_ids1 = [i for i in range(len(arms)) if upper[i]>=lower[lmax]]
        # Prune all arms more than 2*alpha away from the most promising arm
        good_arm = arms[lmax]
        def arg(arm1, arm2):
            dp = np.dot(arm1.T, arm2)
            dp = dp/np.linalg.norm(dp)
            return np.arccos(np.clip(dp,-1,1))
        lvl = 2*np.arccos(lower[lmax])
        pruned_ids = [i for i in pruned_ids1 if arg(arms[i], good_arm)<=lvl]
        return pruned_ids

    def iterate(self, all_arms, theta_z, arms_z, upper_bounds, lower_bounds, rewards, true_best_arm_idx):
        #upper1, lower1 = self.refine_bounds(all_arms, upper_bounds, lower_bounds)
        pruned_ids = self.prune(all_arms, upper_bounds, lower_bounds)

        sg = np.zeros(len(pruned_ids))
        for i in range(len(pruned_ids)):
            if np.abs(lower_bounds[pruned_ids[i]])>0.5:
                sg[i] = self.get_sg(np.abs(lower_bounds[pruned_ids[i]]))
            elif np.abs(upper_bounds[pruned_ids[i]])<0.5:
                sg[i] = self.get_sg(np.abs(upper_bounds[pruned_ids[i]]))
            else:
                sg[i] = 0.25
        
        if sg.size>1:
            self.count += [len(sg)]
            max_sg = np.sqrt(np.amax(sg))
            self.max_sg_over_time = max(self.max_sg_over_time, max_sg)

            action_idx = self.get_best_arm(all_arms, theta_z, arms_z, upper_bounds, lower_bounds, pruned_ids)
            pseudoreward = rewards[action_idx] - np.dot(theta_z.T, arms_z[action_idx])
            self.update_theta(all_arms[action_idx], pseudoreward)
            instantaneous_regret = rewards[true_best_arm_idx] - rewards[action_idx]
            self.regret += [self.regret[-1] + instantaneous_regret]
            if action_idx != true_best_arm_idx: 
                self.count2 += 1
                self.count3 += [instantaneous_regret]
        else:
            self.max_sg_over_times = max(self.max_sg_over_time, np.sqrt(sg))
            arm, reward = all_arms[pruned_ids[0]], rewards[pruned_ids[0]] 
            pseudoreward = reward - np.dot(theta_z.T, arms_z[pruned_ids[0]])
            self.update_theta(arm, pseudoreward)
            instantaneous_regret = rewards[true_best_arm_idx] - reward
            self.regret += [self.regret[-1] + instantaneous_regret]
            if pruned_ids[0] != true_best_arm_idx:
                self.count2 += 1

    def get_regret(self):
        return self.regret
