from .base_model import BaseMOMAB
from utils import pareto_front, sherman_morrison
import numpy as np

class MOLB_UCB(BaseMOMAB):
    def __init__(self, d, m, delta = 0.01 ,noise = 0.1, lam = None, name = "MOLB_UCB"):
        super().__init__(m = m)
        self.noise = noise
        self.delta = delta
        
        if lam==None:
            self.lam=1
        else:
            self.lam=lam
        self.d = d
        # calculation
        self.hat_theta = np.zeros(self.d)
        self.V = self.lam*np.eye(d)
        self.Vinv=(1/self.lam)*np.eye(d)
        self.b_t = np.zeros((self.m, self.d))
        self.theta_hat = np.array([np.zeros(d) for _ in range(self.m)])
        self.bound_hat = np.sqrt(self.lam) + self.noise * np.sqrt(self.cal_log_1())


        self.name = name
        self.settings = {'lambda' : self.lam}
        
    def cal_log_1(self):
        return self.d * np.log(self.m * (1 + (self.t-1)/ (self.d*self.lam)) / self.delta )
    

    def pareto_front(self, Y):
        K= Y.shape[0]
        pareto_index = [i for i in range(K)]
        for i in range(K):
            for j in pareto_index:
                if np.max(Y[i,:] - Y[j,:]) < 0:
                    pareto_index.remove(i)
                    break    
        return pareto_index

    def select_ac(self, contexts):
        means = np.matmul(contexts, self.theta_hat.T)
        stds = np.array([([np.sqrt(X.T @ self.Vinv@ X) for X in contexts]) for j in range(self.m)])
        ucbs = means + 1*stds.T
        empirical_pareto_front = self.pareto_front(ucbs)
        idx = np.random.choice(empirical_pareto_front)

        self.Vinv = sherman_morrison(contexts[idx], self.Vinv)
        return idx
    
    def update(self,reward, context):
        self.b_t += reward.reshape(-1,1) * context
        self.theta_hat = (self.Vinv @ self.b_t.T).T
        self.t+=1
        