from __future__ import annotations

from dataclasses import dataclass
import numpy as np


@dataclass
class PrioritizedBatch:
    obs: np.ndarray
    act: np.ndarray
    rew: np.ndarray
    next_obs: np.ndarray
    done: np.ndarray
    cost: np.ndarray
    idxs: np.ndarray
    weights: np.ndarray


class PrioritizedReplayBuffer:
    def __init__(self, obs_dim: int, act_dim: int, size: int, alpha: float = 0.6, beta: float = 0.4):
        self.size = size
        self.alpha = alpha
        self.beta = beta
        self.ptr = 0
        self.full = False
        self.obs = np.zeros((size, obs_dim), dtype=np.float32)
        self.act = np.zeros((size, act_dim), dtype=np.float32)
        self.rew = np.zeros((size, 1), dtype=np.float32)
        self.next_obs = np.zeros((size, obs_dim), dtype=np.float32)
        self.done = np.zeros((size, 1), dtype=np.float32)
        self.cost = np.zeros((size, 1), dtype=np.float32)
        self.prio = np.ones(size, dtype=np.float32)
        self.eps = 1e-6

    def add(self, obs, act, rew, next_obs, done, cost):
        i = self.ptr
        self.obs[i] = obs
        self.act[i] = act
        self.rew[i] = rew
        self.next_obs[i] = next_obs
        self.done[i] = done
        self.cost[i] = cost
        self.prio[i] = self.prio.max() if self.ptr > 0 else 1.0
        self.ptr = (self.ptr + 1) % self.size
        if self.ptr == 0:
            self.full = True

    def __len__(self):
        return self.size if self.full else self.ptr

    def _prob(self):
        n = len(self)
        p = self.prio[:n] ** self.alpha
        p = p / p.sum()
        return p

    def sample(self, batch_size: int, rng: np.random.Generator | None = None) -> PrioritizedBatch:
        if rng is None:
            rng = np.random.default_rng()
        n = len(self)
        p = self._prob()
        idxs = rng.choice(n, size=batch_size, replace=n < batch_size, p=p)
        weights = (n * p[idxs]) ** (-self.beta)
        weights = (weights / weights.max()).astype(np.float32)
        return PrioritizedBatch(
            obs=self.obs[idxs],
            act=self.act[idxs],
            rew=self.rew[idxs],
            next_obs=self.next_obs[idxs],
            done=self.done[idxs],
            cost=self.cost[idxs],
            idxs=idxs,
            weights=weights,
        )

    def update_priorities(self, idxs: np.ndarray, prios: np.ndarray):
        self.prio[idxs] = np.abs(prios).astype(np.float32) + self.eps

