from .base_env import BaseEnvMOMAB
import numpy as np
from utils import initial_thetalist, generate_diverse_theta, uniform_cxts, pareto_front, pareto_front_v2

class Multiobjective_env(BaseEnvMOMAB):
    def __init__(self, K, num_obj = 3, d = 10 , sig = 0.1,linear = False, contextual = False):
        super().__init__(K = K)
        self.num_obj = num_obj
        self.linear = linear
        self.contextual = contextual
        self.sig = sig
        self.d = d
        if linear:
            # self.theta_list = initial_thetalist(num_obj, d)
            self.theta_list = generate_diverse_theta(d, num_obj)
            self.contexts = uniform_cxts(K, d)
            self.exp_rewards = np.matmul(self.contexts, self.theta_list.T)

        else:
            self.exp_rewards=np.array(np.random.uniform(-1, 1, (K, num_obj)))
        
        self.pareto_idx_old, self.pareto_regret_old = pareto_front(self.exp_rewards)
        self.pareto_idx_ours, self.pareto_regret_ours = pareto_front_v2(self.exp_rewards, self.pareto_idx_old)
        
        # self.pareto_regret_ours
    def warm_up(self):
        return None

    def view_context(self):
        return self.contexts

    def action_reward(self, idx) :
        return self.exp_rewards[idx] + np.random.normal(0, self.sig, size = self.num_obj)
    
    def mean_reward(self, idx) :
        return self.exp_rewards[idx] 
    
    def update_env(self):
        if self.contextual : 
            self.contexts = uniform_cxts(self.K, self.d)
            self.exp_rewards = np.matmul(self.contexts, self.theta_list.T)
            
            self.pareto_idx_old, self.pareto_regret_old = pareto_front(self.exp_rewards)
            self.pareto_idx_ours, self.pareto_regret_ours = pareto_front_v2(self.exp_rewards, self.pareto_idx_old)
        



        
        
