from torch import nn
from transformers import DistilBertTokenizer, DistilBertModel
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def apply_embedding_perturbation(embeddings, embedding_perturb_std):
    """Apply Gaussian perturbation to embeddings."""
    noise = torch.randn_like(embeddings, device=embeddings.device) * embedding_perturb_std
    return embeddings + noise


class BERTClassifier(nn.Module):
    def __init__(self, num_classes=1):
        super().__init__()
        self.bert = DistilBertModel.from_pretrained("distilbert-base-uncased", output_hidden_states=True)
        self.num_classes = num_classes
        self.classifier_head = nn.Linear(768, num_classes)

        # Freeze the embeddings
        for param in self.bert.embeddings.parameters():
            param.requires_grad = False

    def forward(self, tokens, embedding_perturb_std=0.0, output_intermediates=False):
        """
        Args:
            tokens: A dictionary containing 'input_ids', 'attention_mask', etc.
                - input_ids, attention_mask: tensors of shape [batch_size, sequence_length]
            embedding_perturb_std: Standard deviation of the Gaussian noise
            output_intermediates: If True, return logits from all intermediate layers
        Returns:
            logits: tensor [batch_size, num_classes] if not output_intermediates
                   list of tensors [batch_size, num_classes] if output_intermediates
        """
        embeddings = self.bert.embeddings(tokens["input_ids"])

        if embedding_perturb_std > 0.0:
            embeddings = apply_embedding_perturbation(embeddings, embedding_perturb_std)

        outputs = self.bert(
            inputs_embeds=embeddings,
            attention_mask=tokens["attention_mask"],
            output_hidden_states=output_intermediates,  # Only get hidden states when needed
        )

        if not output_intermediates:
            last_hidden_state = outputs[0]
            readout = last_hidden_state[:, 0, :]  # [CLS] token ((batch_size, sequence_length, hidden_size))
            return self.classifier_head(readout)
        else:
            hidden_states = outputs.hidden_states[1:]  # Skip embedding layer
            layer_outputs = []
            for hidden_state in hidden_states:
                layer_cls = hidden_state[:, 0, :]  # [CLS] token
                layer_outputs.append(
                    self.classifier_head(layer_cls)
                )  # Pass each layer's [CLS] token through the classifier head
            return layer_outputs


class SymmetricUnrolledTransformer(nn.Module):
    """
    A symmetric unrolled transformer, from scratch (not-pretrained).
    Supports single and multilayer unrolling..
    """

    def __init__(self, embedding_dim, hidden_dim, num_layers, diff_d, alpha=1):
        """
        Args:
            embedding_dim: Dimension of the embeddings.
            hidden_dim: Dimension of the hidden states.
            num_layers: Number of layers in the transformer.
            diff_d: If True, use different D1 and D2 for each layer.
            alpha: Scaling factor for the FFN.
        """
        super().__init__()
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.diff_d = diff_d
        self.alpha = alpha

        if diff_d:
            self.D1 = nn.Parameter(torch.empty(num_layers, embedding_dim, hidden_dim))
            self.D2 = nn.Parameter(torch.empty(num_layers, hidden_dim, hidden_dim))
        else:
            self.D1 = nn.Parameter(torch.empty(1, embedding_dim, hidden_dim))
            self.D2 = nn.Parameter(torch.empty(1, hidden_dim, hidden_dim))

        nn.init.xavier_normal_(self.D1)
        nn.init.xavier_normal_(self.D2)

        self.readout = nn.Linear(embedding_dim, 1)

        # Borrow the embeddings from the BERT pretrained model.
        bert = DistilBertModel.from_pretrained("distilbert-base-uncased", output_hidden_states=True)
        self.embeddings = bert.embeddings

    def forward(self, tokens, embedding_perturb_std=0.0, output_intermediates=True):
        embeddings = self.embeddings(tokens["input_ids"])

        if embedding_perturb_std > 0.0:
            embeddings = apply_embedding_perturbation(embeddings, embedding_perturb_std)

        layer_outputs = [embeddings]
        current_embeddings = embeddings

        for i in range(self.num_layers):
            # Attention
            HD = current_embeddings @ self.D1[i].unsqueeze(0)
            attention = torch.softmax(HD @ HD.transpose(-2, -1), dim=1)
            Z = attention @ HD

            # FFN + residual
            I = torch.eye(self.hidden_dim).to(current_embeddings.device)
            D2S = (1 - self.alpha) * I + self.alpha / 2 * (self.D2[i] + self.D2[i].transpose(-2, -1))
            current_embeddings = current_embeddings + (Z @ D2S)

            layer_outputs.append(current_embeddings)

        if output_intermediates:
            return torch.stack(layer_outputs)
        else:
            return current_embeddings

    def compute_energy(self, hidden_states):
        """Compute the energy function g(H,D₁,D₂) = g₁(H,D₁) + g₂(H,D₂) for each layer.

        Args:
            hidden_states (torch.Tensor): Hidden states H of shape [num_layers, batch_size, hidden_dim]

        Returns:
            torch.Tensor: Energy values for each layer [num_layers]
        """
        layer_energies = []

        for layer in range(self.num_layers):
            current_hidden = hidden_states[layer]  # [batch_size, hidden_dim]

            D1 = self.D1[layer] if self.diff_d else self.D1.squeeze(0)
            D2 = self.D2[layer] if self.diff_d else self.D2.squeeze(0)

            D1H = current_hidden @ D1.transpose(-2, -1)  # [batch_size, hidden_dim]

            diff_matrix = D1H.unsqueeze(1) - D1H.unsqueeze(0)  # [batch_size, batch_size, hidden_dim]
            squared_dists = torch.sum(diff_matrix**2, dim=-1)  # [batch_size, batch_size]
            exp_term = torch.exp(-0.5 * squared_dists).sum()

            frobenius_norm_term = 0.5 * torch.norm(D1H, p='fro') ** 2

            g1 = exp_term + frobenius_norm_term

            trace_term = (
                0.5
                * torch.diagonal(current_hidden @ D2 @ current_hidden.transpose(-2, -1), dim1=-2, dim2=-1)
                .sum(-1)
                .mean()
            )
            h_frobenius_term = 0.5 * torch.norm(current_hidden, p='fro') ** 2
            phi_term = torch.norm(current_hidden, p=1)

            g2 = trace_term + h_frobenius_term + phi_term

            layer_energies.append(g1 + g2)

        return torch.stack(layer_energies)  # [num_layers]

    def compute_energy_gradient(self, x):
        pass
