import numpy as np
import torch

import gym
from gym.envs.mujoco.walker2d_v3 import Walker2dEnv
from garage.envs import (
    GymEnv,
    TaskOnehotWrapper,
)


class WalkerVelocityMTEnv(gym.Env):
    def __init__(
        self,
        include_task_id,
        target_velocities,
        velocity_reward_weight,
        walker_reward_weight,
        velocity_bonus,
        velocity_bonus_range,
    ):
        super().__init__()
        self._include_task_id = include_task_id
        self._task_velocities = target_velocities
        self._velocity_reward_weight = velocity_reward_weight
        self._walker_reward_weight = walker_reward_weight
        self._velocity_bonus = velocity_bonus
        self._velocity_bonus_range = velocity_bonus_range
        self._num_tasks = len(target_velocities)
        self._init_envs()

    def _init_envs(self):
        self._train_envs = []
        self._test_envs = []
        for i in range(self._num_tasks):
            self._train_envs.append(self._make_env(i))
            self._test_envs.append(self._make_env(i))
        self._curr_env = self._train_envs[0]

        # asset_path = get_asset_path("jaco_reach_multistage.xml")
        # mujoco_env.MujocoEnv.__init__(self, asset_path, 4)
        # utils.EzPickle.__init__(self)

    def _make_env(self, env_idx):
        env = WalkerVelocityEnv(
            target_velocity=self._task_velocities[env_idx],
            velocity_reward_weight=self._velocity_reward_weight,
            walker_reward_weight=self._walker_reward_weight,
            velocity_bonus=self._velocity_bonus,
            velocity_bonus_range=self._velocity_bonus_range,
        )
        env = GymEnv(env)
        env = TaskOnehotWrapper(env, task_index=env_idx, n_total_tasks=self._num_tasks)
        return env

    @property
    def action_space(self):
        return self._curr_env.action_space

    @property
    def observation_space(self):
        return self._curr_env.observation_space

    @property
    def num_tasks(self):
        return self._num_tasks

    def get_train_envs(self):
        return self._train_envs

    def get_test_envs(self):
        return self._test_envs

    # def reset_model(self):
    #     self._curr_env = self._train_envs[self._count % self._num_tasks]
    #     self._count += 1
    #     return self._curr_env.reset_model()
    #
    # def step(self, a):
    #     print("step")
    #     return super().step(a)
    #     return self._curr_env.step(a)

    def get_task_id(self, observation):
        if isinstance(observation, np.ndarray):
            id_array = np.argmax(
                observation[..., -self._num_tasks :],
                axis=-1,
            )
        else:
            id_array = torch.argmax(
                observation[..., -self._num_tasks :],
                dim=-1,
            )
        if len(id_array.shape) == 0:
            id_array = id_array[()]
        return id_array

    def split_observation(self, observation):
        obs_without_task = (
            observation.copy()
            if isinstance(observation, np.ndarray)
            else observation.clone()
        )

        if not self._include_task_id:
            ### zero out task id
            obs_without_task[..., -self._num_tasks :] = 0

        task_info = observation

        return obs_without_task, task_info


class WalkerVelocityEnv(Walker2dEnv):
    def __init__(
        self,
        target_velocity,
        velocity_reward_weight,
        walker_reward_weight,
        velocity_bonus,
        velocity_bonus_range,
        *args,
        **kwargs,
    ):
        kwargs["forward_reward_weight"] = 0.0
        self._target_velocity = target_velocity
        self._velocity_reward_weight = velocity_reward_weight
        self._walker_reward_weight = walker_reward_weight
        self._velocity_bonus = velocity_bonus
        self._velocity_bonus_range = velocity_bonus_range
        super().__init__(*args, **kwargs)

    def step(self, action):
        observation, reward, terminated, info = super().step(action)
        x_velocity = info["x_velocity"]
        velocity_diff = abs(x_velocity - self._target_velocity)
        velocity_reward = 1 - (velocity_diff) / (
            2.5 + self._target_velocity
        )
        # velocity_reward *= self._velocity_reward_weight * velocity_reward
        velocity_reward = max(velocity_reward, 0)
        reward = (
            self._walker_reward_weight * reward
            + self._velocity_reward_weight * velocity_reward
        )
        if velocity_diff < self._velocity_bonus_range:
            reward += self._velocity_bonus
        # if reward < 0:
        #     print(
        #         f"negative reward, target_velocity: {self._target_velocity}, actual velocity: {info['x_velocity']}, velocity reward: {velocity_reward}, prev reward: {reward - velocity_reward}"
        #     )
        #     reward = 0
        return observation, reward, terminated, info

    def __deepcopy__(self, memodict={}):
        cls = self.__class__
        result = cls.__new__(cls)
        result.__init__(
            self._target_velocity,
            self._velocity_reward_weight,
            self._walker_reward_weight,
            self._velocity_bonus,
            self._velocity_bonus_range,
        )
        return result
