"""A replay buffer that efficiently stores and can sample whole paths."""
"""Code adapted from garage"""
import collections

import numpy as np

from garage import StepType, TimeStepBatch
import logging

logging.basicConfig(level=logging.INFO)


class PathBuffer:
    """A replay buffer that stores and can sample whole episodes.

    This buffer only stores valid steps, and doesn't require paths to
    have a maximum length.

    Args:
        capacity_in_transitions (int): Total memory allocated for the buffer.
        env_spec (EnvSpec): Environment specification.

    """

    def __init__(self, capacity_in_transitions, env_spec=None):
        self._capacity = capacity_in_transitions
        self._env_spec = env_spec
        self._transitions_stored = 0
        self._first_idx_of_next_path = 0
        # Each path in the buffer has a tuple of two ranges in
        # self._path_segments. If the path is stored in a single contiguous
        # region of the buffer, the second range will be range(0, 0).
        # The "left" side of the deque contains the oldest episode.
        self._path_segments = collections.deque()
        self._buffer = {}

    def add_episode_batch(self, episodes):
        """Add a EpisodeBatch to the buffer.

        Args:
            episodes (EpisodeBatch): Episodes to add.

        """
        if self._env_spec is None:
            self._env_spec = episodes.env_spec
        env_spec = episodes.env_spec
        obs_space = env_spec.observation_space
        for eps in episodes.split():
            terminals = np.array(
                [step_type == StepType.TERMINAL for step_type in eps.step_types],
                dtype=bool,
            )
            path = {
                "observations": obs_space.flatten_n(eps.observations),
                "next_observations": obs_space.flatten_n(eps.next_observations),
                "actions": env_spec.action_space.flatten_n(eps.actions),
                "rewards": eps.rewards.reshape(-1, 1),
                "terminals": terminals.reshape(-1, 1),
            }
            self.add_path(path)

    def add_path(self, path):
        """Add a path to the buffer.

        Args:
            path (dict): A dict of array of shape (path_len, flat_dim).

        Raises:
            ValueError: If a key is missing from path or path has wrong shape.

        """
        for key, buf_arr in self._buffer.items():
            path_array = path.get(key, None)
            if path_array is None:
                raise ValueError("Key {} missing from path.".format(key))
            if len(path_array.shape) != 2 or path_array.shape[1] != buf_arr.shape[1]:
                raise ValueError("Array {} has wrong shape.".format(key))
        path_len = self._get_path_length(path)
        first_seg, second_seg = self._next_path_segments(path_len)
        # Remove paths which will overlap with this one.
        while self._path_segments and self._segments_overlap(
            first_seg, self._path_segments[0][0]
        ):
            self._path_segments.popleft()
        while self._path_segments and self._segments_overlap(
            second_seg, self._path_segments[0][0]
        ):
            self._path_segments.popleft()
        self._path_segments.append((first_seg, second_seg))
        for key, array in path.items():
            buf_arr = self._get_or_allocate_key(key, array)
            # numpy doesn't special case range indexing, so it's very slow.
            # Slice manually instead, which is faster than any other method.
            buf_arr[first_seg.start : first_seg.stop] = array[: len(first_seg)]
            buf_arr[second_seg.start : second_seg.stop] = array[len(first_seg) :]
        if second_seg.stop != 0:
            self._first_idx_of_next_path = second_seg.stop
        else:
            self._first_idx_of_next_path = first_seg.stop
        self._transitions_stored = min(
            self._capacity, self._transitions_stored + path_len
        )

    def sample_path(self):
        """Sample a single path from the buffer.

        Returns:
            path: A dict of arrays of shape (path_len, flat_dim).

        """
        path_idx = np.random.randint(len(self._path_segments))
        first_seg, second_seg = self._path_segments[path_idx]
        first_seg_indices = np.arange(first_seg.start, first_seg.stop)
        second_seg_indices = np.arange(second_seg.start, second_seg.stop)
        indices = np.concatenate([first_seg_indices, second_seg_indices])
        path = {key: buf_arr[indices] for key, buf_arr in self._buffer.items()}
        return path

    def sample_transitions(self, batch_size):
        """Sample a batch of transitions from the buffer.

        Args:
            batch_size (int): Number of transitions to sample.

        Returns:
            dict: A dict of arrays of shape (batch_size, flat_dim).

        """
        idx = np.random.randint(self._transitions_stored, size=batch_size)
        return {key: buf_arr[idx] for key, buf_arr in self._buffer.items()}

    def sample_timesteps(self, batch_size):
        """Sample a batch of timesteps from the buffer.

        Args:
            batch_size (int): Number of timesteps to sample.

        Returns:
            TimeStepBatch: The batch of timesteps.

        """
        samples = self.sample_transitions(batch_size)
        step_types = np.array(
            [
                StepType.TERMINAL if terminal else StepType.MID
                for terminal in samples["terminals"].reshape(-1)
            ],
            dtype=StepType,
        )
        return TimeStepBatch(
            env_spec=self._env_spec,
            episode_infos={},
            observations=samples["observations"],
            actions=samples["actions"],
            rewards=samples["rewards"].flatten(),
            next_observations=samples["next_observations"],
            step_types=step_types,
            env_infos={},
            agent_infos={},
        )

    def _next_path_segments(self, n_indices):
        """Compute where the next path should be stored.

        Args:
            n_indices (int): Path length.

        Returns:
            tuple: Lists of indices where path should be stored.

        Raises:
            ValueError: If path length is greater than the size of buffer.

        """
        if n_indices > self._capacity:
            raise ValueError("Path is too long to store in buffer.")
        start = self._first_idx_of_next_path
        end = start + n_indices
        if end > self._capacity:
            second_end = end - self._capacity
            return (range(start, self._capacity), range(0, second_end))
        else:
            return (range(start, end), range(0, 0))

    def _get_or_allocate_key(self, key, array):
        """Get or allocate key in the buffer.

        Args:
            key (str): Key in buffer.
            array (numpy.ndarray): Array corresponding to key.

        Returns:
            numpy.ndarray: A NumPy array corresponding to key in the buffer.

        """
        buf_arr = self._buffer.get(key, None)
        if buf_arr is None:
            buf_arr = np.zeros((self._capacity, array.shape[1]), array.dtype)
            self._buffer[key] = buf_arr
        return buf_arr

    def clear(self):
        """Clear buffer."""
        self._transitions_stored = 0
        self._first_idx_of_next_path = 0
        self._path_segments.clear()
        self._buffer.clear()

    @staticmethod
    def _get_path_length(path):
        """Get path length.

        Args:
            path (dict): Path.

        Returns:
            length: Path length.

        Raises:
            ValueError: If path is empty or has inconsistent lengths.

        """
        length_key = None
        length = None
        for key, value in path.items():
            if length is None:
                length = len(value)
                length_key = key
            elif len(value) != length:
                raise ValueError(
                    "path has inconsistent lengths between "
                    "{!r} and {!r}.".format(length_key, key)
                )
        if not length:
            raise ValueError("Nothing in path")
        return length

    @staticmethod
    def _segments_overlap(seg_a, seg_b):
        """Compute if two segments overlap.

        Args:
            seg_a (range): List of indices of the first segment.
            seg_b (range): List of indices of the second segment.

        Returns:
            bool: True iff the input ranges overlap at at least one index.

        """
        # Empty segments never overlap.
        if not seg_a or not seg_b:
            return False
        first = seg_a
        second = seg_b
        if seg_b.start < seg_a.start:
            first, second = seg_b, seg_a
        assert first.start <= second.start
        return first.stop > second.start

    @property
    def n_transitions_stored(self):
        """Return the size of the replay buffer.

        Returns:
            int: Size of the current replay buffer.

        """
        return int(self._transitions_stored)
