import torch
import torch.utils.data

from .list_replay_buffer import SimpleReplayBuffer
from ..utils import numpy_to_torch_dtype_dict


class TorchReplayBuffer(SimpleReplayBuffer, torch.utils.data.dataset.Dataset):
    def __init__(self, env, device='cpu', action_dtype=None, *args, **kwargs):
        self.env = env
        self.device = torch.device(device)
        self.action_dtype = action_dtype
        super().__init__(*args, **kwargs)

    def _buf_init(self, key, max_buf_size):
        state_dtype = numpy_to_torch_dtype_dict[self.env.observation_space.dtype.type]
        action_dtype = numpy_to_torch_dtype_dict[self.env.action_space.dtype.type]
        shapes = {
            'state': (state_dtype, self.env.observation_space.shape),
            'action': (action_dtype, self.env.action_space.shape),
            'next_state': (state_dtype, self.env.observation_space.shape),
            'reward': (torch.float32, ()),
            'done': (torch.bool, ()),
            'timeout': (torch.bool, ()),
        }
        dtype, shape = shapes[key]
        return torch.empty(max_buf_size, *shape, dtype=dtype, device=self.device)

    def _buf_add(self, buf, idx, data):
        buf[idx] = torch.as_tensor(data, device=self.device)

    def _buf_sample(self, buf, indices):
        return buf[indices]
