from itertools import combinations

import torch
import torch.nn.functional as F
from einops import repeat
from torch import nn

from .base import MixtureLayer

DEBUG_IDENTITY = False


def get_activation(activation: str) -> nn.Module:
    if activation == "identity":
        return nn.Identity()
    if activation == "relu":
        return nn.ReLU()
    if activation == "leaky_relu":
        return nn.LeakyReLU()
    if activation == "gelu":
        return nn.GELU()
    if activation == "tanh":
        return nn.Tanh()
    raise ValueError(f"Invalid activation function: {activation}")


class MixtureOfSoftmaxes(MixtureLayer):
    def __init__(
        self,
        hidden_dim: int,
        num_components: int,
        *,
        input_dropout: float = 0.0,
        component_dropout: float = 0.0,
        component_activation: str = "identity",
        entropy_reg_weight: float = 0.0,
        entropy_of_individual_components: bool = True,
        diversity_reg_weight: float = 0.0,
    ) -> None:
        """Initialize Mixture of Softmaxes layer.

        Args:
            hidden_dim: Dimension of input hidden states
            num_components: Number of mixture components (K)
            num_classes: Number of output classes. Required if use_external_embeddings=False
            input_dropout: Dropout probability for input hidden states
            component_dropout: Dropout probability for context vectors
            component_activation: Activation function for the MLP projections
            entropy_reg_weight: Weight for entropy regularization of mixture weights
            entropy_of_individual_components: Whether to compute entropy of individual components
            diversity_reg_weight: Weight for diversity regularization of projections

        """
        name = f"mix{num_components}"
        if entropy_reg_weight > 0:
            if entropy_of_individual_components:
                name = name + f"_entsamp{entropy_reg_weight}"
            else:
                name = name + f"_ent{entropy_reg_weight}"
        name = name + f"_div{diversity_reg_weight}" if diversity_reg_weight > 0 else name
        super().__init__(num_components, name=name, return_log_prob=True)
        self.entropy_reg_weight = entropy_reg_weight
        self.entropy_of_individual_components = entropy_of_individual_components
        self.diversity_reg_weight = diversity_reg_weight

        self.input_dropout = nn.Dropout(input_dropout)
        self.component_activation = get_activation(component_activation)
        self.component_dropout = nn.Dropout(component_dropout)

        # The following defines the mixture weights parameters (w_pi in the MoS paper).
        # These are target vectors for each component. We measure the interaction of
        # the encoded hidden state with each component vector, so we can interpret them
        # as "object embeddings" for each component.
        self.mixture_classifier = nn.Linear(hidden_dim, num_components)
        torch.nn.init.xavier_uniform_(self.mixture_classifier.weight)

        # The following defines projection matrices for each component (w_h in the MoS paper).
        # These encode an interaction between the hidden state and the component vector.
        # From the hidden state (s,r), we get a hidden state (s,r,c_i) for each component i.
        # These can be interpreted as "relation embeddings" for each component.
        self.component_projections = nn.ModuleList(
            [
                nn.Sequential(
                    nn.Linear(hidden_dim, hidden_dim, bias=False),
                    nn.BatchNorm1d(hidden_dim),
                    self.component_activation,
                    self.component_dropout,
                )
                for _ in range(num_components)
            ],
        )
        for proj in self.component_projections:
            for layer in proj:
                if isinstance(layer, nn.Linear):
                    torch.nn.init.xavier_uniform_(layer.weight)
                    if layer.bias is not None:
                        torch.nn.init.zeros_(layer.bias)

        # Initialize entropy term to 0.0
        self.current_entropy = 0.0

    def forward(
        self,
        hidden_states: torch.Tensor,
        class_embeddings: torch.Tensor | None = None,
        *,
        return_prob: bool = False,
        debug: bool = False,
    ) -> torch.Tensor:
        """Compute Mixture of Softmaxes with numerically stable operations.

        Args:
            hidden_states: (batch, hidden_dim)
            class_embeddings: (num_classes, hidden_dim) Required if use_external_embeddings=True
            return_prob: Whether to return probabilities or log probabilities

        Returns:
            (batch, num_classes)

        """
        hidden_states = self.input_dropout(hidden_states)

        # Compute mixture weights
        pi_logits = self.mixture_classifier(hidden_states)  # (batch, num_components)
        log_pi = F.log_softmax(pi_logits, dim=-1)  # (batch, num_components)

        # Compute entropy for regularization (if enabled)
        if self.entropy_reg_weight > 0 and self.training:
            if self.entropy_of_individual_components:
                pi_distrib = torch.distributions.Categorical(logits=pi_logits)
                self.current_entropy = pi_distrib.entropy().mean()
            else:
                pi_probs = torch.softmax(pi_logits, dim=-1)
                avg_pi_probs = pi_probs.mean(dim=0)
                self.current_entropy = -torch.sum(avg_pi_probs * torch.log(avg_pi_probs + 1e-10))
        else:
            self.current_entropy = 0.0

        # Compute context vectors
        h = torch.stack(
            [proj(hidden_states) for proj in self.component_projections],
            dim=1,
        )  # (batch, num_components, hidden_dim)

        # Compute component-wise logits
        logits = torch.matmul(h, class_embeddings.t())  # (batch, K, num_classes)

        # Compute softmax probabilities per component
        log_softmax_outputs = F.log_softmax(logits, dim=-1)  # (batch, K, num_classes)

        # Compute final probability distribution using einsum
        num_classes = logits.size(-1)
        log_pi = repeat(log_pi, "b k -> b k num_classes", num_classes=num_classes)
        log_outputs = torch.logsumexp(log_pi + log_softmax_outputs, dim=1)

        return torch.exp(log_outputs) if return_prob else log_outputs

    def regularization_term(self) -> torch.Tensor:
        """Add regularization to the projection weights, mixture entropy, and projection diversity."""
        reg = 0.0

        # Add entropy regularization term (negative because we want to maximize entropy)
        if self.entropy_reg_weight > 0:
            reg -= self.entropy_reg_weight * self.current_entropy

        # Add diversity regularization term
        if self.diversity_reg_weight > 0 and self.num_components > 1:
            cosine_similarities = self.projection_similarities()
            diversity_penalty = torch.mean(cosine_similarities)
            reg += self.diversity_reg_weight * diversity_penalty

        return reg

    def projection_similarities(self) -> torch.Tensor:
        """Compute pairwise cosine similarity of projection matrices.

        Returns:
            tensor of shape (num_components choose 2)

        """
        projection_weights = [proj.weight for proj in self.component_projections]
        cosine_similarities = []

        # Calculate pairwise cosine similarities
        for i, j in combinations(range(len(projection_weights)), 2):
            W_i = projection_weights[i]
            W_j = projection_weights[j]
            frob_inner_prod = torch.sum(W_i * W_j)
            frob_norm_i = torch.norm(W_i, p="fro")
            frob_norm_j = torch.norm(W_j, p="fro")
            matrix_cos_sim = frob_inner_prod / (frob_norm_i * frob_norm_j)
            cosine_similarities.append(matrix_cos_sim)

        return torch.stack(cosine_similarities)
