import numpy as np
import torch


class ReplayBuffer:
    """Buffer to store environment transitions."""

    def __init__(self, obs_shape, action_shape, capacity, device):
        self.capacity = capacity
        self.device = device

        # The proprioceptive obs is stored as float32, pixel obs as uint8
        obs_dtype = np.float32 if len(obs_shape) == 1 else np.uint8

        self.obses = np.empty((capacity, *obs_shape), dtype=obs_dtype)
        self.next_obses = np.empty((capacity, *obs_shape), dtype=obs_dtype)
        self.actions = np.empty((capacity, *action_shape), dtype=np.float32)
        self.rewards = np.empty((capacity, 1), dtype=np.float32)
        self.costs = np.empty((capacity, 1), dtype=np.float32)
        self.not_dones = np.empty((capacity, 1), dtype=np.float32)
        self.not_dones_no_max = np.empty((capacity, 1), dtype=np.float32)

        self.idx = 0
        self.full = False

    def __len__(self):
        return self.capacity if self.full else self.idx

    def add(self, obs, action, reward, cost, next_obs, done, done_no_max):
        """Add a new transition to the buffer."""
        np.copyto(self.obses[self.idx], obs)
        np.copyto(self.actions[self.idx], action)
        np.copyto(self.rewards[self.idx], reward)
        np.copyto(self.costs[self.idx], cost)
        np.copyto(self.next_obses[self.idx], next_obs)
        np.copyto(self.not_dones[self.idx], not done)
        np.copyto(self.not_dones_no_max[self.idx], not done_no_max)

        self.idx = (self.idx + 1) % self.capacity
        self.full = self.full or self.idx == 0

    def _convert_to_tensor(self, array, dtype=torch.float32):
        """Helper function to convert arrays to tensors."""
        return torch.as_tensor(array, device=self.device).to(dtype)

    def sample(self, batch_size):
        """Sample a batch of transitions."""
        max_idx = self.capacity if self.full else self.idx
        idxs = np.random.randint(0, max_idx, size=batch_size)

        obses = self._convert_to_tensor(self.obses[idxs])
        actions = self._convert_to_tensor(self.actions[idxs])
        rewards = self._convert_to_tensor(self.rewards[idxs])
        costs = self._convert_to_tensor(self.costs[idxs])
        next_obses = self._convert_to_tensor(self.next_obses[idxs])
        not_dones = self._convert_to_tensor(self.not_dones[idxs])
        not_dones_no_max = self._convert_to_tensor(self.not_dones_no_max[idxs])

        return obses, actions, rewards, costs, next_obses, not_dones, not_dones_no_max
