import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from torch import nn

from .base import MixtureLayer


def get_activation(activation: str) -> nn.Module:
    if activation == "identity":
        return nn.Identity()
    if activation == "relu":
        return nn.ReLU()
    if activation == "leaky_relu":
        return nn.LeakyReLU()
    if activation == "gelu":
        return nn.GELU()
    if activation == "tanh":
        return nn.Tanh()
    raise ValueError(f"Invalid activation function: {activation}")


class MixtureOfSoftmaxesMLP(MixtureLayer):
    def __init__(
        self,
        hidden_dim: int,
        num_components: int,
        *,
        input_dropout: float = 0.0,
        component_dropout: float = 0.0,
        component_activation: str = "relu",
        entropy_reg_weight: float = 0.0,
        entropy_of_individual_components: bool = False,
        diversity_reg_weight: float = 0.1,
    ) -> None:
        """Initialize Mixture of Softmaxes layer.

        Args:
            hidden_dim: Dimension of input hidden states
            num_components: Number of mixture components (K)
            num_classes: Number of output classes. Required if use_external_embeddings=False
            use_external_embeddings: If True, expects class embeddings to be passed in forward().
                                   If False, creates a learnable projection layer.
            input_dropout: Dropout probability for input hidden states
            component_dropout: Dropout probability for context vectors
            component_activation: Activation function for the MLP projections
            entropy_reg_weight: Weight for entropy regularization of mixture weights
            entropy_of_individual_components: Whether to compute entropy of individual components
            diversity_reg_weight: Weight for diversity regularization of projections
            mlp_hidden_dim: Hidden dimension for the MLP projections. If None, uses hidden_dim

        """
        name = f"mixMLP||{num_components}"
        if entropy_reg_weight > 0:
            name += f"_ent{'samp' if entropy_of_individual_components else ''}{entropy_reg_weight}"
        name += f"_div{diversity_reg_weight}" if diversity_reg_weight > 0 else ""
        super().__init__(num_components, name=name, return_log_prob=True)

        self.entropy_reg_weight = entropy_reg_weight
        self.entropy_of_individual_components = entropy_of_individual_components
        self.diversity_reg_weight = diversity_reg_weight

        self.input_dropout = nn.Dropout(input_dropout)
        self.component_dropout = nn.Dropout(component_dropout)
        self.component_activation = get_activation(component_activation)

        # Mixture weight predictor
        self.mixture_classifier = nn.Linear(hidden_dim, num_components)
        nn.init.xavier_uniform_(self.mixture_classifier.weight)
        # nn.init.zeros_(self.mixture_classifier.bias)

        # projections
        self.layer1_weights = nn.Parameter(torch.randn(num_components * hidden_dim, hidden_dim))
        torch.nn.init.xavier_uniform_(self.layer1_weights)
        self.layer1_bias = nn.Parameter(torch.randn(num_components * hidden_dim))
        self.layer2_weights = nn.Parameter(torch.randn(num_components, hidden_dim, hidden_dim))
        torch.nn.init.xavier_uniform_(self.layer2_weights)
        self.layer2_bias = nn.Parameter(torch.randn(num_components, hidden_dim))
        self.batch_norm1 = nn.BatchNorm1d(num_components * hidden_dim)
        self.batch_norm2 = nn.BatchNorm1d(num_components * hidden_dim)

        # Initialize entropy term to 0.0
        self.current_entropy = 0.0

    def forward(
        self,
        hidden_states: torch.Tensor,
        class_embeddings: torch.Tensor,
        *,
        return_prob: bool = False,
        debug: bool = True,
    ) -> torch.Tensor:
        """Compute Mixture of Softmaxes with numerically stable operations.

        Args:
            hidden_states: (batch, hidden_dim)
            class_embeddings: (num_classes, hidden_dim) Required if use_external_embeddings=True
            return_prob: Whether to return probabilities or log probabilities

        Returns:
            (batch, num_classes)

        """
        hidden_states = self.input_dropout(hidden_states)

        # Compute mixture weights
        pi_logits = self.mixture_classifier(hidden_states)  # (batch, num_components)
        log_pi = F.log_softmax(pi_logits, dim=-1)  # (batch, num_components)
        # Compute entropy for regularization (if enabled)
        if self.entropy_reg_weight > 0 and self.training:
            if self.entropy_of_individual_components:
                pi_distrib = torch.distributions.Categorical(logits=pi_logits)
                self.current_entropy = pi_distrib.entropy().mean()
            else:
                pi_probs = torch.softmax(pi_logits, dim=-1)
                avg_pi_probs = pi_probs.mean(dim=0)
                self.current_entropy = -torch.sum(avg_pi_probs * torch.log(avg_pi_probs + 1e-10))
        else:
            self.current_entropy = 0.0

        B, D = hidden_states.shape
        K = self.num_components
        W = self.layer1_weights  # (K*D, D)
        b = self.layer1_bias  # (K*D)
        h = F.linear(hidden_states, W, b)  # (B, D) @ (D, K*D) -> (B, K*D)
        h = self.batch_norm1(h)
        h = self.component_activation(h)
        h = self.component_dropout(h)
        h = rearrange(h, "b (k d) -> k b d", k=K)

        W = self.layer2_weights  # (K, D, D)
        b = self.layer2_bias  # (K, D)
        h = torch.bmm(h, W)  # (K, B, D) @ (K, D, D) -> (K, B, D)
        h = h + repeat(b, "k d -> k b d", b=B)  # (K, B, D)
        h = rearrange(h, "k b d -> b (k d)")
        h = self.batch_norm2(h)
        h = self.component_activation(h)
        h = self.component_dropout(h)
        h = rearrange(h, "b (k d) -> b k d", k=K)

        logits = torch.matmul(h, class_embeddings.t())  # (batch, K, num_classes)

        # Compute softmax probabilities per component
        log_softmax_outputs = F.log_softmax(logits, dim=-1)  # (batch, K, num_classes)
        # Compute final probability distribution using einsum
        num_classes = logits.size(-1)
        log_pi = repeat(log_pi, "b k -> b k num_classes", num_classes=num_classes)
        log_outputs = torch.logsumexp(log_pi + log_softmax_outputs, dim=1)
        return torch.exp(log_outputs) if return_prob else log_outputs

    def regularization_term(self) -> torch.Tensor:
        """Add regularization to the projection weights, mixture entropy, and projection diversity."""
        reg = 0.0
        # Add entropy regularization term (negative because we want to maximize entropy)
        if self.entropy_reg_weight > 0:
            reg -= self.entropy_reg_weight * self.current_entropy

        return reg
