import abc

import torch


class BaseModel(torch.nn.Module, abc.ABC):
    """Base class for models for discrete diffusion samplers."""

    def __init__(self, ndim: int, vocab_size: int) -> None:
        """Initialise the model.

        Args:
            ndim: Length of an input/output sequence.
            vocab_size: The number of unique tokens in the vocabulary (including mask token).
        """
        super().__init__()
        self.ndim = ndim
        self.vocab_size = vocab_size

    @abc.abstractmethod
    def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
        """Forward pass.

        Args:
            x: (batch_size, ndim) tensor of input sequences,
                where each element is an integer in {0, 1, ..., vocab_size - 1}.
            **kwargs: Additional keyword arguments (e.g., timestep)

        Returns:
            (batch_size, ...) tensor of output logits, where ... depends on the specific model.
        """
        raise NotImplementedError


class MaskedDiffusionModel(BaseModel, abc.ABC):
    """Base class for masked diffusion models."""

    @abc.abstractmethod
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass.

        Args:
            x: (batch_size, ndim) tensor of input sequences,
                where each element is an integer in {0, 1, ..., vocab_size - 1}.
                The last token is the mask token.

        Returns:
            (batch_size, ndim, vocab_size) tensor of log probabilities.
        """
        raise NotImplementedError
