import copy
import gym
import d4rl

import numpy as np
from onpolicy.debug import debug_print
import psutil

class D4RLEnv:
    def __init__(self, all_args):
        self.scenario_name = all_args.scenario_name
        self.algorithm_name = all_args.algorithm_name

        # print('cpu:', psutil.Process().cpu_num())
        self._env = gym.make(self.scenario_name)

        self.action_space = [[self._env.action_space]]
        self.observation_space = [self._env.observation_space]
        self.share_observation_space = [self._env.observation_space]
        self.gt_act_dim = self._env.action_space.shape[0]

        if self.algorithm_name in ["diff-gail", "diff-infogail"]:
            self.action_space = [[gym.spaces.Box(low=np.concatenate([self._env.action_space.low, self._env.action_space.low]),
                                                 high=np.concatenate([self._env.action_space.high, self._env.action_space.high])),
                                  gym.spaces.Discrete(2)]]
                                  
        if self.algorithm_name in ["infogail", "diff-infogail", "diayn"]:
            self.z_latent_dim = all_args.z_latent_dim
            self.sample_z_interval = all_args.sample_z_interval
            # extend the observation space with z latent
            self.observation_space = [gym.spaces.Box(np.concatenate([-np.inf * np.ones(self.z_latent_dim), self._env.observation_space.low]),
                                                     np.concatenate([+np.inf * np.ones(self.z_latent_dim), self._env.observation_space.high]))]
            self.share_observation_space = copy.deepcopy(self.observation_space)
            self.z = np.random.randn(self.z_latent_dim)
        # debug_print(self.observation_space, self.share_observation_space, self.action_space)
        normalization_path = all_args.normalization_path
        self.normalize = normalization_path is not None
        if self.normalize:
            normalization = np.load(normalization_path)
            self.obs_min = normalization["obs_min"]
            self.obs_max = normalization["obs_max"]
            self.action_min = normalization["action_min"]
            self.action_max = normalization["action_max"]
    
    def seed(self, seed=None):
        if seed is None:
            np.random.seed(1)
        else:
            np.random.seed(seed)
            
    def normalize_obs(self, obs):
        return 2 * ((obs - self.obs_min) / (self.obs_max - self.obs_min + 1e-6) - 0.5)
    
    def unnormalize_action(self, action):
        action = (action + 1) / 2  # [-1, 1] -> [0, 1]
        return action * (self.action_max - self.action_min) + self.action_min
    
    def step(self, action):
        self._step += 1
        if self.normalize:
            action = self.unnormalize_action(action)
        action = action[0]

        if self.algorithm_name in ["diff-gail", "diff-infogail"]:
            reference_action, action, use_reference = action[:self.gt_act_dim], action[self.gt_act_dim:self.gt_act_dim * 2], action[self.gt_act_dim * 2]
            if use_reference > 0.5:
                assert use_reference == 1
                action = reference_action
                # print("use reference", action)
            else:
                assert use_reference == 0
                # print("use rl action", action)
                pass

        state, reward, done, info = self._env.step(action)
        self._return += reward

        if self.algorithm_name in ["infogail", "diff-infogail", "diayn"]:
            if self._step % self.sample_z_interval == 0:
                self.z = np.random.randn(self.z_latent_dim)
            state = np.concatenate([self.z, state])
        
        if done:
            info.update({"episode_length": self._step,
                         "episode_return": self._return})
        
        if self.normalize:
            state = self.normalize_obs(state)

        return [state], [np.array([reward])], [done], [info]
    
    def reset(self):
        self._step = 0
        self._return = 0
        state = self._env.reset()
        if self.normalize:
            state = self.normalize_obs(state)

        if self.algorithm_name in ["infogail", "diff-infogail", "diayn"]:
            self.z = np.random.randn(self.z_latent_dim)
            state = np.concatenate([self.z, state])
        t=[state]
        return [state]
    
    def close(self):
        return
