"""This module implements a Hindsight Experience Replay (HER).

See: https://arxiv.org/abs/1707.01495.
"""
import copy

import numpy as np
from src.replay_buffer import PathBuffer


class HERReplayBuffer(PathBuffer):
    """Replay buffer for HER (Hindsight Experience Replay).

    It constructs hindsight examples using future strategy.

    Args:
        replay_k (int): Number of HER transitions to add for each regular
            Transition. Setting this to 0 means that no HER replays will
            be added.
        reward_fn (callable): Function to re-compute the reward with
            substituted goals.
        capacity_in_transitions (int): total size of transitions in the buffer.
        env_spec (EnvSpec): Environment specification.
    """

    def __init__(self, replay_k, reward_fn, capacity_in_transitions, env_spec):
        self._replay_k = replay_k
        self._reward_fn = reward_fn

        if not float(replay_k).is_integer() or replay_k < 0:
            raise ValueError("replay_k must be an integer and >= 0.")
        super().__init__(capacity_in_transitions, env_spec)

    def _sample_her_goals(self, path, transition_idx):
        """Samples HER goals from the given path.

        Goals are randomly sampled starting from the index after
        transition_idx in the given path.

        Args:
            path (dict[str, np.ndarray]): A dict containing the transition
                keys, where each key contains an ndarray of shape
                :math:`(T, S^*)`.
            transition_idx (int): index of the current transition. Only
                transitions after the current transitions will be randomly
                sampled for HER goals.

        Returns:
            np.ndarray: A numpy array of HER goals with shape
                (replay_k, goal_dim).

        """
        goal_indexes = np.random.randint(
            transition_idx + 1, len(path["observations"]), size=self._replay_k
        )
        return [
            goal["achieved_goal"]
            for goal in np.asarray(path["observations"])[goal_indexes]
        ]

    def _flatten_dicts(self, path):
        for key in ["observations", "next_observations"]:
            if not isinstance(path[key], dict):
                path[key] = self._env_spec.observation_space.flatten_n(path[key])
            else:
                path[key] = self._env_spec.observation_space.flatten(path[key])

    def add_path(self, path):
        """Adds a path to the replay buffer.

        For each transition in the given path except the last one,
        replay_k HER transitions will added to the buffer in addition
        to the one in the path. The last transition is added without
        sampling additional HER goals.

        Args:
            path(dict[str, np.ndarray]): Each key in the dict must map
                to a np.ndarray of shape :math:`(T, S^*)`.

        """
        obs_space = self._env_spec.observation_space
        if not isinstance(path["observations"][0], dict):
            # unflatten dicts if they've been flattened
            path["observations"] = obs_space.unflatten_n(path["observations"])
            path["next_observations"] = obs_space.unflatten_n(path["next_observations"])

        # create HER transitions and add them to the buffer
        for idx in range(path["actions"].shape[0] - 1):
            transition = {key: sample[idx] for key, sample in path.items()}
            her_goals = self._sample_her_goals(path, idx)

            # create replay_k transitions using the HER goals
            for goal in her_goals:

                t_new = copy.deepcopy(transition)
                a_g = t_new["next_observations"]["achieved_goal"]

                t_new["rewards"] = np.array(self._reward_fn(a_g, goal, None))
                t_new["observations"]["desired_goal"] = goal
                t_new["next_observations"]["desired_goal"] = copy.deepcopy(goal)
                t_new["terminals"] = np.array(False)

                # flatten the observation dicts now that we're done with them
                self._flatten_dicts(t_new)

                for key in t_new.keys():
                    t_new[key] = t_new[key].reshape(1, -1)

                # Since we're using a PathBuffer, add each new transition
                # as its own path.
                super().add_path(t_new)

        self._flatten_dicts(path)
        super().add_path(path)

    def __getstate__(self):
        """Object.__getstate__.

        Returns:
            dict: The state to be pickled for the instance.

        """
        new_dict = self.__dict__.copy()
        return new_dict

    def __setstate__(self, state):
        """Object.__setstate__.

        Args:
            state (dict): Unpickled state.

        """
        self.__dict__ = state
