from .base_model import BaseMOMAB
from utils import pareto_front, sherman_morrison
import numpy as np

class MOLB_Greedy(BaseMOMAB):
    def __init__(self, d, m, delta = 0.01 ,noise = 0.1, epsilon = 0.01, lam = None, name = "MOLB_Greedy"):
        super().__init__(m = m)
        self.noise = noise
        self.delta = delta
        
        if lam==None:
            self.lam=1
        else:
            self.lam=lam
        self.d = d
        self.epsilon = epsilon
        self.explore = 1
        # calculation
        self.hat_theta = np.zeros(self.d)
        self.V = self.lam*np.eye(d)
        self.Vinv=(1/self.lam)*np.eye(d)
        self.b_t = np.zeros((self.m, self.d))
        self.theta_hat = np.array([np.zeros(d) for _ in range(self.m)])
        
        self.settings = {'lambda' : self.lam}
        self.name = name

    def pareto_front(self, Y):
        K= Y.shape[0]
        pareto_index = [i for i in range(K)]
        for i in range(K):
            for j in pareto_index:
                if np.max(Y[i,:] - Y[j,:]) < 0:
                    pareto_index.remove(i)
                    break    
        return pareto_index

    def select_ac(self, contexts):
        p = np.random.random()
        if p < self.epsilon:
            self.explore += 1
            return np.random.randint(contexts.shape[0])

        means = np.matmul(contexts, self.theta_hat.T)
        empirical_pareto_front = self.pareto_front(means)
        idx = np.random.choice(empirical_pareto_front)
        
        return idx
    
    def update(self,reward, context):
        self.b_t += reward.reshape(-1,1) * context
        self.theta_hat = (self.Vinv @ self.b_t.T).T
        self.t+=1
        
        