import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
import math
from torchvision.models.vision_transformer import MLPBlock
from torch.nn.modules.activation import MultiheadAttention

from typing import Optional, List, Dict, Tuple, Union, Tuple

class MultiHeadAttentionPruned(nn.Module):
    def __init__(
            self, 
            active_heads: torch.Tensor, 
            original_heads: torch.Tensor,
            embed_dim: int,
            in_proj_weight: torch.Tensor,  
            out_proj_weight: torch.Tensor,
            in_proj_bias: Optional[torch.Tensor] = None,
            out_proj_bias: Optional[torch.Tensor] = None,
            dropout=0.0
            ):
        """
        Multi-head attention layer with dynamic head pruning and unpruning.

        Parameters:
        -----------
        active_heads : torch.Tensor
            Indices of the attention heads to keep active.
        original_num_heads : int
            Total number of heads before pruning.
        embed_dim : int
            Embedding dimension.
        in_proj_weight : torch.Tensor
            Combined QKV input projection weight (3*embed_dim x embed_dim).
        out_proj_weight : torch.Tensor
            Output projection weight (embed_dim x embed_dim).
        in_proj_bias : torch.Tensor or None
            Input projection bias (3*embed_dim) or None.
        out_proj_bias : torch.Tensor or None
            Output projection bias (embed_dim) or None.
        dropout : float
            Dropout probability.
        """
        super().__init__()
        assert isinstance(active_heads, torch.Tensor), f"active_heads should a tensor, instead got {type(active_heads)}"
        assert (in_proj_bias is None and out_proj_bias is None) or (in_proj_bias is not None and out_proj_bias is not None), \
            f"Either both in_proj_bias and out_proj_bias should be None or both should be provided, instead got {type(in_proj_bias)} and {type(out_proj_bias)}" 
        bias = True if in_proj_bias is not None else False
        original_num_heads = original_heads.shape[0]
        self.embed_dim = embed_dim
        self.num_heads = original_num_heads
        self.dropout = dropout
        self.out_proj = nn.modules.linear.NonDynamicallyQuantizableLinear(out_proj_weight.shape[1], 
                                                                          out_proj_weight.shape[0],
                                                                          bias=bias,
                                                                          device=out_proj_weight.device)
        self.out_proj.weight.data = out_proj_weight
        self.in_proj_weight = nn.Parameter(in_proj_weight) if not isinstance(in_proj_weight, nn.Parameter) else in_proj_weight

        if bias:
            self.in_proj_bias = nn.Parameter(in_proj_bias) if not isinstance(in_proj_bias, nn.Parameter) else in_proj_bias 
            self.out_proj.bias.data = out_proj_bias

        self.original_num_heads = original_num_heads
        self.pruned_heads = active_heads
        self.active_heads = original_heads
        self.head_dim = in_proj_weight.shape[0] // 3 // original_num_heads

        if self.pruned_heads.numel() == 0:
            self.active_heads = self.pruned_heads
        else:
            self.prune()

    def forward(
            self, 
            x: torch.Tensor, 
            k: torch.Tensor, 
            v: torch.Tensor, 
            attn_mask: Optional[torch.Tensor] = None,
            need_weights=False
        ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        The other args are not used in this implementation and are kept for compatibility
        with the implementation of ViT in torchvision.models.
        """
        if self.active_heads.numel() == 0:
            return x, torch.empty(0)
        
        B, T, _ = x.size()

        # QKV projection
        qkv = F.linear(x, self.in_proj_weight, self.in_proj_bias)  # [B, T, 3*out_dim]
        qkv = qkv.view(B, T, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4) # (q, k, v), B, num_heads, T, head_dim
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        # pytorch will scale it automatically
        attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout) 
        
        # Combine heads
        attn_output = attn_output.transpose(1, 2).reshape(B, T, self.head_dim * self.num_heads)  # [B, T, out_dim]
        out = self.out_proj(attn_output)  # [B, T, embed_dim]
        return out, torch.empty(0) 
    
    def prune(
            self, 
            verbose: bool = False
        ) -> None:
        """
        Prunes attention heads based on `self.pruned_heads`. It stores the original weights and biases
        to allow for unpruning later.
        """
        if verbose:
            print(f"Pruning heads to {len(self.pruned_heads)} heads")
        
        if len(self.active_heads) == len(self.pruned_heads):
            if verbose:
                print("Head Block already pruned")
            return

        self.original_in_proj_weight = self.in_proj_weight.data.clone().to("cpu")
        self.original_out_proj_weight = self.out_proj.weight.data.clone().to("cpu")
        self.original_in_proj_bias = None 
        self.original_out_proj_bias = None
        bias = False

        num_pruned_heads = len(self.pruned_heads)

        q_weight, k_weight, v_weight = self.in_proj_weight.data.chunk(3, dim=0)

        pruned_head_indices = torch.tensor([
            (self.active_heads == h).nonzero(as_tuple=False).item()
            for h in self.pruned_heads
            ])
        
        q_active_weight = q_weight.reshape(self.num_heads, self.head_dim, -1)[pruned_head_indices]
        k_active_weight = k_weight.reshape(self.num_heads, self.head_dim, -1)[pruned_head_indices]
        v_active_weight = v_weight.reshape(self.num_heads, self.head_dim, -1)[pruned_head_indices]
        
        q_active_weight = q_active_weight.reshape(num_pruned_heads * self.head_dim, -1)
        k_active_weight = k_active_weight.reshape(num_pruned_heads * self.head_dim, -1)
        v_active_weight = v_active_weight.reshape(num_pruned_heads * self.head_dim, -1)
        self.in_proj_weight.data = torch.cat([q_active_weight, k_active_weight, v_active_weight], dim=0)

        out_proj_weight_active = self.out_proj.weight.data.view(self.embed_dim, self.num_heads, self.head_dim)
        out_proj_weight_active = out_proj_weight_active[:, pruned_head_indices, :].reshape(self.embed_dim, -1)
        
        if self.in_proj_bias is not None:
            bias = True
            self.original_in_proj_bias = self.in_proj_bias.data.clone().to("cpu")
            self.original_out_proj_bias = self.out_proj.bias.data.clone().to("cpu")

            q_bias, k_bias, v_bias = self.in_proj_bias.data.chunk(3, dim=0)

            q_active_bias = q_bias.reshape(self.num_heads, self.head_dim, -1)[pruned_head_indices].reshape(-1)
            k_active_bias = k_bias.reshape(self.num_heads, self.head_dim, -1)[pruned_head_indices].reshape(-1)
            v_active_bias = v_bias.reshape(self.num_heads, self.head_dim, -1)[pruned_head_indices].reshape(-1)

            self.in_proj_bias.data = torch.cat([q_active_bias, k_active_bias, v_active_bias], dim=0)
            
        self.out_proj = nn.modules.linear.NonDynamicallyQuantizableLinear(out_proj_weight_active.shape[1], 
                                                                          out_proj_weight_active.shape[0],
                                                                          bias=bias,
                                                                          device=out_proj_weight_active.device)
        self.out_proj.weight.data = out_proj_weight_active
        self.out_proj.bias.data = self.original_out_proj_bias.to(self.out_proj.weight.device)

        self.active_heads = self.pruned_heads
        self.num_heads = len(self.pruned_heads)

    def unprune(
            self, 
            verbose: bool =False
        ) -> None:
        """
        Restores the original weights and biases of the attention heads.
        NOTE: This method might currently not work correctly as the design of the class has changed to 
              support uneven pruning of heads across different layers.
        """
        if not hasattr(self, "original_in_proj_weight"):
            if verbose:
                print("Head has not been pruned yet. Please call prune() first.")
            return

        self.in_proj_weight.data = self.original_in_proj_weight.to(self.in_proj_weight.device)
        bias = True if self.in_proj_bias is not None else False
        original_out_proj_weight = nn.modules.linear.NonDynamicallyQuantizableLinear(self.original_out_proj_weight.shape[1], 
                                                                                     self.original_out_proj_weight.shape[0],
                                                                                     bias=bias,
                                                                                     device=self.in_proj_weight.device)
        self.out_proj = original_out_proj_weight
        self.out_proj.weight.data = self.original_out_proj_weight.to(self.in_proj_weight.device)
        if self.in_proj_bias is not None:
            self.in_proj_bias.data = self.original_in_proj_bias.to(self.in_proj_bias.device)
            self.out_proj.bias.data = self.original_out_proj_bias.to(self.out_proj.bias.device)

        self.num_heads = self.original_num_heads
        self.active_heads = torch.arange(self.num_heads)

        del self.original_in_proj_weight
        del self.original_in_proj_bias
        del self.original_out_proj_weight
        del self.original_out_proj_bias

class PruneMLPBlock(MLPBlock):
    def __init__(
            self,
            in_proj_weight: torch.Tensor,
            in_proj_bias: torch.Tensor,
            out_proj_weight: torch.Tensor,
            out_proj_bias: torch.Tensor,
            active_rows: Union[List[int], torch.Tensor],
            dropout: float = 0.0,
            activation = nn.GELU(approximate="none")
        ):
        """
        A modified MLPBlock that supports pruning and unpruning of intermediate dimensions 
        by selecting a subset of hidden units.

        Args:
            in_proj_weight (torch.Tensor): Weight for the input linear layer.
            in_proj_bias (torch.Tensor): Bias for the input linear layer.
            out_proj_weight (torch.Tensor): Weight for the output linear layer.
            out_proj_bias (torch.Tensor): Bias for the output linear layer.
            active_rows (List[int] or torch.Tensor): Indices of hidden units to keep active.
            dropout (float): Dropout rate to apply after activation.
        """
        
        super().__init__(in_proj_weight.shape[1], in_proj_weight.shape[0], dropout)
        
        assert isinstance(active_rows, (list, torch.Tensor)), "active_rows should be a list or a tensor"

        active_in_proj_bias = in_proj_bias.data.clone()
        active_in_proj_weight =  in_proj_weight.data.clone()
        active_out_proj_weight = out_proj_weight.data.clone()
        active_out_proj_bias = out_proj_bias.data.clone()

        self[0].weight = torch.nn.Parameter(active_in_proj_weight)
        self[0].bias = torch.nn.Parameter(active_in_proj_bias)
        self[1] = activation
        self[3].weight = torch.nn.Parameter(active_out_proj_weight)
        self[3].bias = torch.nn.Parameter(active_out_proj_bias)

        if not isinstance(active_rows, torch.Tensor):
            active_rows = torch.tensor(active_rows, dtype=torch.long)
            
        self.pruned_rows = active_rows
        self.active_rows = torch.arange(in_proj_weight.shape[0])
        self.prune()

    def forward(
            self, 
            x: torch.Tensor
        ) -> torch.Tensor:
        """
        Default forward pass for the MLPBlock.
        """
        
        for module in self:
            x = module(x)
        
        return x
    
    def prune(
            self, 
            verbose: bool =False
        ) -> None:
        """
        Prunes the MLP block in place by selecting only the hidden units in `pruned_rows`.
        It stores the original weights and biases to allow for unpruning later.

        Args:
            verbose (bool): If True, prints debug info.
        """
        if verbose:
            print(f"Pruning MLPBlock to {len(self.pruned_rows)} rows")
        
        if self[0].weight.shape[0] == self.pruned_rows.shape[0]:
            if verbose:
                print("MLPBlock already pruned")
            return
        
        active_in_proj_weight = self[0].weight.data[self.pruned_rows].clone()
        active_in_proj_bias = self[0].bias.data[self.pruned_rows].clone()
        active_out_proj_weight = self[3].weight.data[:,self.pruned_rows].clone()
        active_out_proj_bias = self[3].bias.data.clone()

        self.original_in_proj_weight = self[0].weight.data.clone()
        self.original_in_proj_bias = self[0].bias.data.clone()
        self.original_out_proj_weight = self[3].weight.data.clone()
        self.original_out_proj_bias = self[3].bias.data.clone()

        self[0] = torch.nn.Linear(active_in_proj_weight.shape[1], active_in_proj_weight.shape[0], bias=True)
        self[0].weight = torch.nn.Parameter(active_in_proj_weight)
        self[0].bias = torch.nn.Parameter(active_in_proj_bias)

        self[3] = torch.nn.Linear(active_out_proj_weight.shape[1], active_out_proj_weight.shape[0], bias=True)
        self[3].weight = torch.nn.Parameter(active_out_proj_weight)
        self[3].bias = torch.nn.Parameter(active_out_proj_bias)

        self.active_rows = self.pruned_rows


    def unprune(
            self, 
            verbose: bool=False
        ) -> None:
        """
        Restores the MLP block to its original unpruned state.

        Args:
            verbose (bool): If True, prints debug info.
        """
        if not hasattr(self, "original_in_proj_weight"):
            if verbose:
                print("Head has not been pruned yet. Please call prune() first.")
            return 

        bias = True if self[0].bias is not None else False
        self[0] = torch.nn.Linear(self.original_in_proj_weight.shape[1], 
                                  self.original_in_proj_weight.shape[0], 
                                  bias=bias)
        self[0].weight.data = self.original_in_proj_weight
        self[0].bias.data = self.original_in_proj_bias
        self[3] = torch.nn.Linear(self.original_out_proj_weight.shape[1], 
                                  self.original_out_proj_weight.shape[0], 
                                  bias=bias)
        self[3].weight.data = self.original_out_proj_weight
        self[3].bias.data = self.original_out_proj_bias

        self.active_rows = torch.arange(self[0].weight.shape[0])  

        del self.original_in_proj_weight
        del self.original_in_proj_bias
        del self.original_out_proj_weight
        del self.original_out_proj_bias
        
class MHA_pruned_token(MultiheadAttention):
    def __init__(
            self,
            embed_dim: int,
            num_heads: int,
            k_sampling: int = 197,
            num_tokens_to_keep: Optional[Union[int, str]] = None,
            cls_tokens: List[int] = [0],
            random: bool = False
        ) -> None:
        """
        Multi-head attention layer that prunes tokens based on attention-weighted significance.

        Args:
            embed_dim (int): Total dimension of the model.
            num_heads (int): Number of attention heads.
            k_sampling (int): Number of tokens to sample in adaptive mode.
            num_tokens_to_keep (Optional[Union[int, str]]): Number of tokens to keep, or "adaptive".
            cls_tokens (List[int]): Indices of class tokens to always retain.
            random (bool): Whether to randomly sample tokens instead of using scoring.
        """
        super().__init__(embed_dim, num_heads)
        
        self.num_tokens_to_keep = num_tokens_to_keep
        self.tokens_kept = None 
        self.cls_tokens = cls_tokens
        self.random = random
        self.k_sampling = k_sampling

    def forward(
            self,
            x: Tensor,
            q: Optional[Tensor] = None,
            k: Optional[Tensor] = None,
            need_weights: bool = False
        ) -> Tuple[Tensor, Tensor]:
        """
        Forward pass for token-pruned multi-head attention. q and k are not used in this implementation,
        but are kept for compatibility with the MultiheadAttention interface. The method computes attention 
        scores, selects tokens based on their significance scores, and returns the pruned output tensor 
        along with attention probabilities.

        Args:
            x (Tensor): Input tensor of shape (B, T, C).
            q (Optional[Tensor]): Not used.
            k (Optional[Tensor]): Not used.
            need_weights (bool): Not used.

        Returns:
            Tuple[Tensor, Tensor]: Output tensor and attention probabilities.
        """
        B, T, _ = x.size()

        qkv = F.linear(x, self.in_proj_weight, self.in_proj_bias)
        q, k, v = qkv.chunk(3, dim=-1)

        q = q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)  # (B, heads, T, head_dim)
        k = k.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)

        attn_scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        attn_probs = F.softmax(attn_scores, dim=-1)  # (B, heads, T, T)

        if not self.random:
            sorted_significance_scores, sorted_indices = self.score_assignment(attn_probs, v)  # (B, T-cls_tokens)

            # Select top-k tokens (excluding CLS)
            if isinstance(self.num_tokens_to_keep, int):
                tokens_to_keep = sorted_indices[:, -self.num_tokens_to_keep:]
            elif self.num_tokens_to_keep == "adaptive": 
                tokens_to_keep = self.inverse_transform_sampling(sorted_significance_scores, sorted_indices)

            # Add CLS token
            cls_tokens_tensor = torch.tensor(self.cls_tokens, dtype=torch.long, device=x.device).unsqueeze(0).expand(B, -1)
            tokens_to_keep = torch.cat([cls_tokens_tensor, tokens_to_keep + len(self.cls_tokens)], dim=1)

            self.tokens_kept = tokens_to_keep  # (B, K+1)

            tokens_to_keep_exp = tokens_to_keep.unsqueeze(1).unsqueeze(-1).expand(-1, self.num_heads, -1, T)
            selected_attn = torch.gather(attn_probs, dim=2, index=tokens_to_keep_exp)  # (B, heads, K+1, T)
        else:
            rand_idx = torch.randint(len(self.cls_tokens), T, (B, self.num_tokens_to_keep), device=x.device)
            cls_tokens_tensor = torch.tensor(self.cls_tokens, dtype=torch.long, device=x.device).unsqueeze(0).expand(B, -1)
            tokens_to_keep = torch.cat([
                cls_tokens_tensor,  # CLS
                rand_idx
            ], dim=1)

            self.tokens_kept = tokens_to_keep
            tokens_to_keep_exp = tokens_to_keep.unsqueeze(1).unsqueeze(-1).expand(-1, self.num_heads, -1, T)
            selected_attn = torch.gather(attn_probs, dim=2, index=tokens_to_keep_exp)

        attn_output = selected_attn @ v  # (B, heads, K+1, head_dim)
        attn_output = attn_output.transpose(1, 2).reshape(B, tokens_to_keep.shape[1], self.embed_dim)
        out = self.out_proj(attn_output)

        return out, attn_probs
    
    def score_assignment(
            self,
            attn: Tensor,
            v: Tensor
        ) -> Tuple[Tensor, Tensor]:
        """
        Compute token significance scores based on attention and value norms.

        Args:
            attn (Tensor): Attention weights of shape (B, heads, T, T).
            v (Tensor): Value vectors of shape (B, heads, T, head_dim).

        Returns:
            Tuple[Tensor, Tensor]: Sorted significance scores and their indices.
        """
        num_of_cls_tokens = len(self.cls_tokens)
        attn_probs_cls = attn[:, :, 0, num_of_cls_tokens:]       # (B, heads, 1, T-cls_tokens)
        v_wo_cls = v[:, :, num_of_cls_tokens:, :]                # (B, heads, T-cls_tokens, head_dim)
        value_norms = v_wo_cls.norm(dim=-1)                      # (B, heads, T-cls_tokens)

        weighted_scores = attn_probs_cls * value_norms  # (B, heads, T-cls_tokens)
        normalization_factors = weighted_scores.sum(dim=-1, keepdim=True) + 1e-6 # (B, heads, 1)
        normalized_scores = weighted_scores / normalization_factors  # (B, heads, T-cls_tokens)

        #assert torch.allclose(
        #    normalized_scores.sum(dim=-1),
        #    torch.ones_like(normalized_scores.sum(dim=-1)),
        #    rtol=1e-3, atol=1e-5
        #), "Normalized scores should sum to 1.0 per head"

        significance_scores = normalized_scores.sum(dim=1)  # (B, T-cls_tokens)

        # Select top-k tokens (excluding CLS)
        sorted_scores, sorted_indices = torch.sort(
            significance_scores, descending=False, dim=1
        )

        return sorted_scores, sorted_indices
    
    def create_ys(
            self,
            normalized_cdf: Tensor,
            n_tokens: int
        ) -> Tensor:
        """
        Sample y-axis positions for inverse transform sampling. Using the normalized cdf
        to create a fixed uniform sampling scheme that starts from the first meaningful token.
        So the uniform sampling over the y-axis is actually weighted in the token space. 

        Args:
            normalized_cdf (Tensor): Normalized cumulative distribution of scores.
            n_tokens (int): Number of tokens to sample.

        Returns:
            Tensor: Sampled y-values of shape (B, n_tokens-1).
        """

        B = normalized_cdf.shape[0]

        # fixed uniform sampling scheme
        ys = (torch.linspace(
                    start=0,
                    end=1.0,
                    steps=n_tokens - 1,
                    device=normalized_cdf.device,
                )
                .unsqueeze(0)
                .repeat(B, 1)
            )

        # get smallest non-zero value, to start sampling from first meaningful token
        ys_start = (
            torch.min(normalized_cdf + (normalized_cdf == 0).float() * 1e8, dim=1)[0]
            .unsqueeze(-1)
            .expand_as(ys)
        )

        steps = (
            torch.arange(0, n_tokens - 1, device=normalized_cdf.device)
            .unsqueeze(0)
            .expand_as(ys_start)
        )

        informative_bins = n_tokens - 1
        scaled_ys = (ys * informative_bins - ys_start * steps) / informative_bins
        ys = ys_start + scaled_ys # shift y-axis to start from the first meaningful token

        return ys
    
    def inverse_transform_sampling(
            self,
            sorted_scores: Tensor,
            sorted_indices: Tensor
        ) -> Tensor:
        """
        Perform inverse transform sampling to select diverse tokens.

        Args:
            sorted_scores (Tensor): Sorted token scores (B, T-cls).
            sorted_indices (Tensor): Indices of the sorted scores (B, T-cls).

        Returns:
            Tensor: Selected token indices of shape (B, K).
        """
    
        cdf = torch.cumsum(sorted_scores, dim=1)

        normalized_cdf = (  # normalized cdf
                cdf - cdf.min(dim=1)[0].unsqueeze(dim=1)
            ) / ((cdf.max(dim=1)[0] - cdf.min(dim=1)[0]) / 1.0).unsqueeze(dim=1)
        
        ys = self.create_ys(normalized_cdf, self.k_sampling)
        
        normalized_cdf = normalized_cdf.unsqueeze(dim=1)

        tokens_to_pick_ind = torch.argmin(torch.abs(ys - normalized_cdf), dim=2)
        token_indices = torch.gather(sorted_indices, dim=1, index=tokens_to_pick_ind)
        sampled_indices = torch.unique(token_indices, dim=1)

        return sampled_indices
    

# --- Utils to Prune/Unprune ViT ---
## --- Prune MLP ---
def prune_mlp(
    self,
    layers: Optional[List[int]] = None
    ) -> None:
    """
    Calls the `prune()` method on MLP blocks in the specified encoder layers. 
    The function is assigned to the pruned model instance.

    Args:
        self: The ViT model instance.
        layers (List[int], optional): 
            List of layer indices to prune. 
            If None, all layers will be pruned.
    """
    layers = layers if layers is not None else list(range(len(self.encoder.layers)))
    assert isinstance(layers, list), f"Layers to prune must be a list, but got {type(layers)}"
    assert len(layers) <= len(self.encoder.layers), \
        f"Layers to prune must be < {len(self.encoder.layers)}, but got {layers}"
    for i, layer in enumerate(self.encoder.layers):
        if i not in layers: continue 
        mlp = layer.mlp
        if callable(getattr(mlp, "prune")):
            mlp.prune()
        else:
            print(f"[!] Layer {i} does not have a prune method. Skipping.")
    
def unprune_mlp(
    self
    ) -> None:
    """
    Calls the `unprune()` method on all MLP blocks in the encoder.
    The function is assigned to the pruned model instance.

    Args:
        self: The ViT model instance.
    """
    for layer in self.encoder.layers:
        mlp = layer.mlp
        if callable(getattr(mlp, "unprune")):
            mlp.unprune()

def get_active_mlp(self):
    active_rows = {k: None for k in range(len(self.encoder.layers))}
    for i, layer in enumerate(self.encoder.layers):
        mlp = layer.mlp[0]
        active_rows[i] = mlp.active_rows.sort().values if hasattr(mlp, "active_rows") else torch.arange(mlp.weight.shape[0])
    return active_rows

## --- Prune Heads ---
def prune_heads(
        self,
        layers: Optional[List[int]] = None
    ) -> None:
    """
    Calls the `prune()` method on Pruned Attention blocks in the specified encoder layers. 
    The function is assigned to the pruned model instance.

    Args:
    -----------
    self: The ViT model instance.
    layers : Optional[List[int]], default=None
        List of layer indices whose attention heads should be pruned.
        If None, all layers are pruned.

    """
    layers = layers if layers is not None else list(range(len(self.encoder.layers)))
    assert isinstance(layers, list), f"Layers to prune must be a list, but got {type(layers)}"
    assert len(layers) <= len(self.encoder.layers), \
        f"Layers to prune must be < {len(self.encoder.layers)}, but got {layers}"
    for i, layer in enumerate(self.encoder.layers):
        if i not in layers: continue 
        mha = layer.self_attention
        if callable(getattr(mha, "prune")):
            mha.prune()
        else:
            print(f"[!] Layer {i} does not have a prune method. Skipping.")

def unprune_heads(
        self
    ) -> None:
    """
    Restores all pruned attention heads in every encoder layer.

    Notes:
    ------
    - Calls the `unprune` method on each layer's self-attention module if available.
    - Silently skips layers without an `unprune` method.
    """
    for layer in self.encoder.layers:
        mha = layer.self_attention
        if callable(getattr(mha, "unprune")):
            mha.unprune()

def get_active_heads(
        self
    ) -> Dict[int, torch.Tensor]:
    """
    Retrieves the indices of currently active (non-pruned) attention heads for each encoder layer.

    Returns:
    --------
    active_heads : Dict[int, torch.Tensor]
        Dictionary mapping layer indices to a sorted tensor of active head indices.
        If the self-attention module has an `active_heads` attribute, returns its sorted values.
        Otherwise, returns all head indices for the layer.

    Example:
    --------
    {
        0: tensor([0, 1, 3]),
        1: tensor([0, 1, 2, 3]),
        ...
    }
    """
    active_heads = {k: None for k in range(len(self.encoder.layers))}
    for i, layer in enumerate(self.encoder.layers):
        mha = layer.self_attention
        active_heads[i] = mha.active_heads.sort().values if hasattr(mha, "active_heads") else torch.arange(mha.num_heads)
    return active_heads
