import numpy as np
import torch

import metaworld.envs.mujoco.env_dict as _env_dict
from gymnasium.wrappers.time_limit import TimeLimit


from task.abs_task import AbsTask


def trim_mw_obs(obs):
    # Credit:CPL https://github.com/jhejna/cpl
    # Remove the double robot observation from the environment.
    # Only stack two object observations
    # this helps keep everything more markovian
    return np.concatenate((obs[:18], obs[22:]), dtype=np.float32)


class MetaworldTask(AbsTask):
    def __init__(
        self, task_name, device, timesteps_per_rollout, randomize_hand=True, seed=0
    ):
        super().__init__()
        env_name = task_name
        self.device = device
        if env_name in _env_dict.ALL_V2_ENVIRONMENTS:
            env_cls = _env_dict.ALL_V2_ENVIRONMENTS[env_name]
        else:
            env_cls = _env_dict.ALL_V1_ENVIRONMENTS[env_name]

        env_cls_init = env_cls()

        env_cls_init._freeze_rand_vec = False
        env_cls_init._set_task_called = True
        env_cls_init._partially_observable = False

        env_cls_init.seed(seed)
        self.randomize_hand = randomize_hand

        if timesteps_per_rollout < 0:
            max_path_length = env_cls_init.max_path_length
        else:
            max_path_length = timesteps_per_rollout

        self.env = TimeLimit(env_cls_init, max_path_length)
        self.set_seed(seed)

        self.obs_dim = self.env.observation_space.shape[0] - 4
        self.action_dim = self.env.action_space.shape[0]
        self.action_range = [
            float(self.env.action_space.low.min()),
            float(self.env.action_space.high.max()),
        ]

    def step(self, action):
        if type(action) == torch.Tensor:
            action = action.cpu().numpy()

        observation, reward, terminated, truncated, info = self.env.step(action)

        info["discount"] = 1.0
        observation = trim_mw_obs(observation.astype(np.float32))

        return observation, reward, terminated, truncated, info

    def _get_obs(self):
        return trim_mw_obs(self.env.unwrapped._get_obs())

    def reset(self):
        observation, info = self.env.reset()

        if self.randomize_hand:
            # Hand init pos is usually set to self.init_hand_pos
            # We will add some uniform noise to it.
            high = np.array([0.25, 0.15, 0.2], dtype=np.float32)
            hand_init_pos = self.env.hand_init_pos + np.random.uniform(
                low=-high, high=high
            )
            hand_init_pos = np.clip(
                hand_init_pos, a_min=self.env.mocap_low, a_max=self.env.mocap_high
            )
            hand_init_pos = np.expand_dims(hand_init_pos, axis=0)
            for _ in range(50):
                self.env.unwrapped.data.mocap_pos = hand_init_pos
                self.env.unwrapped.data.mocap_quat = np.array([1, 0, 1, 0])
                self.env.do_simulation([-1, 1], self.env.frame_skip)

        observation = self._get_obs().astype(np.float32)

        return observation, info

    def max_episode_steps(self):
        return self.env._max_episode_steps

    def set_state(self, state):
        joint_state = self.env.sim.get_state()
        if not hasattr(self, "_split_shapes"):
            self.get_state()  # Load the split
        qpos, qvel, mocap_pos, mocap_quat, rand_vec = np.split(
            state, self._split_shapes, axis=0
        )
        if not np.all(self.env._last_rand_vec == rand_vec):
            # We need to set the rand vec and then reset
            self.env._freeze_rand_vec = True
            self.env._last_rand_vec = rand_vec
            self.env.reset()
        joint_state.qpos[:] = qpos
        joint_state.qvel[:] = qvel
        self.env.set_env_state(
            (
                joint_state,
                (np.expand_dims(mocap_pos, axis=0), np.expand_dims(mocap_quat, axis=0)),
            )
        )
        self.env.sim.forward()

    def evaluate(self, agent, reward, num_episodes, step, logger):
        env_rewards = []
        model_rewards = []
        success_rates = []
        episode_lengths = []

        for episode in range(num_episodes):
            observation, info = self.reset()

            agent.reset()
            done = False
            truncated = False
            episode_env_reward = 0
            episode_model_reward = 0
            success_rate = 0.0
            episode_length = 0
            while not (done or truncated):
                action = agent.act(observation, sample=False)
                if type(action) == torch.Tensor:
                    action = action.cpu().numpy()

                next_observation, env_reward, done, truncated, info = self.step(action)

                observation = torch.from_numpy(observation).to(self.device)
                action = torch.from_numpy(action).to(self.device)
                if reward is not None:
                    model_reward = reward.reward(observation, action).item()
                else:
                    model_reward = 0.0

                observation = next_observation

                episode_env_reward += env_reward
                episode_model_reward += model_reward
                if info["success"]:
                    success_rate = info["success"]
                episode_length += 1

            env_rewards.append(episode_env_reward)
            model_rewards.append(episode_model_reward)
            success_rates.append(success_rate)
            episode_lengths.append(episode_length)

        env_rewards = np.array(env_rewards)
        model_rewards = np.array(model_rewards)
        success_rates = np.array(success_rates)
        episode_lengths = np.array(episode_lengths)

        metrics = {
            "env_rewards_mean": np.mean(env_rewards),
            "env_rewards_std": np.std(env_rewards),
            "model_rewards_mean": np.mean(model_rewards),
            "model_rewards_std": np.std(model_rewards),
            "success_rates_mean": np.mean(success_rates),
            "success_rates_std": np.std(success_rates),
            "episode_lengths_mean": np.mean(episode_lengths),
        }
        logger.log("eval_metrics", metrics, step)
        return metrics
