from gym import Wrapper
from gym.spaces.box import Box
import gym
import d4rl
import environments
import numpy as np
import pickle

class MOOfflineEnv(Wrapper):
    def __init__(self, env_name, dataset='d4rl', num_objective=2, reset_noise_scale=1e-3):
        self.num_objective = num_objective
        self.env_name = env_name
        self.dataset = dataset
        if dataset == 'd4rl':
            env = gym.make(f"{self.env_name.lower()}-medium-v2")
        else:
            env = gym.make(f"{env_name}")
        env.reward_space = np.zeros((num_objective,))
        self._max_episode_steps = env._max_episode_steps
        super(MOOfflineEnv, self).__init__(env)

        if self.env.spec.id=='MO-Hopper-v2':
            self.action_space = Box(low=-1.0, high=1.0, shape=(3,), dtype=np.float32)

    def cal_reward_in_d4rl(self, action):
        xposbefore = self.env.sim.data.qpos[0]
        obs, reward, done, info = self.env.step(action)
        xposafter, height, ang = self.sim.data.qpos[0:3]
        alive_bonus = 1.0
        reward1 = (xposafter - xposbefore)/self.env.dt + alive_bonus
        reward2 = 4.0 - 1.0 * np.square(action).sum() + alive_bonus
        if self.num_objective == 1:
            reward = np.array([reward])
        else:
            reward = np.array([reward1, reward2])
        return obs, reward, done, info

    def step(self, action):
        if self.env.spec.id=='MO-Hopper-v2':
            action = action * np.array([2, 2, 4])

        if self.dataset == 'd4rl':
            obs, reward, done, info = self.cal_reward_in_d4rl(action)
        else:
            obs, reward, done, info = self.env.step(action)
        return obs, reward, done, info  
    
    def get_dataset(self, dataset_type):
        if self.dataset == 'd4rl':
            env = gym.make(f"{self.env_name.lower()}-{dataset_type}-v2")
            dataset = env.get_dataset()
            return self.d4rl2d4morl(dataset, env)
        elif self.dataset == 'd4morl':
            dataset_path = f"./data/{self.env_name}/{self.env_name}_50000_{dataset_type}.pkl"
            with open(dataset_path, 'rb') as f:
                dataset = pickle.load(f)
            return dataset
    
    def d4rl2d4morl(self, dataset, env, max_episode_len=1000):
        def calculate_reward_vector(info_qpos, action, timeouts, terminals ): 
            N = action.shape[0]
            reward1, reward2 = [], []
            episode_step = 0
            for i in range(N):
                alive_bonus = 1.0
                xposafter, xposbefore = info_qpos[min(i+1, N-1), 0], info_qpos[i, 0]
                r1 = (xposafter - xposbefore)/env.dt + alive_bonus
                r2 = 4.0 - 1.0 * np.square(action[i]).sum() + alive_bonus

                if bool(terminals[i]) or timeouts[i]:
                    r1 = reward1[-1]  # Use the previous reward to approximate the current reward in the last step
                    episode_step = 0

                reward1.append(r1)
                reward2.append(r2)
                episode_step += 1

            reward1, reward2 = np.array(reward1), np.array(reward2)
            return np.stack([reward1, reward2], axis=1)

        reward = calculate_reward_vector(dataset['infos/qpos'], dataset['actions'], 
                                                    dataset['timeouts'], dataset['terminals'])
        assert len(reward)==len(dataset['observations'])
        step = 0
        ret = np.zeros((reward.shape[-1]))
        observations, actions, next_observations, raw_rewards, terminals = [], [], [], [], []
        d4morl_dataset = []
        for i, rw in enumerate(reward):
            step += 1
            ret += rw
            observations.append(dataset['observations'][i])
            actions.append(dataset['actions'][i])
            next_observations.append(dataset['next_observations'][i])
            raw_rewards.append(rw)

            if dataset['terminals'][i] or dataset['timeouts'][i] or step==max_episode_len:
                terminals.append(True)
                preference = (ret / np.linalg.norm(ret, ord=1)).reshape(1, -1).repeat(len(raw_rewards), 0)
                d4morl_dataset.append({
                    'observations': np.array(observations),
                    'actions': np.array(actions),
                    'next_observations': np.array(next_observations),
                    'raw_rewards': np.array(raw_rewards),
                    'terminals': np.array(terminals),
                    'preference': preference,
                }) 
                step = 0
                ret = np.zeros((reward.shape[-1]))
                observations, actions, next_observations, raw_rewards, terminals = [], [], [], [], []
            else:
                terminals.append(False)
        return d4morl_dataset

    def get_normalized_score(self, tot_rewards):
        if self.num_objective == 1:
            return np.array([self.env.get_normalized_score(tot_rewards[0])])
        else:
            return tot_rewards
        
        
def normalization(x, mean, std):
    x = (x - mean) / std
    return x

def min_max_normalization(x, min, max, eps=1e-8):
    x = (x - min) / (max - min + eps)
    return x


def normalize_rewards(raw_rewards, mode="minmax", reward_min=None, reward_max=None,
                      reward_mean=None, reward_std=None, reward_return_scale=None, 
                      reward_max_scale=None, reward_balanced_scale=None, eps=1e-8):
    """Normalize reward vectors using the specified mode.
    
    Modes:
        - none: No normalization
        - minmax: (x - min) / (max - min)
        - zscore: (x - mean) / std
        - return: x * (1 / mean_return) - scales by inverse mean trajectory return
        - maxscale: x * (1 / max_reward) - scales so max reward = 1 for each objective
        - balanced: x * scale where scale equalizes total objective contributions
                   (scales so all objectives have same expected total return)
    """
    if mode == "none":
        return raw_rewards
    if mode == "minmax":
        if reward_min is None or reward_max is None:
            return raw_rewards
        return min_max_normalization(raw_rewards, reward_min, reward_max, eps)
    if mode == "zscore":
        if reward_mean is None or reward_std is None:
            return raw_rewards
        return (raw_rewards - reward_mean) / (reward_std + eps)
    if mode == "return":
        if reward_return_scale is None:
            return raw_rewards
        return raw_rewards * reward_return_scale
    if mode == "maxscale":
        if reward_max_scale is None:
            return raw_rewards
        return raw_rewards * reward_max_scale
    if mode == "balanced":
        # Scale so all objectives have equal expected total contribution
        # This scales ore UP to match fuel, rather than scaling fuel down
        if reward_balanced_scale is None:
            return raw_rewards
        return raw_rewards * reward_balanced_scale
    raise ValueError(f"Unknown reward normalization mode: {mode}")


def social_welfare(returns, alpha, eps=1e-5):
    """Compute social welfare with parameter alpha (isoelastic utility).
    
    Args:
        returns: Array of returns per objective
        alpha: Fairness parameter
            alpha=1.0: NSW (Nash Social Welfare) = sum(log(returns))
            alpha=0.0: USW (Utilitarian Social Welfare) = sum(returns)
            0 < alpha < 1: Interpolates between USW and NSW
        eps: Small value for numerical stability
    
    Returns:
        Social welfare score (higher is better)
    """
    returns = np.maximum(returns, eps)
    
    if abs(alpha - 1.0) < 1e-6:
        # NSW: sum of log returns
        return np.sum(np.log(returns))
    else:
        # Isoelastic utility: sum(x^(1-alpha)) / (1-alpha)
        # For alpha=0: sum(x), for alpha→1: sum(log(x))
        return np.sum(np.power(returns, 1 - alpha)) / (1 - alpha)