import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from torch import nn

from .base import MixtureLayer

DEBUG_IDENTITY = False


def get_activation(activation: str) -> nn.Module:
    if activation == "identity":
        return nn.Identity()
    if activation == "relu":
        return nn.ReLU()
    if activation == "leaky_relu":
        return nn.LeakyReLU()
    if activation == "gelu":
        return nn.GELU()
    if activation == "tanh":
        return nn.Tanh()
    raise ValueError(f"Invalid activation function: {activation}")


class MixtureOfSoftmaxes(MixtureLayer):
    def __init__(
        self,
        hidden_dim: int,
        num_components: int,
        *,
        input_dropout: float = 0.0,
        component_dropout: float = 0.0,
        component_activation: str = "identity",
        entropy_reg_weight: float = 0.0,
        entropy_of_individual_components: bool = True,
        diversity_reg_weight: float = 0.0,
    ) -> None:
        """Initialize Mixture of Softmaxes layer.

        Args:
            hidden_dim: Dimension of input hidden states
            num_components: Number of mixture components (K)
            num_classes: Number of output classes. Required if use_external_embeddings=False
            input_dropout: Dropout probability for input hidden states
            component_dropout: Dropout probability for context vectors
            component_activation: Activation function for the MLP projections
            entropy_reg_weight: Weight for entropy regularization of mixture weights
            entropy_of_individual_components: Whether to compute entropy of individual components
            diversity_reg_weight: Weight for diversity regularization of projections

        """
        name = f"mix||{num_components}"
        if entropy_reg_weight > 0:
            if entropy_of_individual_components:
                name = name + f"_entsamp{entropy_reg_weight}"
            else:
                name = name + f"_ent{entropy_reg_weight}"
        name = name + f"_div{diversity_reg_weight}" if diversity_reg_weight > 0 else name
        super().__init__(num_components, name=name, return_log_prob=True)
        self.entropy_reg_weight = entropy_reg_weight
        self.entropy_of_individual_components = entropy_of_individual_components
        self.diversity_reg_weight = diversity_reg_weight

        self.input_dropout = nn.Dropout(input_dropout)
        self.component_activation = get_activation(component_activation)
        self.component_dropout = nn.Dropout(component_dropout)

        # The following defines the mixture weights parameters (w_pi in the MoS paper).
        # These are target vectors for each component. We measure the interaction of
        # the encoded hidden state with each component vector, so we can interpret them
        # as "object embeddings" for each component.
        self.mixture_classifier = nn.Linear(hidden_dim, num_components)
        torch.nn.init.xavier_uniform_(self.mixture_classifier.weight)

        # projections
        self.layer_weights = nn.Parameter(torch.randn(num_components * hidden_dim, hidden_dim))
        torch.nn.init.xavier_uniform_(self.layer_weights)
        self.layer_bias = nn.Parameter(torch.randn(num_components * hidden_dim))
        self.batch_norm = nn.BatchNorm1d(num_components * hidden_dim)

        # Initialize entropy term to 0.0
        self.current_entropy = 0.0

    def forward(
        self,
        hidden_states: torch.Tensor,
        class_embeddings: torch.Tensor | None = None,
        *,
        return_prob: bool = False,
        debug: bool = False,
    ) -> torch.Tensor:
        """Compute Mixture of Softmaxes with numerically stable operations.

        Args:
            hidden_states: (batch, hidden_dim)
            class_embeddings: (num_classes, hidden_dim) Required if use_external_embeddings=True
            return_prob: Whether to return probabilities or log probabilities

        Returns:
            (batch, num_classes)

        """
        hidden_states = self.input_dropout(hidden_states)

        # Compute mixture weights
        pi_logits = self.mixture_classifier(hidden_states)  # (batch, num_components)
        log_pi = F.log_softmax(pi_logits, dim=-1)  # (batch, num_components)

        # Compute entropy for regularization (if enabled)
        if self.entropy_reg_weight > 0 and self.training:
            if self.entropy_of_individual_components:
                pi_distrib = torch.distributions.Categorical(logits=pi_logits)
                self.current_entropy = pi_distrib.entropy().mean()
            else:
                pi_probs = torch.softmax(pi_logits, dim=-1)
                avg_pi_probs = pi_probs.mean(dim=0)
                self.current_entropy = -torch.sum(avg_pi_probs * torch.log(avg_pi_probs + 1e-10))
        else:
            self.current_entropy = 0.0

        B, D = hidden_states.shape
        K = self.num_components
        W = self.layer_weights  # (K*D, D)
        b = self.layer_bias  # (K*D)
        h = F.linear(hidden_states, W, b)  # (B, D) @ (D, K*D) -> (B, K*D)
        h = self.batch_norm(h)
        h = self.component_activation(h)
        h = self.component_dropout(h)
        h = rearrange(h, "b (k d) -> b k d", k=K)

        # Compute component-wise logits
        logits = torch.matmul(h, class_embeddings.t())  # (batch, K, num_classes)
        # Compute softmax probabilities per component
        log_softmax_outputs = F.log_softmax(logits, dim=-1)  # (batch, K, num_classes)
        # Compute final probability distribution using einsum
        num_classes = logits.size(-1)
        log_pi = repeat(log_pi, "b k -> b k num_classes", num_classes=num_classes)
        log_outputs = torch.logsumexp(log_pi + log_softmax_outputs, dim=1)

        return torch.exp(log_outputs) if return_prob else log_outputs

    def regularization_term(self) -> torch.Tensor:
        """Add regularization to the projection weights, mixture entropy, and projection diversity."""
        reg = 0.0

        # Add entropy regularization term (negative because we want to maximize entropy)
        if self.entropy_reg_weight > 0:
            reg -= self.entropy_reg_weight * self.current_entropy

        return reg
