from copy import deepcopy
from typing import (
    Any,
    Dict,
    List,
    Optional,
    Tuple,
    Union,
)

import gym
import metaworld
import numpy as np

from gym.wrappers import TimeLimit

from transfer.envs.utils.wrappers import (
    FinishEarly,
    OneHotAdder,
    RandomizationWrapper,
    RenderWrapper,
    RewardSparsifier,
    StitchedSuccessCounter,
    SuccessCounter,
    TimestepAdder,
)


def get_mt50() -> metaworld.MT50:
    saved_random_state = np.random.get_state()
    np.random.seed(1)
    MT50 = metaworld.MT50()
    np.random.set_state(saved_random_state)
    return MT50


MT50 = get_mt50()
META_WORLD_TIME_HORIZON = 200
MT50_TASK_NAMES = list(MT50.train_classes)
MW_OBS_LEN = 39
MW_ACT_LEN = 4


def get_task_name(name_or_number: Union[int, str]) -> str:
    try:
        index = int(name_or_number)
        return MT50_TASK_NAMES[index]
    except:
        return name_or_number


def set_simple_goal(env: gym.Env, name: str) -> None:
    goal = [task for task in MT50.train_tasks if task.env_name == name][0]
    env.set_task(goal)


def get_subtasks(name: str) -> List[metaworld.Task]:
    return [s for s in MT50.train_tasks if s.env_name == name]


def get_mt50_idx(env: gym.Env) -> int:
    idx = list(env._env_discrete_index.values())
    assert len(idx) == 1
    return idx[0]


def get_single_env(
    task: Union[int, str],
    add_one_hot: bool = False,
    one_hot_idx: int = 0,
    one_hot_len: int = 1,
    success_as_reward: bool = False,
    randomization: str = "random_init_all",
    append_timestep: bool = False,
    done_on_success: bool = False,
    reward_early_finish: bool = False,
) -> gym.Env:
    """
    Return a single task environment.

    Appends one-hot embedding to the observation, so that the model that operates on many envs
    can differentiate between them.

    Args:
      task: task name or MT50 number
      one_hot_idx: one-hot identifier (indicates order among different tasks that we consider)
      one_hot_len: length of the one-hot encoding, number of tasks that we consider
      randomization: randomization kind, one of 'deterministic', 'random_init_all',
                     'random_init_fixed20', 'random_init_small_box'.

    Return:
      gym.Env: single-task environment
    """
    task_name = get_task_name(task)
    env = MT50.train_classes[task_name]()
    env = RandomizationWrapper(env, get_subtasks(task_name), randomization)
    if append_timestep:
        env = TimestepAdder(env)
    if add_one_hot:
        env = OneHotAdder(env, one_hot_idx=one_hot_idx, one_hot_len=one_hot_len)
    # Currently TimeLimit is needed since SuccessCounter looks at dones.
    if success_as_reward:
        env = RewardSparsifier(env)
    env = TimeLimit(env, META_WORLD_TIME_HORIZON)
    if done_on_success:
        env = FinishEarly(env, reward_early_finish=reward_early_finish)
    env = SuccessCounter(env)
    env = RenderWrapper(env)
    env.name = task_name
    env.num_envs = 1
    return env


def assert_equal_excluding_goal_dimensions(os1: gym.spaces.Box, os2: gym.spaces.Box) -> None:
    assert np.array_equal(os1.low[:36], os2.low[:36])
    assert np.array_equal(os1.high[:36], os2.high[:36])
    assert np.array_equal(os1.low[39:], os2.low[39:])
    assert np.array_equal(os1.high[39:], os2.high[39:])


def remove_goal_bounds(obs_space: gym.spaces.Box) -> None:
    obs_space.low[36:39] = -np.inf
    obs_space.high[36:39] = np.inf


class ContinualLearningEnv(gym.Env):
    def __init__(self, envs: List[gym.Env], steps_per_env: int) -> None:
        for i in range(len(envs)):
            assert envs[0].action_space == envs[i].action_space
            assert_equal_excluding_goal_dimensions(envs[0].observation_space, envs[i].observation_space)
        self.action_space = envs[0].action_space
        self.observation_space = deepcopy(envs[0].observation_space)
        remove_goal_bounds(self.observation_space)

        self.envs = envs
        self.num_envs = len(envs)
        self.steps_per_env = steps_per_env
        self.steps_limit = self.num_envs * self.steps_per_env
        self.cur_step = 0
        self.cur_seq_idx = 0

    def _check_steps_bound(self) -> None:
        if self.cur_step >= self.steps_limit:
            raise RuntimeError("Steps limit exceeded for ContinualLearningEnv!")

    def pop_successes(self) -> List[bool]:
        all_successes = []
        self.avg_env_success = {}
        for env in self.envs:
            successes = env.pop_successes()
            all_successes += successes
            if len(successes) > 0:
                self.avg_env_success[env.name] = np.mean(successes)
        return all_successes

    def step(self, action: Any) -> Tuple[np.ndarray, float, bool, Dict]:
        self._check_steps_bound()
        obs, reward, done, info = self.envs[self.cur_seq_idx].step(action)
        info["seq_idx"] = self.cur_seq_idx

        self.cur_step += 1
        if self.cur_step % self.steps_per_env == 0:
            # If we hit limit for current env, end the episode.
            # This may cause border episodes to be shorter than 200.
            done = True
            info["TimeLimit.truncated"] = True

            self.cur_seq_idx += 1

        return obs, reward, done, info

    def reset(self) -> np.ndarray:
        self._check_steps_bound()
        return self.envs[self.cur_seq_idx].reset()


