import numpy as np

from rlkit.data_management.simple_replay_buffer import SimpleReplayBuffer, \
SimpleImageReplayBuffer
from gym.spaces import Box, Discrete, Tuple


class MultiTaskReplayBuffer(object):
    def __init__(self, max_replay_buffer_size, env, tasks, visual=False,
            obs_shape=(3,100,100),):
        """
        :param max_replay_buffer_size:
        :param env:
        :param tasks: for multi-task setting
        """
        self.env = env
        self._ob_space = env.observation_space
        self._action_space = env.action_space
        if visual:
            self.task_buffers = dict(
                [
                    (
                        idx,
                        SimpleImageReplayBuffer(
                            max_replay_buffer_size=max_replay_buffer_size,
                            observation_shape=obs_shape,
                            action_dim=get_dim(self._action_space),
                        ),
                    )
                    for idx in tasks
                ]
            )
        else:
            self.task_buffers = dict(
                [
                    (
                        idx,
                        SimpleReplayBuffer(
                            max_replay_buffer_size=max_replay_buffer_size,
                            observation_dim=get_dim(self._ob_space),
                            action_dim=get_dim(self._action_space),
                        ),
                    )
                    for idx in tasks
                ]
            )

    def add_sample(
        self,
        task,
        observation,
        action,
        reward,
        terminal,
        next_observation,
        env_info,
        **kwargs
    ):

        if isinstance(self._action_space, Discrete):
            action = np.eye(self._action_space.n)[action]
        self.task_buffers[task].add_sample(
            observation, action, reward, terminal, next_observation, env_info, **kwargs
        )

    def terminate_episode(self, task):
        self.task_buffers[task].terminate_episode()

    def random_batch(self, task, batch_size, sequence=False):
        if sequence:
            batch = self.task_buffers[task].random_sequence(batch_size)
        else:
            batch = self.task_buffers[task].random_batch(batch_size)
        return batch

    def random_start_obs(self, task, batch_size):
        return self.task_buffers[task].random_start_obs(batch_size)

    def num_steps_can_sample(self, task):
        return self.task_buffers[task].num_steps_can_sample()

    def add_path(self, task, path):
        self.task_buffers[task].add_path(path)

    def add_paths(self, task, paths):
        for path in paths:
            self.task_buffers[task].add_path(path)

    def clear_buffer(self, task):
        self.task_buffers[task].clear()


def get_dim(space):
    if isinstance(space, Box):
        return space.low.size
    elif isinstance(space, Discrete):
        return space.n
    elif isinstance(space, Tuple):
        return sum(get_dim(subspace) for subspace in space.spaces)
    elif hasattr(space, "flat_dim"):
        return space.flat_dim
    else:
        # import OldBox here so it is not necessary to have rand_param_envs
        # installed if not running the rand_param envs
        from rand_param_envs.gym.spaces.box import Box as OldBox

        if isinstance(space, OldBox):
            return space.low.size
        else:
            raise TypeError("Unknown space: {}".format(space))
