from __future__ import annotations

from dataclasses import dataclass
from typing import Tuple

import numpy as np


@dataclass
class ReplayBatch:
    obs: np.ndarray          # uint8 [B,1,84,84]
    actions: np.ndarray      # int64 [B]
    next_actions: np.ndarray # int64 [B]
    rewards: np.ndarray      # float32 [B]
    next_obs: np.ndarray     # uint8 [B,1,84,84]
    dones: np.ndarray        # float32 [B] in {0,1}


class ReplayBuffer:
    def __init__(self, capacity: int, *, obs_shape: Tuple[int, int, int]) -> None:
        self.capacity = int(capacity)
        c, h, w = obs_shape
        self.obs = np.zeros((self.capacity, c, h, w), dtype=np.uint8)
        self.next_obs = np.zeros((self.capacity, c, h, w), dtype=np.uint8)
        self.actions = np.zeros((self.capacity,), dtype=np.int64)
        self.next_actions = np.zeros((self.capacity,), dtype=np.int64)
        self.rewards = np.zeros((self.capacity,), dtype=np.float32)
        self.dones = np.zeros((self.capacity,), dtype=np.float32)

        self.ptr = 0
        self.size = 0

    def __len__(self) -> int:
        return int(self.size)

    def add(
        self,
        obs: np.ndarray,
        action: int,
        next_action: int,
        reward: float,
        next_obs: np.ndarray,
        done: bool,
    ) -> None:
        self.obs[self.ptr] = obs
        self.next_obs[self.ptr] = next_obs
        self.actions[self.ptr] = int(action)
        self.next_actions[self.ptr] = int(next_action)
        self.rewards[self.ptr] = float(reward)
        self.dones[self.ptr] = 1.0 if bool(done) else 0.0
        self.ptr = (self.ptr + 1) % self.capacity
        self.size = min(self.size + 1, self.capacity)

    def sample(self, batch_size: int, rng: np.random.Generator) -> ReplayBatch:
        if self.size <= 0:
            raise ValueError("Cannot sample from empty replay buffer.")
        idxs = rng.integers(0, self.size, size=int(batch_size), endpoint=False)
        return ReplayBatch(
            obs=self.obs[idxs],
            actions=self.actions[idxs],
            next_actions=self.next_actions[idxs],
            rewards=self.rewards[idxs],
            next_obs=self.next_obs[idxs],
            dones=self.dones[idxs],
        )

