import torch
import torch.nn as nn
import torch.nn.functional as F

from typing import Optional, List, Dict, Tuple, Tuple

class MultiHeadAttentionPruned(nn.Module):
    def __init__(
            self, 
            active_heads: torch.Tensor, 
            original_heads: torch.Tensor,
            embed_dim: int,
            in_proj_weight: torch.Tensor,  
            out_proj_weight: torch.Tensor,
            in_proj_bias: Optional[torch.Tensor] = None,
            out_proj_bias: Optional[torch.Tensor] = None,
            dropout=0.0
            ):
        """
        Multi-head attention layer with dynamic head pruning and unpruning.

        Parameters:
        -----------
        active_heads : torch.Tensor
            Indices of the attention heads to keep active.
        original_num_heads : int
            Total number of heads before pruning.
        embed_dim : int
            Embedding dimension.
        in_proj_weight : torch.Tensor
            Combined QKV input projection weight (3*embed_dim x embed_dim).
        out_proj_weight : torch.Tensor
            Output projection weight (embed_dim x embed_dim).
        in_proj_bias : torch.Tensor or None
            Input projection bias (3*embed_dim) or None.
        out_proj_bias : torch.Tensor or None
            Output projection bias (embed_dim) or None.
        dropout : float
            Dropout probability.
        """
        super().__init__()
        assert isinstance(active_heads, torch.Tensor), f"active_heads should a tensor, instead got {type(active_heads)}"
        assert (in_proj_bias is None and out_proj_bias is None) or (in_proj_bias is not None and out_proj_bias is not None), \
            f"Either both in_proj_bias and out_proj_bias should be None or both should be provided, instead got {type(in_proj_bias)} and {type(out_proj_bias)}" 
        bias = True if in_proj_bias is not None else False
        original_num_heads = original_heads.shape[0]
        self.embed_dim = embed_dim
        self.num_heads = original_num_heads
        self.dropout = dropout
        self.out_proj = nn.modules.linear.NonDynamicallyQuantizableLinear(out_proj_weight.shape[1], 
                                                                          out_proj_weight.shape[0],
                                                                          bias=bias,
                                                                          device=out_proj_weight.device)
        self.out_proj.weight.data = out_proj_weight
        self.in_proj_weight = nn.Parameter(in_proj_weight) if not isinstance(in_proj_weight, nn.Parameter) else in_proj_weight

        if bias:
            self.in_proj_bias = nn.Parameter(in_proj_bias) if not isinstance(in_proj_bias, nn.Parameter) else in_proj_bias 
            self.out_proj.bias.data = out_proj_bias

        self.original_num_heads = original_num_heads
        self.pruned_heads = active_heads
        self.active_heads = original_heads
        self.head_dim = in_proj_weight.shape[0] // 3 // original_num_heads

        if self.pruned_heads.numel() == 0:
            self.active_heads = self.pruned_heads
        else:
            self.prune()

    def forward(
            self, 
            x: torch.Tensor, 
            k: torch.Tensor, 
            v: torch.Tensor, 
            attn_mask: Optional[torch.Tensor] = None,
            need_weights=False
        ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        The other args are not used in this implementation and are kept for compatibility
        with the implementation of ViT in torchvision.models.
        """
        if self.active_heads.numel() == 0:
            return x, torch.empty(0)
        
        B, T, _ = x.size()

        # QKV projection
        qkv = F.linear(x, self.in_proj_weight, self.in_proj_bias)  # [B, T, 3*out_dim]
        qkv = qkv.view(B, T, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4) # (q, k, v), B, num_heads, T, head_dim
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        # pytorch will scale it automatically
        attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=self.dropout) 
        
        # Combine heads
        attn_output = attn_output.transpose(1, 2).reshape(B, T, self.head_dim * self.num_heads)  # [B, T, out_dim]
        out = self.out_proj(attn_output)  # [B, T, embed_dim]
        return out, torch.empty(0) 
    
    def prune(
            self, 
            verbose: bool = False
        ) -> None:
        """
        Prunes attention heads based on `self.pruned_heads`. It stores the original weights and biases
        to allow for unpruning later.
        """
        if verbose:
            print(f"Pruning heads to {len(self.pruned_heads)} heads")
        
        if len(self.active_heads) == len(self.pruned_heads):
            if verbose:
                print("Head Block already pruned")
            return

        self.original_in_proj_weight = self.in_proj_weight.data.clone().to("cpu")
        self.original_out_proj_weight = self.out_proj.weight.data.clone().to("cpu")
        self.original_in_proj_bias = None 
        self.original_out_proj_bias = None
        bias = False

        num_pruned_heads = len(self.pruned_heads)

        q_weight, k_weight, v_weight = self.in_proj_weight.data.chunk(3, dim=0)

        pruned_head_indices = torch.tensor([
            (self.active_heads == h).nonzero(as_tuple=False).item()
            for h in self.pruned_heads
            ])
        
        q_active_weight = q_weight.reshape(self.num_heads, self.head_dim, -1)[pruned_head_indices]
        k_active_weight = k_weight.reshape(self.num_heads, self.head_dim, -1)[pruned_head_indices]
        v_active_weight = v_weight.reshape(self.num_heads, self.head_dim, -1)[pruned_head_indices]
        
        q_active_weight = q_active_weight.reshape(num_pruned_heads * self.head_dim, -1)
        k_active_weight = k_active_weight.reshape(num_pruned_heads * self.head_dim, -1)
        v_active_weight = v_active_weight.reshape(num_pruned_heads * self.head_dim, -1)
        self.in_proj_weight.data = torch.cat([q_active_weight, k_active_weight, v_active_weight], dim=0)

        out_proj_weight_active = self.out_proj.weight.data.view(self.embed_dim, self.num_heads, self.head_dim)
        out_proj_weight_active = out_proj_weight_active[:, pruned_head_indices, :].reshape(self.embed_dim, -1)
        
        if self.in_proj_bias is not None:
            bias = True
            self.original_in_proj_bias = self.in_proj_bias.data.clone().to("cpu")
            self.original_out_proj_bias = self.out_proj.bias.data.clone().to("cpu")

            q_bias, k_bias, v_bias = self.in_proj_bias.data.chunk(3, dim=0)

            q_active_bias = q_bias.reshape(self.num_heads, self.head_dim, -1)[pruned_head_indices].reshape(-1)
            k_active_bias = k_bias.reshape(self.num_heads, self.head_dim, -1)[pruned_head_indices].reshape(-1)
            v_active_bias = v_bias.reshape(self.num_heads, self.head_dim, -1)[pruned_head_indices].reshape(-1)

            self.in_proj_bias.data = torch.cat([q_active_bias, k_active_bias, v_active_bias], dim=0)
            
        self.out_proj = nn.modules.linear.NonDynamicallyQuantizableLinear(out_proj_weight_active.shape[1], 
                                                                          out_proj_weight_active.shape[0],
                                                                          bias=bias,
                                                                          device=out_proj_weight_active.device)
        self.out_proj.weight.data = out_proj_weight_active
        self.out_proj.bias.data = self.original_out_proj_bias.to(self.out_proj.weight.device)

        self.active_heads = self.pruned_heads
        self.num_heads = len(self.pruned_heads)

    def unprune(
            self, 
            verbose: bool =False
        ) -> None:
        """
        Restores the original weights and biases of the attention heads.
        NOTE: This method might currently not work correctly as the design of the class has changed to 
              support uneven pruning of heads across different layers.
        """
        if not hasattr(self, "original_in_proj_weight"):
            if verbose:
                print("Head has not been pruned yet. Please call prune() first.")
            return

        self.in_proj_weight.data = self.original_in_proj_weight.to(self.in_proj_weight.device)
        bias = True if self.in_proj_bias is not None else False
        original_out_proj_weight = nn.modules.linear.NonDynamicallyQuantizableLinear(self.original_out_proj_weight.shape[1], 
                                                                                     self.original_out_proj_weight.shape[0],
                                                                                     bias=bias,
                                                                                     device=self.in_proj_weight.device)
        self.out_proj = original_out_proj_weight
        self.out_proj.weight.data = self.original_out_proj_weight.to(self.in_proj_weight.device)
        if self.in_proj_bias is not None:
            self.in_proj_bias.data = self.original_in_proj_bias.to(self.in_proj_bias.device)
            self.out_proj.bias.data = self.original_out_proj_bias.to(self.out_proj.bias.device)

        self.num_heads = self.original_num_heads
        self.active_heads = torch.arange(self.num_heads)

        del self.original_in_proj_weight
        del self.original_in_proj_bias
        del self.original_out_proj_weight
        del self.original_out_proj_bias
    
def get_module_by_path(model, path):
    for attr in path.split('.'):
        model = getattr(model, attr)
    return model

# --- Utils to Prune/Unprune ViT ---
## --- Prune Heads ---
def prune_heads(
        self,
        layers: Optional[List[int]] = None
    ) -> None:
    """
    Calls the `prune()` method on Pruned Attention blocks in the specified encoder layers. 
    The function is assigned to the pruned model instance.

    Args:
    -----------
    self: The ViT model instance.
    layers : Optional[List[int]], default=None
        List of layer indices whose attention heads should be pruned.
        If None, all layers are pruned.

    """
    layers = layers if layers is not None else list(range(len(self.encoder.layers)))
    assert isinstance(layers, list), f"Layers to prune must be a list, but got {type(layers)}"
    assert len(layers) <= len(self.encoder.layers), \
        f"Layers to prune must be < {len(self.encoder.layers)}, but got {layers}"
    for i, layer in enumerate(self.encoder.layers):
        if i not in layers: continue 
        mha = layer.self_attention
        if callable(getattr(mha, "prune")):
            mha.prune()
        else:
            print(f"[!] Layer {i} does not have a prune method. Skipping.")

def unprune_heads(
        self
    ) -> None:
    """
    Restores all pruned attention heads in every encoder layer.

    Notes:
    ------
    - Calls the `unprune` method on each layer's self-attention module if available.
    - Silently skips layers without an `unprune` method.
    """
    for layer in self.encoder.layers:
        mha = layer.self_attention
        if callable(getattr(mha, "unprune")):
            mha.unprune()

def get_active_heads(
        self,
        model_name: str = "ViT"
    ) -> Dict[int, torch.Tensor]:
    """
    Retrieves the indices of currently active (non-pruned) attention heads for each encoder layer.

    Returns:
    --------
    active_heads : Dict[int, torch.Tensor]
        Dictionary mapping layer indices to a sorted tensor of active head indices.
        If the self-attention module has an `active_heads` attribute, returns its sorted values.
        Otherwise, returns all head indices for the layer.

    Example:
    --------
    {
        0: tensor([0, 1, 3]),
        1: tensor([0, 1, 2, 3]),
        ...
    }
    """
    models_naming = {
        "OpenCLIP" : ["visual.transformer.resblocks", "attn"],
        "ViT" : ["encoder.layers", "self_attention"]
    }
    layers_path, attn_attr = models_naming[model_name]
    layers = get_module_by_path(self, layers_path)
    active_heads = {}

    for i, layer in enumerate(layers):
        mha = getattr(layer, attn_attr)
        if hasattr(mha, "active_heads"):
            active_heads[i] = mha.active_heads.sort().values
        else:
            active_heads[i] = torch.arange(mha.num_heads)
    return active_heads