
# ----------------------------
# === Replay Buffer & Trainer ===
# ----------------------------

import threading
from collections import deque
import numpy as np 
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F


class ReplayBuffer:
    def __init__(self, capacity=10000):
        self.buf = deque(maxlen=capacity)
        self.lock = threading.Lock()

    def push(self, examples):
        with self.lock:
            self.buf.extend(examples)

    def sample(self, batch_size):
        with self.lock:
            idx = np.random.choice(len(self.buf), size=batch_size, replace=False)
            return [self.buf[i] for i in idx]

    def __len__(self):
        return len(self.buf)

class Trainer:
    def __init__(self, net: nn.Module, replay: ReplayBuffer, device='cuda', lr=1e-3):
        self.net = net
        self.replay = replay
        self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
        self.optimizer = optim.Adam(self.net.parameters(), lr=lr)
        self.net.to(self.device)

    def train_step(self, batch):
        # batch: list of GameExample
        boards = np.stack([b.board for b in batch], axis=0)  # (B, C, H, W)
        pis = np.stack([b.pi for b in batch], axis=0)
        vals = np.array([b.value for b in batch], dtype=np.float32)
        boards_t = torch.from_numpy(boards).float().to(self.device)
        pis_t = torch.from_numpy(pis).float().to(self.device)
        vals_t = torch.from_numpy(vals).float().to(self.device)
        self.net.train()
        p_logits, v = self.net(boards_t)
        # policy loss (cross entropy with target distribution)
        p_log = F.log_softmax(p_logits, dim=1)
        p_prob = p_log.exp()
        policy_loss = - (pis_t * p_log).sum(dim=1).mean()  # maximize log likelihood
        # policy entropy (for logging only)
        entropy = - (p_prob * p_log).sum(dim=1).mean()
        # value loss (MSE)
        value_loss = F.mse_loss(v, vals_t)
        loss = policy_loss + value_loss
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        return (
            float(loss.item()),
            float(policy_loss.item()),
            float(value_loss.item()),
            float(entropy.item()),
        )