import numpy as np

from rlkit.data_management.replay_buffer import ReplayBuffer
from rlkit.torch.data_aug import random_crop


class SimpleReplayBuffer(ReplayBuffer):
    def __init__(self, max_replay_buffer_size, observation_dim, action_dim):
        self._observation_dim = observation_dim
        self._action_dim = action_dim
        self._max_replay_buffer_size = max_replay_buffer_size
        self._observations = np.zeros((max_replay_buffer_size, observation_dim))
        self._start_obs = np.zeros((max_replay_buffer_size, observation_dim))
        self._start_obs_top = 0
        # It's a bit memory inefficient to save the observations twice,
        # but it makes the code *much* easier since you no longer have to
        # worry about termination conditions.
        self._next_obs = np.zeros((max_replay_buffer_size, observation_dim))
        self._actions = np.zeros((max_replay_buffer_size, action_dim))
        # Make everything a 2D np array to make it easier for other code to
        # reason about the shape of the data
        self._rewards = np.zeros((max_replay_buffer_size, 1))
        self._sparse_rewards = np.zeros((max_replay_buffer_size, 1))
        # self._terminals[i] = a terminal was received at time i
        self._terminals = np.zeros((max_replay_buffer_size, 1), dtype="uint8")
        self._env_infos = np.zeros((max_replay_buffer_size,), dtype="object")
        self._num_start_obs = 0
        self.clear()

    def add_sample(
        self,
        observation,
        action,
        reward,
        terminal,
        next_observation,
        env_info,
        **kwargs
    ):
        self._observations[self._top] = observation
        self._actions[self._top] = action
        self._rewards[self._top] = reward
        self._terminals[self._top] = terminal
        self._next_obs[self._top] = next_observation
        self._sparse_rewards[self._top] = env_info.get("sparse_reward", 0)
        # self._sparse_rewards[self._top] = reward
        self._env_infos[self._top] = env_info
        self._advance()

    def add_start_obs(self, observation):
        self._start_obs[self._start_obs_top] = observation
        self._advance_start_obs()

    def terminate_episode(self):
        # store the episode beginning once the episode is over
        # n.b. allows last episode to loop but whatever
        self._episode_starts.append(self._cur_episode_start)
        self._cur_episode_start = self._top

    def size(self):
        return self._size

    def clear(self):
        self._top = 0
        self._size = 0
        self._episode_starts = []
        self._cur_episode_start = 0
        self._num_start_obs = 0

    def _advance(self):
        self._top = (self._top + 1) % self._max_replay_buffer_size
        if self._size < self._max_replay_buffer_size:
            self._size += 1

    def _advance_start_obs(self):
        self._start_obs_top = (self._start_obs_top + 1) % self._max_replay_buffer_size
        if self._num_start_obs < self._max_replay_buffer_size:
            self._num_start_obs += 1

    def sample_data(self, indices):
        return dict(
            observations=self._observations[indices],
            actions=self._actions[indices],
            rewards=self._rewards[indices],
            terminals=self._terminals[indices],
            next_observations=self._next_obs[indices],
            env_infos=self._env_infos[indices],
            sparse_rewards=self._sparse_rewards[indices],
        )

    def sample_start_obs(self, indices):
        return dict(start_obs=self._start_obs[indices])

    def random_batch(self, batch_size):
        """ batch of unordered transitions """
        indices = np.random.randint(0, self._size, batch_size)
        return self.sample_data(indices)

    def random_start_obs(self, batch_size):
        indices = np.random.randint(0, self._num_start_obs, batch_size)
        return self.sample_start_obs(indices)

    def random_sequence(self, batch_size):
        """ batch of trajectories """
        # take random trajectories until we have enough
        i = 0
        indices = []
        while len(indices) < batch_size:
            # TODO hack to not deal with wrapping episodes, just don't take the last one
            start = np.random.choice(self.episode_starts[:-1])
            pos_idx = self._episode_starts.index(start)
            indices += list(range(start, self._episode_starts[pos_idx + 1]))
            i += 1
        # cut off the last traj if needed to respect batch size
        indices = indices[:batch_size]
        return self.sample_data(indices)

    def num_steps_can_sample(self):
        return self._size

