from abc import ABC, abstractmethod

import torch


class BaseMemory(ABC, torch.nn.Module):

    def __init__(self, latent_size):
        super().__init__()

        self.latent_size = latent_size

    @abstractmethod
    def forward(self, z: torch.Tensor, h: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        pass
