from __future__ import annotations

from dataclasses import dataclass
from typing import Optional

import numpy as np


@dataclass
class TransitionBatch:
    obs: np.ndarray
    act: np.ndarray
    rew: np.ndarray
    next_obs: np.ndarray
    done: np.ndarray
    cost: np.ndarray


class ReplayBuffer:
    def __init__(self, obs_dim: int, act_dim: int, size: int):
        self.size = size
        self.obs_buf = np.zeros((size, obs_dim), dtype=np.float32)
        self.act_buf = np.zeros((size, act_dim), dtype=np.float32)
        self.rew_buf = np.zeros((size, 1), dtype=np.float32)
        self.next_obs_buf = np.zeros((size, obs_dim), dtype=np.float32)
        self.done_buf = np.zeros((size, 1), dtype=np.float32)
        self.cost_buf = np.zeros((size, 1), dtype=np.float32)
        self.ptr, self.full = 0, False

    def add(self, obs, act, rew, next_obs, done, cost):
        i = self.ptr
        self.obs_buf[i] = obs
        self.act_buf[i] = act
        self.rew_buf[i] = rew
        self.next_obs_buf[i] = next_obs
        self.done_buf[i] = done
        self.cost_buf[i] = cost
        self.ptr = (self.ptr + 1) % self.size
        if self.ptr == 0:
            self.full = True

    def __len__(self):
        return self.size if self.full else self.ptr

    def sample(self, batch_size: int, rng: Optional[np.random.Generator] = None) -> TransitionBatch:
        if rng is None:
            rng = np.random.default_rng()
        max_idx = self.size if self.full else self.ptr
        idxs = rng.integers(0, max_idx, size=batch_size)
        return TransitionBatch(
            obs=self.obs_buf[idxs],
            act=self.act_buf[idxs],
            rew=self.rew_buf[idxs],
            next_obs=self.next_obs_buf[idxs],
            done=self.done_buf[idxs],
            cost=self.cost_buf[idxs],
        )

