"""A replay buffer that efficiently stores and can sample whole paths."""

import collections

import numpy as np
import torch


class PathBufferEx:
    """A replay buffer that stores and can sample whole paths.

    This buffer only stores valid steps, and doesn't require paths to
    have a maximum length.

    Args:
        capacity_in_transitions (int): Total memory allocated for the buffer.

    """

    def __init__(
        self,
        capacity_in_transitions,
        pixel_shape,
        max_episode_length=1000,
        use_goal=0,
    ):
        self._capacity = capacity_in_transitions
        self._transitions_stored = 0
        self._first_idx_of_next_path = 0
        # Each path in the buffer has a tuple of two ranges in
        # self._path_segments. If the path is stored in a single contiguous
        # region of the buffer, the second range will be range(0, 0).
        # The "left" side of the deque contains the oldest path.
        self._path_segments = collections.deque()
        self._buffer = {}

        if pixel_shape is not None:
            self._pixel_dim = np.prod(pixel_shape)
        else:
            self._pixel_dim = None
        self._pixel_keys = ["obs", "next_obs"]

        self.indices_to_episode_timesteps = np.array([], dtype=np.int32)
        self.epilengths = np.array([], dtype=np.int32)

        self.max_episode_length = max_episode_length

        from tests.main import args

        self.debug_noconst = args.debug_noconst

        self.use_goal = use_goal
        if self.use_goal:
            self._pixel_keys.append("options")

    def add_path(self, path):
        """Add a path to the buffer.

        Args:
            path (dict): A dict of array of shape (path_len, flat_dim).

        Raises:
            ValueError: If a key is missing from path or path has wrong shape.

        """
        path_len = self._get_path_length(path)
        first_seg, second_seg = self._next_path_segments(path_len)
        # Remove paths which will overlap with this one.
        while self._path_segments and self._segments_overlap(
            first_seg, self._path_segments[0][0]
        ):
            self._path_segments.popleft()
        while self._path_segments and self._segments_overlap(
            second_seg, self._path_segments[0][0]
        ):
            self._path_segments.popleft()
        self._path_segments.append((first_seg, second_seg))

        # add path keys
        path["timesteps"] = np.arange(path_len, dtype=np.int64)[:, None]
        path["epilength"] = np.full([path_len], path_len, dtype=np.int64)[:, None]
        # path["initial_obs"] = np.repeat(path["obs"][0][None], path_len, axis=0)

        for key, array in path.items():
            if self._pixel_dim is not None and key in self._pixel_keys:
                pixel_key = f"{key}_pixel"
                state_key = f"{key}_state"
                if pixel_key not in self._buffer:
                    self._buffer[pixel_key] = np.random.randint(
                        0, 255, (self._capacity, self._pixel_dim), dtype=np.uint8
                    )  # For memory preallocation
                    self._buffer[state_key] = np.zeros(
                        (self._capacity, array.shape[1] - self._pixel_dim),
                        dtype=array.dtype,
                    )
                self._buffer[pixel_key][first_seg.start : first_seg.stop] = array[
                    : len(first_seg), : self._pixel_dim
                ]
                self._buffer[state_key][first_seg.start : first_seg.stop] = array[
                    : len(first_seg), self._pixel_dim :
                ]
                self._buffer[pixel_key][second_seg.start : second_seg.stop] = array[
                    len(first_seg) :, : self._pixel_dim
                ]
                self._buffer[state_key][second_seg.start : second_seg.stop] = array[
                    len(first_seg) :, self._pixel_dim :
                ]
            elif key in ["hidden_states"]:
                if key not in self._buffer:
                    self._buffer[key] = np.random.rand(
                        self._capacity, *array.shape[1:]
                    ).astype(array.dtype)
                self._buffer[key][first_seg.start : first_seg.stop] = array[
                    : len(first_seg)
                ]
                self._buffer[key][second_seg.start : second_seg.stop] = array[
                    len(first_seg) :
                ]
            else:
                buf_arr = self._get_or_allocate_key(key, array)
                buf_arr[first_seg.start : first_seg.stop] = array[: len(first_seg)]
                buf_arr[second_seg.start : second_seg.stop] = array[len(first_seg) :]
        if second_seg.stop != 0:
            self._first_idx_of_next_path = second_seg.stop
        else:
            self._first_idx_of_next_path = first_seg.stop

        self._transitions_stored = min(
            self._capacity, self._transitions_stored + path_len
        )

    def sample_transitions(self, batch_size, only_first_and_last=False):
        """Sample a batch of transitions from the buffer.

        Args:
            batch_size (int): Number of transitions to sample.

        Returns:
            dict: A dict of arrays of shape (batch_size, flat_dim).

        """

        if only_first_and_last:
            assert batch_size % 2 == 0
            # also we need to assert that all episode lengths are the same, and max buffer size % episode length == 0.
            # TODO: add it later
            num_consecutive_steps = 5
            idx = np.random.choice(
                self._transitions_stored, batch_size // 2 // num_consecutive_steps
            )
            assert idx.ndim == 1, idx.shape
            timesteps = self._buffer["timesteps"][idx].squeeze(-1)
            epilengths = self._buffer["epilength"][idx].squeeze(-1)
            assert epilengths[0] == 200
            assert timesteps.shape == idx.shape == epilengths.shape

            initial_idx = idx - timesteps  # index for initial observations
            final_idx = initial_idx + epilengths - 1  # index for final observations
            initial_idx = (
                initial_idx[:, None] + np.arange(num_consecutive_steps)
            ).reshape(-1)
            final_idx = (final_idx[:, None] - np.arange(num_consecutive_steps)).reshape(
                -1
            )
            assert initial_idx.shape == final_idx.shape

            idx = np.stack([initial_idx, final_idx], axis=-1).reshape(-1)
        else:
            idx = np.random.choice(self._transitions_stored, batch_size)

        if self._pixel_dim is not None:
            ret_dict = {}
            keys = set(self._buffer.keys())
            for key in self._pixel_keys:
                pixel_key = f"{key}_pixel"
                state_key = f"{key}_state"
                keys.remove(pixel_key)
                keys.remove(state_key)
                if self._buffer[state_key].shape[1] != 0:
                    ret_dict[key] = np.concatenate(
                        [self._buffer[pixel_key][idx], self._buffer[state_key][idx]],
                        axis=1,
                    )
                else:
                    ret_dict[key] = self._buffer[pixel_key][idx]
            for key in keys:
                ret_dict[key] = self._buffer[key][idx]
            return ret_dict
        else:
            return {key: buf_arr[idx] for key, buf_arr in self._buffer.items()}

    def sample_transitions_with_goals(
        self,
        batch_size,
        discount=0.99,
        p_trajgoal=0.625,
    ):
        # # code from HILP
        indx = np.random.randint(self.n_transitions_stored - 1, size=batch_size)

        p_randomgoal = 1 - p_trajgoal
        p_currgoal = 0

        batch_size = len(indx)

        # Random goals
        goal_indx = np.random.randint(self.n_transitions_stored, size=batch_size)

        # Goals from the same trajectory
        # final_state_indx = self.terminal_locs[np.searchsorted(self.terminal_locs, indx)]
        timesteps = self._buffer["timesteps"][indx].squeeze(-1)
        epilengths = self._buffer["epilength"][indx].squeeze(-1)

        initial_indx = indx - timesteps  # index for initial observations

        final_state_indx = initial_indx + epilengths - 1  # index for final observations
        final_state_indx = np.minimum(
            final_state_indx, self.n_transitions_stored - 1
        )  # For example, when self._capacity % epilength != 0. like pointmaze.

        # ===== unused now  ======#
        # use geom sample
        us = np.random.rand(batch_size)
        middle_goal_indx = np.minimum(
            indx + np.ceil(np.log(1 - us) / np.log(discount)).astype(int),
            final_state_indx,
        )

        # since the code below is slow, so for now we're using above code.
        # # Calculate deltas
        # deltas = np.arange(self.max_episode_length)
        # pdeltas = np.where(
        #     (timesteps[:, None] + deltas) < epilengths[:, None],
        #     discount**deltas,
        #     0,
        # )

        # # Normalize the probabilities
        # pdeltas_sum = np.sum(pdeltas, axis=1, keepdims=True)
        # pdeltas_normalized = np.divide(
        #     pdeltas, pdeltas_sum, out=np.zeros_like(pdeltas), where=pdeltas_sum != 0
        # )

        # # Sample from categorical distribution
        # cumulative_probs = np.cumsum(pdeltas_normalized, axis=1)
        # random_values = np.random.rand(batch_size, 1)
        # deltas_sampled = np.sum(random_values > cumulative_probs, axis=1)

        # # Calculate future observations indices
        # middle_goal_indx = indx + deltas_sampled

        goal_indx = np.where(
            np.random.rand(batch_size) < p_trajgoal / (1.0 - p_currgoal),
            middle_goal_indx,
            goal_indx,
        )
        # ===== not used ======
        # # Goals at the current state
        # goal_indx = np.where(np.random.rand(batch_size) < p_currgoal, indx, goal_indx)

        # batch = self.dataset.sample(batch_size, indx)
        # goal_indx = self.sample_goals(indx)
        # =====================

        success = indx == goal_indx

        success_rewards = success.astype(float)[:, None]
        masks = 1.0 - success.astype(float)[:, None]
        if self._pixel_dim is None:
            goals = self._buffer["next_obs"][goal_indx]
        else:
            goals = self._buffer["next_obs_pixel"][goal_indx]

        assert self._buffer["cur_exploration"].dtype == bool

        # # TODO: check the latency
        # # for trajgoal (goal for updating the policy):
        # # if not exploration trajectory, current option is the goal
        # # if exploration trajectory (cur_exploration = 1 later timesteps), sample one of them.
        # episode_indices = initial_indx[:, None] + np.arange(epilengths[0])  # N, L
        # episode_indices = np.minimum(episode_indices, self.n_transitions_stored - 1)
        # cur_exploration = self._buffer["cur_exploration"][episode_indices].squeeze(
        #     -1
        # )  # N, L
        # assert cur_exploration.ndim == 2
        # exploration_start_indx = np.full(batch_size, -1)

        # # Find the indices where the value is True
        # # Check if any True value exists in each row
        # # Set the indices where True is found
        # true_indices = np.argmax(cur_exploration, axis=1)
        # has_true = np.any(cur_exploration, axis=1)
        # exploration_start_indx[has_true] = true_indices[has_true]
        # exploration_start_indx = np.minimum(exploration_start_indx, final_state_indx)

        # Current strategy:
        # 20% real, 80% future for non-exploratory episodes
        # if exploration policy, use future only

        futures_indx = np.random.randint(
            indx,
            final_state_indx + 1,
            size=batch_size,
        )
        if self.debug_noconst == 21:
            # future with a geometric distribution
            us = np.random.rand(batch_size)
            futures_indx = np.minimum(
                indx + np.ceil(np.log(1 - us) / np.log(discount)).astype(int),
                final_state_indx,
            )

        cur_exploration = self._buffer["cur_exploration"][indx].squeeze(-1)
        if self._pixel_dim is None:
            futures = self._buffer["next_obs"][futures_indx]
        else:
            futures = self._buffer["next_obs_pixel"][futures_indx]

        if self._pixel_dim is None:
            futures = self._buffer["next_obs"][futures_indx]
        else:
            futures = self._buffer["next_obs_pixel"][futures_indx]

        if self._pixel_dim is None:
            futures = self._buffer["next_obs"][futures_indx]
            options = self._buffer["options"][indx]
        else:
            futures = self._buffer["next_obs_pixel"][futures_indx]
            options = self._buffer["options_pixel"][indx]

        if self.debug_noconst == 40:
            # always use future because we do not have goal given in training time
            trajgoals = futures
        else:
            trajgoals = np.where(
                ((np.random.rand(batch_size) < 0.8) | cur_exploration).reshape(-1, 1),
                futures,
                options,
            )

        if self.debug_noconst == 11:
            # trajgoals: simply sample from later trajectories
            k = 13
            trajgoal_indx = np.random.randint(
                low=indx, high=final_state_indx + 1, size=(k, batch_size)
            ).transpose(1, 0)
            if self._pixel_dim is None:
                trajgoals = self._buffer["obs"][trajgoal_indx]
            else:
                trajgoals = self._buffer["obs_pixel"][trajgoal_indx]
        elif self.debug_noconst == 12:
            raise NotImplementedError
            # 20% real goal, 80% geom sampled for potential goals
            us = np.random.rand(batch_size)
            trajgoal_indx = np.minimum(
                indx + np.ceil(np.log(1 - us) / np.log(1e-5)).astype(int),
                final_state_indx,
            )
            if self._pixel_dim is None:
                potential_goals = self._buffer["obs"][trajgoal_indx]
            else:
                potential_goals = self._buffer["obs_pixel"][trajgoal_indx]

            trajgoals = np.where(
                (exploration_start_indx == -1)[:, None],
                self._buffer["options"][indx],
                potential_goals,
            )

        # self._buffer is a dictionary where key: value is like obs: ndarray
        if self._pixel_dim is not None:
            ret_dict = {}
            keys = set(self._buffer.keys())
            for key in self._pixel_keys:
                pixel_key = f"{key}_pixel"
                state_key = f"{key}_state"
                keys.remove(pixel_key)
                keys.remove(state_key)
                if self._buffer[state_key].shape[1] != 0:
                    ret_dict[key] = np.concatenate(
                        [self._buffer[pixel_key][indx], self._buffer[state_key][indx]],
                        axis=1,
                    )
                else:
                    ret_dict[key] = self._buffer[pixel_key][indx]
            for key in keys:
                ret_dict[key] = self._buffer[key][indx]
        else:
            ret_dict = {key: buf_arr[indx] for key, buf_arr in self._buffer.items()}

        assert (
            goals.ndim
            == success_rewards.ndim
            == masks.ndim
            == ret_dict["rewards"].ndim
            == 2
        )

        ret_dict["goals"] = goals
        ret_dict["trajgoals"] = (
            trajgoals  # trajgoal: goal that we will update our policy with.
        )
        ret_dict["success_rewards"] = success_rewards
        ret_dict["masks"] = masks
        return ret_dict

    def sample_transitions_with_future_and_initial_observations(
        self, batch_size, future_observation_discount=0.99
    ):
        """Sample a batch of transitions from the buffer.

        Args:
            batch_size (int): Number of transitions to sample.

        Returns:
            dict: A dict of arrays of shape (batch_size, flat_dim).

        """
        # reference:

        # indices = torch.as_tensor(indices)
        # eindices = self.indices_to_episode_indices[indices]
        # obs_indices = indices + eindices  # index for `observation`: skip the s_last from previous episodes
        # obs = self.get_observations(obs_indices)
        # nobs = self.get_observations(obs_indices + 1)

        # terminals = self.raw_data.terminals[indices]

        # tindices = self.indices_to_episode_timesteps[indices]
        # epilengths = self.raw_data.episode_lengths[eindices]  # max idx is this
        # deltas = torch.arange(self.max_episode_length)
        # pdeltas = torch.where(
        #     # test tidx + 1 + delta <= max_idx = epi_length
        #     (tindices[:, None] + deltas) < epilengths[:, None],
        #     self.future_observation_discount ** deltas,
        #     0,
        # )
        # deltas = torch.distributions.Categorical(
        #     probs=pdeltas,
        # ).sample()
        # future_observations = self.get_observations(obs_indices + 1 + deltas)

        # return BatchData(
        #     observations=obs,
        #     actions=self.raw_data.actions[indices],
        #     next_observations=nobs,
        #     future_observations=future_observations,
        #     rewards=self.raw_data.rewards[indices],
        #     terminals=terminals,
        #     timeouts=self.raw_data.timeouts[indices],
        # )

        idx = np.random.choice(self._transitions_stored, batch_size)
        timesteps = self._buffer["timesteps"][idx]
        epilengths = self._buffer["epilength"][idx]

        deltas = np.arange(self.max_episode_length)  # suppose 1000: max_episode_length
        pdeltas = np.where(
            # test tidx + 1 + delta <= max_idx = epi_length
            (timesteps[:, None] + deltas) < epilengths[:, None],
            future_observation_discount**deltas,
            0,
        )
        deltas = (
            torch.distributions.Categorical(
                probs=torch.from_numpy(pdeltas),
            )
            .sample()
            .squeeze(-1)
            .numpy()
        )

        assert idx.shape == deltas.shape
        if self._pixel_dim is not None:
            pixel_key = "next_obs_pixel"
            state_key = "next_obs_state"
            if self._buffer[state_key].shape[1] != 0:
                future_observations = np.concatenate(
                    [
                        self._buffer[pixel_key][idx + deltas],
                        self._buffer[state_key][idx + deltas],
                    ],
                    axis=1,
                )
            else:
                future_observations = self._buffer[pixel_key][idx + deltas]
        else:
            future_observations = self._buffer["next_obs"][idx + deltas]

        # self._buffer is a dictionary where key: value is like obs: ndarray
        if self._pixel_dim is not None:
            ret_dict = {}
            keys = set(self._buffer.keys())
            for key in self._pixel_keys:
                pixel_key = f"{key}_pixel"
                state_key = f"{key}_state"
                keys.remove(pixel_key)
                keys.remove(state_key)
                if self._buffer[state_key].shape[1] != 0:
                    ret_dict[key] = np.concatenate(
                        [self._buffer[pixel_key][idx], self._buffer[state_key][idx]],
                        axis=1,
                    )
                else:
                    ret_dict[key] = self._buffer[pixel_key][idx]
            for key in keys:
                ret_dict[key] = self._buffer[key][idx]
        else:
            ret_dict = {key: buf_arr[idx] for key, buf_arr in self._buffer.items()}

        ret_dict["future_obs"] = future_observations
        return ret_dict

    def sample_trajectories(self, batch_size):
        """Sample a batch of trajectories from the buffer.

        Args:
            batch_size (int): Number of trajectories to sample.

        Returns:
            dict: A dict of arrays of shape (batch_size, path_len, flat_dim).

        """
        idx = np.random.choice(len(self._path_segments), batch_size)
        path_segments = [self._path_segments[i] for i in idx]

        if self._pixel_dim is not None:
            raise NotImplementedError
            # ret_dict = {}
            # keys = set(self._buffer.keys())
            # for key in self._pixel_keys:
            #     pixel_key = f'{key}_pixel'
            #     state_key = f'{key}_state'
            #     keys.remove(pixel_key)
            #     keys.remove(state_key)
            #     if self._buffer[state_key].shape[1] != 0:
            #         ret_dict[key] = np.concatenate([self._buffer[pixel_key][idx], self._buffer[state_key][idx]], axis=1)
            #     else:
            #         ret_dict[key] = self._buffer[pixel_key][idx]
            # for key in keys:
            #     ret_dict[key] = self._buffer[key][idx]
            # return ret_dict
        else:
            # sample from self._path_segments (deque)
            return {
                key: np.stack(
                    [
                        np.concatenate(
                            [buf_arr[first_seg], buf_arr[second_seg]], axis=0
                        )
                        for (first_seg, second_seg) in path_segments
                    ],
                    axis=0,
                )
                for key, buf_arr in self._buffer.items()
            }
            # todo: add padding when stacking

    def _next_path_segments(self, n_indices):
        """Compute where the next path should be stored.

        Args:
            n_indices (int): Path length.

        Returns:
            tuple: Lists of indices where path should be stored.

        Raises:
            ValueError: If path length is greater than the size of buffer.

        """
        if n_indices > self._capacity:
            raise ValueError("Path is too long to store in buffer.")
        start = self._first_idx_of_next_path
        end = start + n_indices
        if end > self._capacity:
            second_end = end - self._capacity
            return (range(start, self._capacity), range(0, second_end))
        else:
            return (range(start, end), range(0, 0))

    def _get_or_allocate_key(self, key, array):
        """Get or allocate key in the buffer.

        Args:
            key (str): Key in buffer.
            array (numpy.ndarray): Array corresponding to key.

        Returns:
            numpy.ndarray: A NumPy array corresponding to key in the buffer.

        """
        buf_arr = self._buffer.get(key, None)
        if buf_arr is None:
            buf_arr = np.zeros((self._capacity, array.shape[1]), array.dtype)
            self._buffer[key] = buf_arr
        return buf_arr

    def clear(self):
        """Clear buffer."""
        self._transitions_stored = 0
        self._first_idx_of_next_path = 0
        self._path_segments.clear()
        self._buffer.clear()

    @staticmethod
    def _get_path_length(path):
        """Get path length.

        Args:
            path (dict): Path.

        Returns:
            length: Path length.

        Raises:
            ValueError: If path is empty or has inconsistent lengths.

        """
        length_key = None
        length = None
        for key, value in path.items():
            if length is None:
                length = len(value)
                length_key = key
            elif len(value) != length:
                raise ValueError(
                    "path has inconsistent lengths between "
                    "{!r} and {!r}.".format(length_key, key)
                )
        if not length:
            raise ValueError("Nothing in path")
        return length

    @staticmethod
    def _segments_overlap(seg_a, seg_b):
        """Compute if two segments overlap.

        Args:
            seg_a (range): List of indices of the first segment.
            seg_b (range): List of indices of the second segment.

        Returns:
            bool: True iff the input ranges overlap at at least one index.

        """
        # Empty segments never overlap.
        if not seg_a or not seg_b:
            return False
        first = seg_a
        second = seg_b
        if seg_b.start < seg_a.start:
            first, second = seg_b, seg_a
        assert first.start <= second.start
        return first.stop > second.start

    @property
    def n_transitions_stored(self):
        """Return the size of the replay buffer.

        Returns:
            int: Size of the current replay buffer.

        """
        return int(self._transitions_stored)


