import torch
import torch.nn.functional as F
from einops import repeat
from torch import nn

from .base import MixtureLayer


def get_activation(activation: str) -> nn.Module:
    if activation == "identity":
        return nn.Identity()
    if activation == "relu":
        return nn.ReLU()
    if activation == "leaky_relu":
        return nn.LeakyReLU()
    if activation == "gelu":
        return nn.GELU()
    if activation == "tanh":
        return nn.Tanh()
    raise ValueError(f"Invalid activation function: {activation}")


class MixtureOfSoftmaxesMLP(MixtureLayer):
    def __init__(
        self,
        hidden_dim: int,
        num_components: int,
        *,
        input_dropout: float = 0.0,
        component_dropout: float = 0.0,
        component_activation: str = "identity",
        entropy_reg_weight: float = 0.0,  # Weight for entropy regularization
        entropy_of_individual_components: bool = False,
        diversity_reg_weight: float = 0.1,  # Weight for diversity regularization
        mlp_hidden_dim: int | None = None,  # Hidden dimension for the MLP projections
    ) -> None:
        """Initialize Mixture of Softmaxes layer.

        Args:
            hidden_dim: Dimension of input hidden states
            num_components: Number of mixture components (K)
            num_classes: Number of output classes. Required if use_external_embeddings=False
            use_external_embeddings: If True, expects class embeddings to be passed in forward().
                                   If False, creates a learnable projection layer.
            input_dropout: Dropout probability for input hidden states
            component_dropout: Dropout probability for context vectors
            component_activation: Activation function for the MLP projections
            entropy_reg_weight: Weight for entropy regularization of mixture weights
            entropy_of_individual_components: Whether to compute entropy of individual components
            diversity_reg_weight: Weight for diversity regularization of projections
            mlp_hidden_dim: Hidden dimension for the MLP projections. If None, uses hidden_dim

        """
        name = f"mixMLP{num_components}"
        if entropy_reg_weight > 0:
            if entropy_of_individual_components:
                name = name + f"_entsamp{entropy_reg_weight}"
            else:
                name = name + f"_ent{entropy_reg_weight}"
        name = name + f"_div{diversity_reg_weight}" if diversity_reg_weight > 0 else name
        super().__init__(num_components, name=name, return_log_prob=True)
        self.entropy_reg_weight = entropy_reg_weight
        self.entropy_of_individual_components = entropy_of_individual_components
        self.diversity_reg_weight = diversity_reg_weight

        self.input_dropout = nn.Dropout(input_dropout)
        self.component_activation = get_activation(component_activation)
        self.component_dropout = nn.Dropout(component_dropout)

        # Set MLP hidden dimension if not provided
        if mlp_hidden_dim is None:
            mlp_hidden_dim = hidden_dim

        # The following defines the mixture weights parameters (w_pi in the MoS paper).
        # These are target vectors for each component. We measure the interaction of
        # the encoded hidden state with each component vector, so we can interpret them
        # as "object embeddings" for each component.
        self.mixture_classifier = nn.Linear(hidden_dim, num_components)
        torch.nn.init.xavier_uniform_(self.mixture_classifier.weight)

        # Replace single linear projections with 2-layer MLPs for each component
        self.component_projections = nn.ModuleList(
            [
                nn.Sequential(
                    nn.Linear(hidden_dim, mlp_hidden_dim),
                    nn.BatchNorm1d(mlp_hidden_dim),
                    self.component_activation,
                    self.component_dropout,
                    nn.Linear(mlp_hidden_dim, hidden_dim),
                    nn.BatchNorm1d(hidden_dim),  # BatchNorm on output
                    self.component_activation,
                    self.component_dropout,
                )
                for _ in range(num_components)
            ],
        )

        # Initialize MLP weights
        for proj in self.component_projections:
            for layer in proj:
                if isinstance(layer, nn.Linear):
                    torch.nn.init.xavier_uniform_(layer.weight)
                    if layer.bias is not None:
                        torch.nn.init.zeros_(layer.bias)

        # Initialize entropy term to 0.0
        self.current_entropy = 0.0

    def forward(
        self,
        hidden_states: torch.Tensor,
        class_embeddings: torch.Tensor,
        *,
        return_prob: bool = False,
        debug: bool = True,
    ) -> torch.Tensor:
        """Compute Mixture of Softmaxes with numerically stable operations.

        Args:
            hidden_states: (batch, hidden_dim)
            class_embeddings: (num_classes, hidden_dim) Required if use_external_embeddings=True
            return_prob: Whether to return probabilities or log probabilities

        Returns:
            (batch, num_classes)

        """
        hidden_states = self.input_dropout(hidden_states)

        # Compute mixture weights
        pi_logits = self.mixture_classifier(hidden_states)  # (batch, num_components)
        log_pi = F.log_softmax(pi_logits, dim=-1)  # (batch, num_components)

        # Compute entropy for regularization (if enabled)
        if self.entropy_reg_weight > 0 and self.training:
            if self.entropy_of_individual_components:
                pi_distrib = torch.distributions.Categorical(logits=pi_logits)
                self.current_entropy = pi_distrib.entropy().mean()
            else:
                pi_probs = torch.softmax(pi_logits, dim=-1)
                avg_pi_probs = pi_probs.mean(dim=0)
                self.current_entropy = -torch.sum(avg_pi_probs * torch.log(avg_pi_probs + 1e-10))
        else:
            self.current_entropy = 0.0

        h = torch.stack(
            [proj(hidden_states) for proj in self.component_projections],
            dim=1,
        )  # (batch, num_components, hidden_dim)

        logits = torch.matmul(h, class_embeddings.t())  # (batch, K, num_classes)
        # Compute softmax probabilities per component
        log_softmax_outputs = F.log_softmax(logits, dim=-1)  # (batch, K, num_classes)

        # Compute final probability distribution using einsum
        num_classes = logits.size(-1)
        log_pi = repeat(log_pi, "b k -> b k num_classes", num_classes=num_classes)
        log_outputs = torch.logsumexp(log_pi + log_softmax_outputs, dim=1)

        return torch.exp(log_outputs) if return_prob else log_outputs

    def regularization_term(self) -> torch.Tensor:
        """Add regularization to the projection weights, mixture entropy, and projection diversity."""
        reg = 0.0
        # Add entropy regularization term (negative because we want to maximize entropy)
        if self.entropy_reg_weight > 0:
            reg -= self.entropy_reg_weight * self.current_entropy

        return reg
