"""
Replay buffer implementation.
"""

from __future__ import annotations

import numpy as np
import torch


class ReplayBuffer:
    """
    Simple replay buffer storing numpy arrays, sampling returns torch tensors on device.
    """

    def __init__(self, capacity, state_dim, device):
        self.capacity = int(capacity)
        self.state_dim = int(state_dim)
        self.device = device

        self.states = np.zeros((self.capacity, self.state_dim), dtype=np.float32)
        self.next_states = np.zeros((self.capacity, self.state_dim), dtype=np.float32)
        self.actions = np.zeros(self.capacity, dtype=np.int64)
        self.rewards = np.zeros(self.capacity, dtype=np.float32)
        self.dones = np.zeros(self.capacity, dtype=np.bool_)

        self.size = 0
        self.ptr = 0
        self.rng = np.random.default_rng()

    def add(self, s, a, r, s_next, done):
        idx = self.ptr
        self.states[idx] = s
        self.next_states[idx] = s_next
        self.actions[idx] = int(a)
        self.rewards[idx] = float(r)
        self.dones[idx] = bool(done)

        self.ptr = (self.ptr + 1) % self.capacity
        self.size = min(self.size + 1, self.capacity)

    def add_batch(self, S, A, R, S_next, D):
        """
        Add a batch of transitions.
        """
        S = np.asarray(S, dtype=np.float32)
        S_next = np.asarray(S_next, dtype=np.float32)
        A = np.asarray(A, dtype=np.int64)
        R = np.asarray(R, dtype=np.float32)
        D = np.asarray(D, dtype=np.bool_)

        K = int(A.shape[0])
        if K <= 0:
            return

        if K >= self.capacity:
            S = S[-self.capacity:]
            S_next = S_next[-self.capacity:]
            A = A[-self.capacity:]
            R = R[-self.capacity:]
            D = D[-self.capacity:]
            K = self.capacity
            self.states[:] = S
            self.next_states[:] = S_next
            self.actions[:] = A
            self.rewards[:] = R
            self.dones[:] = D
            self.size = self.capacity
            self.ptr = 0
            return

        idxs = (self.ptr + np.arange(K)) % self.capacity
        self.states[idxs] = S
        self.next_states[idxs] = S_next
        self.actions[idxs] = A
        self.rewards[idxs] = R
        self.dones[idxs] = D

        self.ptr = (self.ptr + K) % self.capacity
        self.size = min(self.size + K, self.capacity)

    def sample(self, batch_size):
        batch_size = min(int(batch_size), self.size)
        idxs = self.rng.integers(0, self.size, size=batch_size)

        S = torch.from_numpy(self.states[idxs]).to(self.device)
        A = torch.from_numpy(self.actions[idxs]).to(self.device)
        R = torch.from_numpy(self.rewards[idxs]).to(self.device)
        S_next = torch.from_numpy(self.next_states[idxs]).to(self.device)
        D = torch.from_numpy(self.dones[idxs]).to(self.device)

        return S, A, R, S_next, D


__all__ = ["ReplayBuffer"]
