from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer, TensorDictPrioritizedReplayBuffer
from tensordict import TensorDict

from config import BUFFER_PRIO


class ReplayMemoryMARL(object):
    prio_alpha = 0.4
    prio_eps = 1e-7

    def __init__(self, capacity, batch_size):
        self.capacity, self.batch_size = capacity, batch_size
        self.memory = None
        self.reset()

    def push(self, dct: dict):
        transition = TensorDict(dct, batch_size=len(list(dct.values())[0]))
        self.memory.extend(transition)

    def sample(self, batch_size=None):
        return self.memory.sample(batch_size)

    def update_priority(self, transitions):
        self.memory.update_tensordict_priority(transitions)

    def reset(self):
        del self.memory
        if BUFFER_PRIO:
            self.memory = TensorDictPrioritizedReplayBuffer(
                alpha=self.prio_alpha,
                beta=0,  # redundant because prioritized sampling == importance weighting
                eps=self.prio_eps,
                storage=LazyTensorStorage(self.capacity),
                batch_size=self.batch_size,
            )
        else:
            self.memory = TensorDictReplayBuffer(
                storage=LazyTensorStorage(self.capacity),
                batch_size=self.batch_size,
            )

    def __len__(self):
        return len(self.memory)