def get_cl_env(tasks: List[Union[int, str]], steps_per_task: int, randomization: str = "random_init_all") -> gym.Env:
    """Return continual learning environment.

    Args:
      tasks: list of task names or MT50 numbers
      steps_per_task: steps the agent will spend in each of single environments
      randomization: randomization kind, one of 'deterministic', 'random_init_all',
                     'random_init_fixed20', 'random_init_small_box'.

    Returns:
      gym.Env: continual learning environment
    """
    task_names = [get_task_name(task) for task in tasks]
    num_tasks = len(task_names)
    envs = []
    for i, task_name in enumerate(task_names):
        env = MT50.train_classes[task_name]()
        env = RandomizationWrapper(env, get_subtasks(task_name), randomization)
        env = OneHotAdder(env, one_hot_idx=i, one_hot_len=num_tasks)
        env.name = task_name
        env = TimeLimit(env, META_WORLD_TIME_HORIZON)
        env = SuccessCounter(env)
        env = RenderWrapper(env)
        envs.append(env)
    cl_env = ContinualLearningEnv(envs, steps_per_task)
    cl_env.name = "ContinualLearningEnv"
    return cl_env


class MultiTaskEnv(gym.Env):
    def __init__(self, envs: List[gym.Env], steps_per_env: int, cycle_mode: str = "episode") -> None:
        assert cycle_mode == "episode"
        for i in range(len(envs)):
            assert envs[0].action_space == envs[i].action_space
            assert_equal_excluding_goal_dimensions(envs[0].observation_space, envs[i].observation_space)
        self.action_space = envs[0].action_space
        self.observation_space = deepcopy(envs[0].observation_space)
        remove_goal_bounds(self.observation_space)

        self.envs = envs
        self.num_envs = len(envs)
        self.steps_per_env = steps_per_env
        self.cycle_mode = cycle_mode

        self.steps_limit = self.num_envs * self.steps_per_env
        self.cur_step = 0
        self._cur_seq_idx = 0

    def _check_steps_bound(self) -> None:
        if self.cur_step >= self.steps_limit:
            raise RuntimeError("Steps limit exceeded for MultiTaskEnv!")

    def pop_successes(self) -> List[bool]:
        all_successes = []
        self.avg_env_success = {}
        for env in self.envs:
            successes = env.pop_successes()
            all_successes += successes
            if len(successes) > 0:
                self.avg_env_success[env.name] = np.mean(successes)
        return all_successes

    def step(self, action: Any) -> Tuple[np.ndarray, float, bool, Dict]:
        self._check_steps_bound()
        obs, reward, done, info = self.envs[self._cur_seq_idx].step(action)
        info["mt_seq_idx"] = self._cur_seq_idx
        if self.cycle_mode == "step":
            self._cur_seq_idx = (self._cur_seq_idx + 1) % self.num_envs
        self.cur_step += 1

        return obs, reward, done, info

    def reset(self) -> np.ndarray:
        self._check_steps_bound()
        if self.cycle_mode == "episode":
            self._cur_seq_idx = (self._cur_seq_idx + 1) % self.num_envs
        obs = self.envs[self._cur_seq_idx].reset()
        return obs


def get_mt_env(
    tasks: List[Union[int, str]],
    steps_per_task: int,
    randomization: str = "random_init_all",
    success_as_reward: bool = False,
    append_timestep: bool = False,
    done_on_success: bool = False,
    reward_early_finish: bool = False,
    one_hot_len: Optional[int] = None,
):
    """Return multi-task learning environment.

    Args:
      tasks: list of task names or MT50 numbers
      steps_per_task: agent will be limited to steps_per_task * len(tasks) steps
      randomization: randomization kind, one of 'deterministic', 'random_init_all',
                     'random_init_fixed20', 'random_init_small_box'.

    Returns:
      gym.Env: continual learning environment
    """
    task_names = [get_task_name(task) for task in tasks]
    num_tasks = len(task_names)

    if one_hot_len is None:
        one_hot_len = num_tasks

    envs = []
    for i, task_name in enumerate(task_names):
        env = MT50.train_classes[task_name]()
        env = RandomizationWrapper(env, get_subtasks(task_name), randomization)
        if append_timestep:
            env = TimestepAdder(env)
        env = OneHotAdder(env, one_hot_idx=i, one_hot_len=one_hot_len)
        if success_as_reward:
            env = RewardSparsifier(env)
        env = TimeLimit(env, META_WORLD_TIME_HORIZON)
        if done_on_success:
            env = FinishEarly(env, reward_early_finish=reward_early_finish)
        env = SuccessCounter(env)
        env = RenderWrapper(env)
        env.name = task_name
        envs.append(env)
    mt_env = MultiTaskEnv(envs, steps_per_task)
    mt_env.name = "MultiTaskEnv"
    return mt_env


