import numpy as np
import torch as th



class PairedSAEBuffer:

    def __init__(self, capacity: int, obs_shape, act_shape, device="cpu", dtype_np=np.float32):
        self.capacity = int(capacity)
        self.device = device
        self.ptr = 0
        self.size_ = 0
        self.obs = np.zeros((capacity, *obs_shape), dtype=dtype_np)
        self.a_agent = np.zeros((capacity, *act_shape), dtype=dtype_np)
        self.a_human = np.zeros((capacity, *act_shape), dtype=dtype_np)

    @staticmethod
    def _to_numpy(x):
        if isinstance(x, np.ndarray):
            return x
        if th.is_tensor(x):
            return x.detach().cpu().numpy()
        return np.asarray(x)

    def add_pair(self, s, a_agent, a_human):
        s = self._to_numpy(s)
        a_agent = self._to_numpy(a_agent)
        a_human = self._to_numpy(a_human)
        self.obs[self.ptr] = s
        self.a_agent[self.ptr] = a_agent
        self.a_human[self.ptr] = a_human
        self.ptr = (self.ptr + 1) % self.capacity
        self.size_ = min(self.size_ + 1, self.capacity)

    def size(self) -> int:
        return self.size_

    def sample_pairs(self, k: int):
        assert self.size_ > 0, "PairedSAEBuffer is empty."
        idx = np.random.randint(0, self.size_, size=k)
        s = self.obs[idx]           # [k, obs_dim]
        a_h = self.a_human[idx]     # [k, act_dim]
        a_a = self.a_agent[idx]     # [k, act_dim]

        states = np.concatenate([s, s], axis=0)              # [2k, obs_dim]
        actions = np.concatenate([a_h, a_a], axis=0)         # [2k, act_dim]
        labels = np.concatenate([
            np.ones((k, 1), dtype=np.float32),               # human -> 1
            np.zeros((k, 1), dtype=np.float32)               # agent -> 0
        ], axis=0)

        perm = np.random.permutation(2 * k)
        states = th.as_tensor(states[perm], dtype=th.float32, device=self.device)
        actions = th.as_tensor(actions[perm], dtype=th.float32, device=self.device)
        labels = th.as_tensor(labels[perm], dtype=th.float32, device=self.device)
        return states, actions, labels