import numpy as np
import random

def sherman_morrison(X, V, w=1):
    tmp = V@X
    result = V-(tmp.reshape(-1, 1) @ tmp.reshape(1, -1) )/(1.+ X @ tmp)
    return result

def l2norm(arr):
    return np.sqrt(np.sum(arr ** 2))

def wl2normsq(arr, V):
    return arr @ V @ arr

class Algorithm:
    def __init__(self, d, name):
        self.d = d
        self.name = name

    def updlogdet(self, logdet, Vinv, Xt):
        return logdet + np.log(1 + wl2normsq(Xt, Vinv))
    
    def greedy(self, arms, theta_hat):
        return np.argmax(arms @ theta_hat)

class Greedy(Algorithm):
    def __init__(self, d):
        super().__init__(d, 'Greedy')
    
    def reset(self):
        self.theta_hat = np.zeros(self.d)
        self.b = np.zeros(self.d)
        self.Vinv = np.eye(self.d)


    def select_ac(self, arms):
        a_t=self.greedy(arms, self.theta_hat)
        self.X_a=arms[a_t]
        return a_t

    def update(self, reward):
        self.b += reward * self.X_a
        self.Vinv = sherman_morrison(self.X_a, self.Vinv)
        self.theta_hat=np.dot(self.Vinv, self.b)

class LinTS(Algorithm):
    def __init__(self, d, sigma, delta, S):
        super().__init__(d, 'LinTS')
        self.sigma = sigma
        self.delta = delta
        self.S = S
    
    def reset(self):
        self.theta_hat = np.zeros(self.d)
        self.b = np.zeros(self.d)
        self.Vinv = np.eye(self.d)
        self.alpha = 2 * np.log(1 / self.delta)
        self.beta = self.sigma * np.sqrt( self.alpha ) + self.S

    def select_ac(self, arms):
        Sigma=(self.beta**2)*self.Vinv
        theta_tilde=np.random.multivariate_normal(self.theta_hat, Sigma)
        a_t=self.greedy(arms, theta_tilde)
        self.X_a=arms[a_t]
        return a_t

    def update(self, reward):
        self.b += reward * self.X_a
        self.alpha = self.updlogdet(self.alpha, self.Vinv, self.X_a)
        self.beta = self.sigma * np.sqrt(self.alpha) + self.S
        self.Vinv = sherman_morrison(self.X_a, self.Vinv)
        self.theta_hat=np.dot(self.Vinv, self.b)

        
class LinUCB(Algorithm):
    def __init__(self, d, sigma, delta, S):
        super().__init__(d, 'LinUCB')
        self.sigma = sigma
        self.delta = delta
        self.S = S
        
    def reset(self):
        self.theta_hat = np.zeros(self.d)
        self.b = np.zeros(self.d)
        self.Vinv = np.eye(self.d)
        self.alpha = 2 * np.log(1 / self.delta)
        self.beta = self.sigma * np.sqrt( self.alpha ) + self.S

    def select_ac(self, arms):
        ucb = arms @ self.theta_hat + self.beta * np.sum( ( arms @ self.Vinv) * arms, axis = 1)
        a_t= np.argmax(ucb)
        self.X_a=arms[a_t]
        return a_t

    def update(self, reward):
        self.b += reward * self.X_a
        self.alpha = self.updlogdet(self.alpha, self.Vinv, self.X_a)
        self.beta = self.sigma * np.sqrt(self.alpha) + self.S
        self.Vinv = sherman_morrison(self.X_a, self.Vinv)
        self.theta_hat=np.dot(self.Vinv, self.b)
    
class INFEX(Algorithm):
    def __init__(self, d, alg, sigma, delta, S, m):
        self.alg = alg(d, sigma, delta, S)
        super().__init__(d, f'INFEX({self.alg.name}, {m})')
        self.M = m
    
    def reset(self):
        # self.theta_hat = np.zeros(self.d)
        # self.b = np.zeros(self.d)
        # self.Vinv = np.eye(self.d)
        self.t = 0
        self.alg.reset()

    def select_ac(self, arms):
        self.t += 1
        if self.t % self.M == 0:
            a_t = self.alg.select_ac(arms)
        else:
            a_t = self.greedy(arms, self.alg.theta_hat)
        
        # self.X_a=arms[a_t]
        self.alg.X_a = arms[a_t]
        return a_t
    
    
    def update(self, reward):
        # self.b += reward * self.X_a
        # self.Vinv = sherman_morrison(self.X_a, self.Vinv)
        # self.theta_hat=np.dot(self.Vinv, self.b)
        self.alg.update(reward)

class OLSBandit(Algorithm):
    def __init__(self, d, q, h):
        super().__init__(d, 'OLSBandit')
        self.q = q
        self.h = h
    
    def reset(self):
        self.Vinvtilde = 1000 * np.eye(self.d) # for computational statibility
        self.btilde = np.zeros(self.d)
        self.theta_tilde = np.zeros(self.d)

        self.Vinvhat = 1000 * np.eye(self.d) # for computational statibility
        self.bhat = np.zeros(self.d)
        self.theta_hat = np.zeros(self.d)

        self.t = 0
        self.explore_cnt = 0
        self.explore = False
    

    def select_ac(self, arms):
        self.t += 1
        if np.exp(self.q * self.explore_cnt) - 1 < self.t:
            self.explore = True
            a_t = random.choice(range(len(arms)))
            while np.exp(self.q * self.explore_cnt) - 1 < self.t:
                self.explore_cnt += 1
        else:
            self.explore = False
            est1 = arms @ self.theta_tilde
            best2nd, best = np.argpartition(est1, -2)[-2:]
            if est1[best] - est1[best2nd] > self.h / 2:
                a_t = np.argmax(arms @ self.theta_hat)
            else:
                a_t = best
        
        self.X_a = arms[a_t]
        return a_t
    
    
    def update(self, reward):
        self.bhat += reward * self.X_a
        self.Vinvhat = sherman_morrison(self.X_a, self.Vinvhat)
        self.theta_hat=np.dot(self.Vinvhat, self.bhat)
        
        if self.explore:
            self.btilde += reward * self.X_a
            self.Vinvtilde = sherman_morrison(self.X_a, self.Vinvtilde)
            self.theta_tilde = np.dot(self.Vinvtilde, self.btilde)


class epsGreedy(Algorithm):
    def __init__(self, d):
        super().__init__(d, 'ε-Greedy')
    
    def reset(self):
        self.theta_hat = np.zeros(self.d)
        self.b = np.zeros(self.d)
        self.Vinv = np.eye(self.d)
        self.t = 0


    def select_ac(self, arms):
        self.t += 1
        w = self.t ** (- 1/3)
        explore = random.choices([False, True], weights = [1 - w, w])[0]
        if explore:
            a_t = random.choice(range(len(arms)))
        else:
            a_t=self.greedy(arms, self.theta_hat)
        self.X_a=arms[a_t]
        return a_t

    def update(self, reward):
        self.b += reward * self.X_a
        self.Vinv = sherman_morrison(self.X_a, self.Vinv)
        self.theta_hat=np.dot(self.Vinv, self.b)