class StitchedTasksEnv(gym.Env):
    def __init__(self, envs, verbose: bool = False, accumulate_rewards: bool = False, continue_from_pos: bool = False):
        for i in range(len(envs)):
            assert envs[0].action_space == envs[i].action_space
            assert_equal_excluding_goal_dimensions(envs[0].observation_space, envs[i].observation_space)
        self.action_space = envs[0].action_space
        self.observation_space = deepcopy(envs[0].observation_space)
        self.verbose = verbose

        self.envs = envs
        self.current_env_idx = 0
        self.timestep = 0
        self.transition_times = []
        self.accumulated_reward = 0.0
        self.accumulate_rewards = accumulate_rewards
        self.continue_from_pos = continue_from_pos

    @property
    def viewer(self):
        return self.envs[self.current_env_idx].viewer

    def reset(self) -> np.ndarray:
        self.current_env_idx = 0
        self.accumulated_reward = 0.0
        self.transition_times = []
        self.timestep = 0
        obs = self.envs[0].reset()
        for env in self.envs[1:]:
            env.reset()
        return obs

    def step(self, action: Any):
        obs, reward, done, info = self.envs[self.current_env_idx].step(action)
        new_reward = reward + self.accumulated_reward
        info["transition"] = False

        self.timestep += 1
        if info["success"]:
            last_pos = self.envs[self.current_env_idx].data.mocap_pos[0]
            self.current_env_idx += 1
            self.transition_times += [self.timestep]

            if self.current_env_idx < len(self.envs):
                info["transition"] = True
                if self.continue_from_pos:
                    self.envs[self.current_env_idx].unwrapped.hand_init_pos = last_pos
                obs = self.envs[self.current_env_idx].reset()
                if self.accumulate_rewards:
                    self.accumulated_reward += reward
                done = False
            else:
                done = True

        if done and self.verbose:
            print(f"Env {self.current_env_idx}, {self.timestep}. Hist: {self.transition_times}")
        info["env_stage"] = self.current_env_idx
        return obs, new_reward, done, info


def get_stitched_env(
    tasks: List[Union[int, str]],
    ordering: Optional[List[int]],
    randomization: str = "random_init_all",
    verbose: bool = False,
    success_as_reward: bool = False,
    accumulate_rewards: bool = False,
    continue_from_pos: bool = False,
    append_timestep: bool = False,
    done_on_success: bool = False,
    reward_early_finish: bool = False,
):
    """Return multi-task learning environment.

    Args:
      tasks: list of task names or MT50 numbers
      steps_per_task: agent will be limited to steps_per_task * len(tasks) steps
      randomization: randomization kind, one of 'deterministic', 'random_init_all',
                     'random_init_fixed20', 'random_init_small_box'.

    Returns:
      gym.Env: continual learning environment
    """
    if ordering is None:
        ordering = list(range(len(tasks)))

    task_names = [get_task_name(task) for task in tasks]
    num_tasks = len(task_names)
    envs = []
    for i, task_name in enumerate(task_names):
        env = MT50.train_classes[task_name]()
        env = RandomizationWrapper(env, get_subtasks(task_name), randomization)
        if append_timestep:
            env = TimestepAdder(env)
        env = OneHotAdder(env, one_hot_idx=i, one_hot_len=num_tasks)
        if success_as_reward:
            env = RewardSparsifier(env)
        env = TimeLimit(env, META_WORLD_TIME_HORIZON)
        if done_on_success:
            env = FinishEarly(env, reward_early_finish=reward_early_finish)
        env = RenderWrapper(env)
        env.name = task_name
        envs.append(env)

    rearranged_envs = [envs[order_idx] for order_idx in ordering]
    stitched_env = StitchedTasksEnv(
        rearranged_envs, accumulate_rewards=accumulate_rewards, continue_from_pos=continue_from_pos, verbose=verbose
    )
    stitched_env.name = "StitchedTasksEnv"
    stitched_env = TimeLimit(stitched_env, META_WORLD_TIME_HORIZON * len(tasks))
    stitched_env = StitchedSuccessCounter(stitched_env)
    return stitched_env
