from abc import ABC, abstractmethod

import torch
from torch import nn


class MixtureLayer(nn.Module, ABC):
    def __init__(self, num_components: int, name: str, *, return_log_prob: bool = False):
        super().__init__()
        self.num_components = num_components
        self.name = name
        self.return_log_prob = return_log_prob

    @abstractmethod
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """Forward pass.

        Args:
            hidden_states: (batch, hidden_dim)

        Returns:
            (batch, num_classes)

        """

    def regularization_term(self) -> torch.Tensor:
        """Regularization term for the mixture layer."""
        return torch.tensor(0.0)
