import torch
import json
import numpy as np

class MORLScalarizer:
    
    def __init__(self, func,weight_file,uncertainty_weight=0.5,soft_max_min_temperature=1.0):
        func_dict = {
            "max_min": self.max_min,
            "minimax": self.max_min,
            "worst_case": self.max_min,
            "soft_max_min": self.soft_max_min,
            "soft_minimax": self.soft_max_min,
            "max_avg": self.linear,
            "linear": self.linear,
            "zero_syco": self.zero_syco,
            "uncertainty_weighted": self.uncertainty_weighted,
            "lower_third": self.lower_third,
            "max_median": self.max_median,
            "bernoulli_nash": self.bernoulli_nash,
        }
        self.func = func_dict[func]
        self.uncertainty_weight = uncertainty_weight
        self.temperature = soft_max_min_temperature
        
        self.preference_weights = json.load(open(weight_file))
    

    
    def apply_weighting(self, rewards):
        weighted_rewards = {}
        for key, value in rewards.items():
            weight = self.preference_weights[key]
            weighted_rewards[key] = value * weight
        return weighted_rewards


    def scalarize(self,rewards):
        transformed_rewards = self.apply_weighting(rewards)
        
        
        return self.func(transformed_rewards)
    
    def max_min(self,rewards):

        numpy_array = np.array(list(rewards.values()))
        r = torch.tensor(numpy_array, dtype=torch.float32)
        min_values = torch.min(r, dim=0)[0] 
        return min_values
    
    
    def soft_max_min(self,rewards):

        numpy_array = np.array(list(rewards.values()))
        r = torch.tensor(numpy_array, dtype=torch.float32)
        
        exp_neg_r = torch.exp(-r / self.temperature)
        softmin_values = exp_neg_r / torch.sum(exp_neg_r)
        
        weighted_sum = torch.dot(softmin_values, r)
        
        return weighted_sum.item()
    
    def linear(self,rewards):

        numpy_array = np.array(list(rewards.values()))
        r = torch.tensor(numpy_array, dtype=torch.float32)
        #print("lin",r)
        return r.sum(dim=0)
    def uncertainty_weighted(self,rewards):

        numpy_array = np.array(list(rewards.values()))
        r = torch.tensor(numpy_array, dtype=torch.float32)
        k = len(r)
        mean_reward = np.mean(r)
        variance = np.sum((r - mean_reward) ** 2) / k
        uwo_reward = mean_reward - self.uncertainty_weight * variance
        return uwo_reward
        
    def zero_syco(self,rewards):

        rewards["sycophancy"] = 0
        r = torch.tensor(list(rewards.values()), dtype=torch.float32)
        return torch.sum(r).item()
    
    def lower_third(self,rewards):

        numpy_array = np.array(list(rewards.values()))
        r = torch.tensor(numpy_array, dtype=torch.float32)
        sorted_r = torch.sort(r).values
        lower_third = sorted_r[:len(sorted_r)//3]
        return lower_third.mean().item()
    def max_median(self,rewards):

        numpy_array = np.array(list(rewards.values()))
        r = torch.tensor(numpy_array, dtype=torch.float32)
        return torch.median(r).item()
    def bernoulli_nash(self,rewards):

        numpy_array = np.array(list(rewards.values()))
        r = torch.tensor(numpy_array, dtype=torch.float32)
        return torch.prod(r).item()
    
