from typing import Any, Callable, List, Optional, Tuple, Union

import numpy as np

from tianshou.data import Batch, ReplayBuffer


class HERReplayBuffer(ReplayBuffer):
    """Implementation of Hindsight Experience Replay. arXiv:1707.01495.

    HERReplayBuffer is to be used with goal-based environment where the
    observation is a dictionary with keys ``observation``, ``achieved_goal`` and
    ``desired_goal``. Currently support only HER's future strategy, online sampling.

    :param int size: the size of the replay buffer.
    :param compute_reward_fn: a function that takes 2 ``np.array`` arguments,
        ``acheived_goal`` and ``desired_goal``, and returns rewards as ``np.array``.
        The two arguments are of shape (batch_size, ...original_shape) and the returned
        rewards must be of shape (batch_size,).
    :param int horizon: the maximum number of steps in an episode.
    :param int future_k: the 'k' parameter introduced in the paper. In short, there
        will be at most k episodes that are re-written for every 1 unaltered episode
        during the sampling.

    .. seealso::

        Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage.
    """

    def __init__(
        self,
        size: int,
        compute_reward_fn: Callable[[np.ndarray, np.ndarray], np.ndarray],
        horizon: int,
        future_k: float = 8.0,
        **kwargs: Any,
    ) -> None:
        super().__init__(size, **kwargs)
        self.horizon = horizon
        self.future_p = 1 - 1 / future_k
        self.compute_reward_fn = compute_reward_fn
        self._original_meta = Batch()
        self._altered_indices = np.array([])

    def _restore_cache(self) -> None:
        """Write cached original meta back to `self._meta`.

        It's called everytime before 'writing', 'sampling' or 'saving' the buffer.
        """
        if not hasattr(self, '_altered_indices'):
            return

        if self._altered_indices.size == 0:
            return
        self._meta[self._altered_indices] = self._original_meta
        # Clean
        self._original_meta = Batch()
        self._altered_indices = np.array([])

    def reset(self, keep_statistics: bool = False) -> None:
        self._restore_cache()
        return super().reset(keep_statistics)

    def save_hdf5(self, path: str, compression: Optional[str] = None) -> None:
        self._restore_cache()
        return super().save_hdf5(path, compression)

    def set_batch(self, batch: Batch) -> None:
        self._restore_cache()
        return super().set_batch(batch)

    def update(self, buffer: Union["HERReplayBuffer", "ReplayBuffer"]) -> np.ndarray:
        self._restore_cache()
        return super().update(buffer)

    def add(
        self,
        batch: Batch,
        buffer_ids: Optional[Union[np.ndarray, List[int]]] = None
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
        self._restore_cache()
        return super().add(batch, buffer_ids)

    def sample_indices(self, batch_size: int) -> np.ndarray:
        """Get a random sample of index with size = batch_size.

        Return all available indices in the buffer if batch_size is 0; return an \
        empty numpy array if batch_size < 0 or no available index can be sampled. \
        Additionally, some episodes of the sampled transitions will be re-written \
        according to HER.
        """
        self._restore_cache()
        indices = super().sample_indices(batch_size=batch_size)
        self.rewrite_transitions(indices.copy())
        return indices

    def rewrite_transitions(self, indices: np.ndarray) -> None:
        """Re-write the goal of some sampled transitions' episodes according to HER.

        Currently applies only HER's 'future' strategy. The new goals will be written \
        directly to the internal batch data temporarily and will be restored right \
        before the next sampling or when using some of the buffer's method (e.g. \
        `add`, `save_hdf5`, etc.). This is to make sure that n-step returns \
        calculation etc., performs correctly without additional alteration.
        """
        if indices.size == 0:
            return

        # Sort indices keeping chronological order
        indices[indices < self._index] += self.maxsize
        indices = np.sort(indices)
        indices[indices >= self.maxsize] -= self.maxsize

        # Construct episode trajectories
        indices = [indices]
        for _ in range(self.horizon - 1):
            indices.append(self.next(indices[-1]))
        indices = np.stack(indices)

        # Calculate future timestep to use
        current = indices[0]
        terminal = indices[-1]
        future_offset = np.random.uniform(size=len(indices[0])) * (terminal - current)
        future_offset = future_offset.astype(int)
        future_t = (current + future_offset)

        # Compute indices
        #   open indices are used to find longest, unique trajectories among
        #   presented episodes
        unique_ep_open_indices = np.sort(np.unique(terminal, return_index=True)[1])
        unique_ep_indices = indices[:, unique_ep_open_indices]
        #   close indices are used to find max future_t among presented episodes
        unique_ep_close_indices = np.hstack(
            [(unique_ep_open_indices - 1)[1:],
             len(terminal) - 1]
        )
        #   episode indices that will be altered
        her_ep_indices = np.random.choice(
            len(unique_ep_open_indices),
            size=int(len(unique_ep_open_indices) * self.future_p),
            replace=False
        )

        # Cache original meta
        self._altered_indices = unique_ep_indices.copy()
        self._original_meta = self._meta[self._altered_indices].copy()

        # Copy original obs, ep_rew (and obs_next), and obs of future time step
        ep_obs = self[unique_ep_indices].obs
        ep_rew = self[unique_ep_indices].rew
        if self._save_obs_next:
            ep_obs_next = self[unique_ep_indices].obs_next
            future_obs = self[future_t[unique_ep_close_indices]].obs_next
        else:
            future_obs = self[self.next(future_t[unique_ep_close_indices])].obs

        # Re-assign goals and rewards via broadcast assignment
        ep_obs.desired_goal[:, her_ep_indices] = \
            future_obs.achieved_goal[None, her_ep_indices]
        if self._save_obs_next:
            ep_obs_next.desired_goal[:, her_ep_indices] = \
                future_obs.achieved_goal[None, her_ep_indices]
            ep_rew[:, her_ep_indices] = \
                self._compute_reward(ep_obs_next)[:, her_ep_indices]
        else:
            tmp_ep_obs_next = self[self.next(unique_ep_indices)].obs
            ep_rew[:, her_ep_indices] = \
                self._compute_reward(tmp_ep_obs_next)[:, her_ep_indices]

        # Sanity check
        assert ep_obs.desired_goal.shape[:2] == unique_ep_indices.shape
        assert ep_obs.achieved_goal.shape[:2] == unique_ep_indices.shape
        assert ep_rew.shape == unique_ep_indices.shape

        # Re-write meta
        self._meta.obs[unique_ep_indices] = ep_obs
        if self._save_obs_next:
            self._meta.obs_next[unique_ep_indices] = ep_obs_next
        self._meta.rew[unique_ep_indices] = ep_rew.astype(np.float32)

    def _compute_reward(self, obs: Batch, lead_dims: int = 2) -> np.ndarray:
        lead_shape = obs.observation.shape[:lead_dims]
        g = obs.desired_goal.reshape(-1, *obs.desired_goal.shape[lead_dims:])
        ag = obs.achieved_goal.reshape(-1, *obs.achieved_goal.shape[lead_dims:])
        rewards = self.compute_reward_fn(ag, g)
        return rewards.reshape(*lead_shape, *rewards.shape[1:])