class SimpleImageReplayBuffer(ReplayBuffer):
    def __init__(self, max_replay_buffer_size, observation_shape, action_dim):
        self._observation_shape = observation_shape
        self._action_dim = action_dim
        self._max_replay_buffer_size = max_replay_buffer_size
        self._observations = np.zeros((max_replay_buffer_size, *observation_shape), dtype="uint8")
        self._start_obs = np.zeros((max_replay_buffer_size, *observation_shape), dtype="uint8")
        self._start_obs_top = 0
        # It's a bit memory inefficient to save the observations twice,
        # but it makes the code *much* easier since you no longer have to
        # worry about termination conditions.
        self._next_obs = np.zeros((max_replay_buffer_size, *observation_shape), dtype="uint8")
        self._actions = np.zeros((max_replay_buffer_size, action_dim))
        # Make everything a 2D np array to make it easier for other code to
        # reason about the shape of the data
        self._rewards = np.zeros((max_replay_buffer_size, 1))
        self._sparse_rewards = np.zeros((max_replay_buffer_size, 1))
        # self._terminals[i] = a terminal was received at time i
        self._terminals = np.zeros((max_replay_buffer_size, 1), dtype="uint8")
        self._env_infos = np.zeros((max_replay_buffer_size,), dtype="object")
        self._num_start_obs = 0
        self.clear()

    def add_sample(
        self,
        observation,
        action,
        reward,
        terminal,
        next_observation,
        env_info,
        **kwargs
    ):
        self._observations[self._top] = observation
        self._actions[self._top] = action
        self._rewards[self._top] = reward
        self._terminals[self._top] = terminal
        self._next_obs[self._top] = next_observation
        self._sparse_rewards[self._top] = env_info.get("sparse_reward", 0)
        # self._sparse_rewards[self._top] = reward
        self._env_infos[self._top] = env_info
        self._advance()

    def add_start_obs(self, observation):
        self._start_obs[self._start_obs_top] = observation
        self._advance_start_obs()

    def terminate_episode(self):
        # store the episode beginning once the episode is over
        # n.b. allows last episode to loop but whatever
        self._episode_starts.append(self._cur_episode_start)
        self._cur_episode_start = self._top

    def size(self):
        return self._size

    def clear(self):
        self._top = 0
        self._size = 0
        self._episode_starts = []
        self._cur_episode_start = 0
        self._num_start_obs = 0

    def _advance(self):
        self._top = (self._top + 1) % self._max_replay_buffer_size
        if self._size < self._max_replay_buffer_size:
            self._size += 1

    def _advance_start_obs(self):
        self._start_obs_top = (self._start_obs_top + 1) % self._max_replay_buffer_size
        if self._num_start_obs < self._max_replay_buffer_size:
            self._num_start_obs += 1

    def sample_data(self, indices):
        observations = random_crop(self._observations[indices])
        next_observations = random_crop(self._next_obs[indices])
        return dict(
            observations=observations,
            actions=self._actions[indices],
            rewards=self._rewards[indices],
            terminals=self._terminals[indices],
            next_observations=next_observations,
            env_infos=self._env_infos[indices],
            sparse_rewards=self._sparse_rewards[indices],
        )

    def sample_start_obs(self, indices):
        start_obs = random_crop(self._start_obs[indices])
        return dict(start_obs=start_obs)

    def random_batch(self, batch_size):
        """ batch of unordered transitions """
        indices = np.random.randint(0, self._size, batch_size)
        return self.sample_data(indices)

    def random_start_obs(self, batch_size):
        indices = np.random.randint(0, self._num_start_obs, batch_size)
        return self.sample_start_obs(indices)

    def random_sequence(self, batch_size):
        """ batch of trajectories """
        # take random trajectories until we have enough
        i = 0
        indices = []
        while len(indices) < batch_size:
            # TODO hack to not deal with wrapping episodes, just don't take the last one
            start = np.random.choice(self.episode_starts[:-1])
            pos_idx = self._episode_starts.index(start)
            indices += list(range(start, self._episode_starts[pos_idx + 1]))
            i += 1
        # cut off the last traj if needed to respect batch size
        indices = indices[:batch_size]
        return self.sample_data(indices)

    def num_steps_can_sample(self):
        return self._size
