
import random
import numpy as np
import torch

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

class ReplayBuffer:
    def __init__(self, obs_dim, capacity=100_000):
        self.capacity = int(capacity)
        self.ptr = 0
        self.full = False
        self.obs = np.zeros((self.capacity, obs_dim), dtype=np.float32)
        self.next_obs = np.zeros((self.capacity, obs_dim), dtype=np.float32)
        self.act = np.zeros((self.capacity,), dtype=np.int64)
        self.rew = np.zeros((self.capacity,), dtype=np.float32)
        self.done = np.zeros((self.capacity,), dtype=np.float32)

    def add(self, obs, act, rew, next_obs, done):
        i = self.ptr
        self.obs[i] = obs
        self.act[i] = act
        self.rew[i] = rew
        self.next_obs[i] = next_obs
        self.done[i] = done
        self.ptr = (self.ptr + 1) % self.capacity
        if self.ptr == 0:
            self.full = True

    def __len__(self):
        return self.capacity if self.full else self.ptr

    def sample(self, batch_size=256):
        n = len(self)
        idx = np.random.randint(0, n, size=batch_size)
        return dict(
            obs=torch.tensor(self.obs[idx], dtype=torch.float32),
            act=torch.tensor(self.act[idx], dtype=torch.long),
            rew=torch.tensor(self.rew[idx], dtype=torch.float32),
            next_obs=torch.tensor(self.next_obs[idx], dtype=torch.float32),
            done=torch.tensor(self.done[idx], dtype=torch.float32),
        )

class Logger:
    def __init__(self):
        self.last = {}

    def log(self, **kwargs):
        self.last.update(kwargs)
        msg = " | ".join(f"{k}={v}" for k, v in kwargs.items())
        print(msg, flush=True)


# ----------------- PER Buffer & N-step wrapper (for Rainbow) -----------------
class PERBuffer:
    def __init__(self, obs_dim, capacity=100_000, alpha=0.6, beta=0.4, beta_increment=1e-6, eps=1e-6):
        self.alpha = alpha
        self.beta = beta
        self.beta_inc = beta_increment
        self.eps = eps
        self.obs_dim = obs_dim
        self.capacity = int(capacity)
        self.ptr = 0
        self.full = False
        self.obs = np.zeros((self.capacity, obs_dim), dtype=np.float32)
        self.next_obs = np.zeros((self.capacity, obs_dim), dtype=np.float32)
        self.act = np.zeros((self.capacity,), dtype=np.int64)
        self.rew = np.zeros((self.capacity,), dtype=np.float32)
        self.done = np.zeros((self.capacity,), dtype=np.float32)
        self.prior = np.ones((self.capacity,), dtype=np.float32)  # default priority

    def add(self, obs, act, rew, next_obs, done, priority=None):
        i = self.ptr
        self.obs[i] = obs; self.act[i] = act; self.rew[i] = rew
        self.next_obs[i] = next_obs; self.done[i] = done
        if priority is None:
            p = self.prior.max() if self.ptr>0 or self.full else 1.0
        else:
            p = float(priority)
        self.prior[i] = p
        self.ptr += 1
        if self.ptr >= self.capacity:
            self.ptr = 0; self.full = True

    def sample(self, batch):
        n = self.capacity if self.full else self.ptr
        if n == 0:
            raise ValueError("PERBuffer empty")
        probs = (self.prior[:n] + self.eps) ** self.alpha
        probs = probs / probs.sum()
        idx = np.random.choice(n, size=batch, p=probs)
        weights = (n * probs[idx]) ** (-self.beta)
        weights = weights / weights.max()
        self.beta = min(1.0, self.beta + self.beta_inc)

        return dict(
            idx=idx,
            obs=torch.tensor(self.obs[idx], dtype=torch.float32),
            act=torch.tensor(self.act[idx], dtype=torch.long),
            rew=torch.tensor(self.rew[idx], dtype=torch.float32),
            next_obs=torch.tensor(self.next_obs[idx], dtype=torch.float32),
            done=torch.tensor(self.done[idx], dtype=torch.float32),
            weights=torch.tensor(weights, dtype=torch.float32),
        )

    def update_priorities(self, idx, td_errors):
        p = np.abs(td_errors) + self.eps
        p = p.astype(np.float32)
        self.prior[idx] = p

class NStepHelper:
    def __init__(self, n=3, gamma=0.99):
        self.n = int(n); self.gamma = float(gamma)
        self.buf = []  # list of (s,a,r)
    def push(self, s, a, r):
        self.buf.append((s, a, r))
    def pop_ready(self, next_s, done):
        # returns (s0, a0, R_nstep, next, done_n)
        R = 0.0
        for i,(s_i,a_i,r_i) in enumerate(self.buf):
            R += (self.gamma ** i) * r_i
        s0, a0, _ = self.buf[0]
        return s0, a0, R, next_s, done
    def step(self):
        if self.buf:
            self.buf.pop(0)
    def clear(self):
        self.buf.clear()

