from __future__ import annotations

from dataclasses import dataclass
from typing import Tuple

import numpy as np


@dataclass
class ReplayBatch:
    obs: np.ndarray        # uint8 [B,1,84,84]
    actions: np.ndarray    # int64 [B]
    rewards: np.ndarray    # float32 [B]
    next_obs: np.ndarray   # uint8 [B,1,84,84]
    dones: np.ndarray      # float32 [B] in {0,1}


class ReplayBuffer:
    def __init__(self, capacity: int, *, obs_shape: Tuple[int, int, int]) -> None:
        self.capacity = int(capacity)
        c, h, w = obs_shape
        self.obs = np.zeros((self.capacity, c, h, w), dtype=np.uint8)
        self.next_obs = np.zeros((self.capacity, c, h, w), dtype=np.uint8)
        self.actions = np.zeros((self.capacity,), dtype=np.int64)
        self.rewards = np.zeros((self.capacity,), dtype=np.float32)
        self.dones = np.zeros((self.capacity,), dtype=np.float32)

        self.ptr = 0
        self.size = 0

    def __len__(self) -> int:
        return int(self.size)

    def add(self, obs: np.ndarray, action: int, reward: float, next_obs: np.ndarray, done: bool) -> None:
        self.obs[self.ptr] = obs
        self.next_obs[self.ptr] = next_obs
        self.actions[self.ptr] = int(action)
        self.rewards[self.ptr] = float(reward)
        self.dones[self.ptr] = 1.0 if bool(done) else 0.0
        self.ptr = (self.ptr + 1) % self.capacity
        self.size = min(self.size + 1, self.capacity)

    def sample(self, batch_size: int, rng: np.random.Generator) -> ReplayBatch:
        if self.size <= 0:
            raise ValueError("Cannot sample from empty replay buffer.")
        idxs = rng.integers(0, self.size, size=int(batch_size), endpoint=False)
        return ReplayBatch(
            obs=self.obs[idxs],
            actions=self.actions[idxs],
            rewards=self.rewards[idxs],
            next_obs=self.next_obs[idxs],
            dones=self.dones[idxs],
        )