if __name__ == "__main__":
    # Test
    buffer = PathBufferEx(50, None, max_episode_length=200)
    # for _ in range(5):
    #     buffer.add_path(
    #         {
    #             "observations": np.random.rand(10, 3 * 64 * 64),
    #             "next_observations": np.random.rand(10, 3 * 64 * 64),
    #             "rewards": np.random.randn(10, 1),
    #             "actions": np.random.rand(10, 3),
    #             "terminals": np.zeros((10, 1), dtype=np.int64),
    #         }
    #     )

    # # print(buffer.sample_transitions(10))
    # print(buffer.sample_transitions_with_future_and_initial_observations(10).keys())
    # for k, v in buffer.sample_transitions_with_future_and_initial_observations(
    #     10
    # ).items():
    #     print(k, v.shape)

    for _ in range(100):
        buffer.add_path(
            {
                "obs": np.arange(10)[:, None],
                "next_obs": np.arange(10)[:, None] + 1,
                "options": np.random.randn(10, 4),
                "rewards": np.random.randn(10, 1),
                "actions": np.random.rand(10, 3),
                "dones": np.zeros((10, 1), dtype=np.int64),
                "cur_exploration": np.random.rand(10)[:, None] < 0.5,
            }
        )

    # print(buffer.sample_transitions(2))
    # for _ in range(1000):
    #     print(buffer.sample_transitions_with_goals(10))
    #     assert np.all(buffer.sample_transitions_with_goals(2)["goals"] <= 10)
    print(buffer.sample_transitions(2))
    for _ in range(1000):
        print(buffer.sample_transitions_with_goals(10))
        assert np.all(buffer.sample_transitions_with_goals(2)["goals"] <= 10)

    for _ in range(1000):
        samples = buffer.sample_transitions_with_goals(2)
        # assert np.all(samples["goals"] > samples["obs"]), samples
        # # this may not be the case since there are random goals.
