from typing import cast

import numpy as np
from tianshou.data import ReplayBuffer, Batch
from tianshou.data.types import RolloutBatchProtocol


class CustomReplayBuffer(ReplayBuffer):

    _reserved_keys = (
        "obs",
        "act",
        "pi_option",
        "rew",
        "terminated",
        "truncated",
        "done",
        "obs_next",
        "info",
        "policy",
    )
    _input_keys = (
        "obs",
        "act",
        "pi_option",
        "rew",
        "terminated",
        "truncated",
        "obs_next",
        "info",
        "policy",
    )

    def sample(self, batch_size):
        buffer_size = len(self)
        assert batch_size <= buffer_size
        first_batch_index = (self.last_index[0] + 1 - batch_size) % buffer_size
        indices = [(first_batch_index + i) % buffer_size for i in range(batch_size)]
        return self[indices], indices

    def __getitem__(self, index: slice | int | list[int] | np.ndarray) -> RolloutBatchProtocol:
        """Return a data batch: self[index].

        If stack_num is larger than 1, return the stacked obs and obs_next with shape
        (batch, len, ...).
        """
        if isinstance(index, slice):  # change slice to np array
            # buffer[:] will get all available data
            indices = (
                self.sample_indices(0)
                if index == slice(None)
                else self._indices[: len(self)][index]
            )
        else:
            indices = index  # type: ignore
        # raise KeyError first instead of AttributeError,
        # to support np.array([ReplayBuffer()])
        obs = self.get(indices, "obs")
        if self._save_obs_next:
            obs_next = self.get(indices, "obs_next", Batch())
        else:
            obs_next = self.get(self.next(indices), "obs", Batch())
        batch_dict = {
            "obs": obs,
            "act": self.act[indices],
            "pi_option": self.pi_option[indices],
            "rew": self.rew[indices],
            "terminated": self.terminated[indices],
            "truncated": self.truncated[indices],
            "done": self.done[indices],
            "obs_next": obs_next,
            "info": self.get(indices, "info", Batch()),
            # TODO: what's the use of this key?
            "policy": self.get(indices, "policy", Batch()),
        }
        for key in self._meta.__dict__:
            if key not in self._input_keys:
                batch_dict[key] = self._meta[key][indices]
        return cast(RolloutBatchProtocol, Batch(batch_dict))
