import torch

class ErrorBuffer:
    def __init__(self, num_samples):
        self.maps = [None] * num_samples

    def update_batch(self, sample_ids, node_errors_split):
        """
        sample_ids: Tensor/list [B]
        node_errors_split: list of Tensors; mỗi tensor shape [Ni, 1] hoặc [Ni]
        """
        sid_list = sample_ids.tolist() if torch.is_tensor(sample_ids) else list(sample_ids)
        assert len(sid_list) == len(node_errors_split)
        for sid, err in zip(sid_list, node_errors_split):
            e = err.detach().cpu().squeeze(-1)
            self.maps[sid] = e

    def get_all(self):
        return self.maps
