import torch 
from Pruning.models.block import (
    MultiHeadAttentionPruned, PruneMLPBlock, prune_heads, unprune_heads,
    get_active_heads, prune_mlp, unprune_mlp, get_active_mlp
)
from Pruning.utils.utils import get_grads, compute_attention_head_scores, bernoulli_score_sampling
import types 

import copy

from torch.utils.data import DataLoader
from typing import Dict, Union, Optional
import warnings

def structurally_prune_attention_heads(
        vit_model: torch.nn.Module,
        num_heads_to_prune: Union[int, Dict[int, list[int]]],
        strategy: str,
        context: Optional[str] = "layer",
        min_heads: int = 4,
        model_name: str = "ViT",
        current_active_heads: Optional[Dict[int, torch.Tensor]] = None,
        dataloader: Optional[DataLoader] = None,
        device: Optional[torch.device] = None,
        text_logits = None,
        verbose: bool = False,
        stocastic: bool = False,
        seed: Optional[int] = None,
        use_bias: bool = False
    ) -> torch.nn.Module:
    """
    Structurally prunes attention heads from a Vision Transformer (ViT) model based on a specified strategy.
    The embedding dimension are not changed, only the number of heads is reduced.
    
    Args:
    -----------
    vit_model : torch.nn.Module
        A Vision Transformer model (e.g., ViT-B/16). At the moment, only supports ViT from torchvision.models
    
    model_naming : list[str, str]
        A list of two strings indicating the naming convention for the model's layers.
        The first string is the prefix for the layer (e.g., "encoder.layers"), and the second string is the suffix for the attention block (e.g., "self_attention").

    num_heads_to_prune : Union[int, Dict[int, int]]
        Number of attention heads to prune. If an int, the same number is pruned in each layer.
        If a dict, keys are the number of heads to prune, and values are the list of layers to prune that number from.
        If strategy is "predefined", this should be a dict with keys as layer indices and values as lists of head indices to prune.

    strategy : str
        Pruning strategy. One of:
        - "predefined": uses a predefined list of heads to prune.
        - "magnitude": uses L1 norm of weights.
        - "gradient": uses absolute gradient-weight product.
        - "gradient_squared": uses squared gradient-weight product
        - "random": randomly select heads
    
    context : Optional[str], default="layer"
        Context for pruning heads. One of:
        - "layer": prune heads per layer.
        - "global": prune heads globally across all layers.
        - "layer-global": prune heads per layer, but use heads distribution from global pruning.

    min_heads : int, default=4
        Minimum number of heads to keep in each layer after pruning. If the number of heads to
        prune results in fewer than `min_heads`, the pruning is adjusted to keep at least `min_heads` heads.
        Relevant only for global pruning.
    
    model_name : str, default="ViT"
        Name of the model architecture. Currently supports "ViT" and "OpenCLIP".
        This is used to determine the naming convention for the model's layers.

    current_active_heads : Optional[Dict[int, torch.Tensor]], default=None
        A dictionary mapping layer indices to tensors of currently active heads.
        If None, all heads are considered active. The keys are layer indices, and the values are tensors of head indices.
        If the model has a different number of heads per layer, this should be provided to avoid errors.
        If the model has a fixed number of heads per layer, this can be None, and the function will assume all heads are active.
        NOTE: if your pruning a model already pruned you can use 'model.get_active_heads()' to get the current active heads.
    
    dataloader : Optional[torch.utils.data.DataLoader], default=None
        Required for gradient-based strategies. Provides data for backpropagation to compute gradients.

    device : Optional[torch.device], default=None
        Device on which to run gradient computation. Required for gradient-based strategies.

    text_logits : Optional, default=None
        The text logits to use for computing gradients in the case of gradient-based pruning strategies.

    verbose : bool, default=False
        Whether to print per-layer pruning info.

    stocastic : bool, default=False
        If True, uses stochastic pruning based on the scores of the heads derived using the strategy.
        Drop probability is Bernoulli('score / sum_score') with probabilities [avg(probabs), 1] 
        # NOTE: this is a naive implementation, better methods should be explored and the performance has not been tested yet. 
    Returns:
    --------
    vit_model_pruned : torch.nn.Module
        A structurally pruned Vision Transformer model with heads removed according to the given strategy.

    Notes:
    ------
    For gradient-based strategies, `get_grads` is called at each layer iteration to get new estimates 
    after the pruning of the previous layers when using context 'layer'.
    """

    models_naming = {
        "OpenCLIP" : ["visual.transformer.resblocks", "attn"],
        "ViT" : ["encoder.layers", "self_attention"]
    }
    allowed_strategies = ["magnitude", "gradient", "gradient_squared", 
                          "attention_scores", "random", "predefined"]
    allowed_contexts = ["layer", "global", "layer-global"]

    assert model_name in models_naming, \
        f"Model name must be one of {list(models_naming.keys())}, but got {model_name}."

    model_naming = models_naming[model_name]
    num_heads_to_prune = copy.deepcopy(num_heads_to_prune)

    assert strategy in allowed_strategies, \
        f"Strategy must be one of {allowed_strategies}, instead got{strategy}"
    assert strategy in ["magnitude", "random", "predefined"] or (dataloader is not None and device is not None), \
        f"Gradient required for {strategy} pruning, but dataloader or device not provided."
    assert context in allowed_contexts, \
        f"Context must be one of {allowed_contexts}, but got {context}."
    if strategy == "random":
        assert context in ["layer", "layer-global"], \
            f"Random pruning is only supported in 'layer' context, but got {context}."
    assert isinstance(num_heads_to_prune, (dict,int)), \
        f"num_heads_to_prune must be either an int or a dict, but got {type(num_heads_to_prune)}."
    if stocastic:
        assert seed is not None, \
            f"Stocastic pruning requires a seed to be set, but got seed={seed} and stocastic={stocastic}."
    if context in ["global", "layer-global"]:
        assert isinstance(num_heads_to_prune, int), \
            f"num_heads_to_prune must be an int when context is {context}, but got {type(num_heads_to_prune)}."
    if isinstance(num_heads_to_prune, dict):
        assert context == "layer", \
            f"When num_heads_to_prune is a dict, context must be 'layer', but got {context}."
    if strategy == "predefined":
        assert context == "layer", \
            f"When strategy is 'predefined', context must be 'layer', but got {context}."
        assert isinstance(num_heads_to_prune, dict), \
            f"When strategy is 'predefined', num_heads_to_prune must be a dict, but got {type(num_heads_to_prune)}."
    if strategy in ["random", "predefined"]:
        if dataloader is not None:
            warnings.warn(
                f"Using dataloader with strategy {strategy} is not necessary, as no gradients are computed"
                " and it will be ignored.",
                UserWarning
            )
        if device is not None:
            warnings.warn(
                f"Using device with strategy {strategy} is not necessary, as no gradients are computed"
                " and it will be ignored.",
                UserWarning
            )
    
    device = device or next(vit_model.parameters()).device

    transformer_blocks = get_module_by_path(vit_model, model_naming[0])
    original_num_heads = []
    for i in range(len(transformer_blocks)):
        mha = getattr(transformer_blocks[i], model_naming[1])
        original_num_heads.append(mha.num_heads)
    
    is_num_heads_same = len(set(original_num_heads)) == 1
    if current_active_heads is None:
        assert is_num_heads_same, \
            f"The number of heads per is not the same across all layers {original_num_heads}, so current_active_heads must be provided. " 
        current_active_heads = {layer: torch.arange(num_heads) for layer, num_heads in enumerate(original_num_heads)}
    current_active_heads = {layer: active_heads.to(device) for layer, active_heads in current_active_heads.items()}
    
    if isinstance(num_heads_to_prune, int):
        layer_to_heads = {layer: num_heads_to_prune for layer in range(len(transformer_blocks))}
    elif isinstance(num_heads_to_prune, dict):
        if strategy == "predefined":
            layer_to_heads = {}
            for layer_idx, heads_to_prune in num_heads_to_prune.items():
                layer_heads = original_num_heads[layer_idx] 
                layer_heads = list(range(layer_heads))
                heads_to_keep = [head for head in layer_heads if head not in heads_to_prune]
                num_heads_to_prune[layer_idx] = heads_to_keep
                layer_to_heads[layer_idx] = heads_to_keep
        else:
            layer_to_heads = {layer: num_heads for num_heads, layers in num_heads_to_prune.items() for layer in layers}   


    if isinstance(num_heads_to_prune, dict):
        if isinstance(list(num_heads_to_prune.values())[0], int):
            layer_to_heads = {i: layer_to_heads.get(i, 0) for i in range(len(transformer_blocks))}
        elif isinstance(list(num_heads_to_prune.values())[0], list):
            layer_to_heads = {i: layer_to_heads.get(i, list(range(len(transformer_blocks)))) \
                                for i in range(len(transformer_blocks))}
    '''
    if strategy != "predefined" and isinstance(num_heads_to_prune, dict):
        # values are counts → default missing layers to 0 heads to prune
        layer_to_heads = {i: layer_to_heads.get(i, 0) for i in range(len(transformer_blocks))}

    '''             


    if context == "layer" and strategy != "predefined":
        for layer_idx, num_heads in layer_to_heads.items():
            assert num_heads < len(current_active_heads[layer_idx]), \
                f"Number of heads to prune must be < {len(current_active_heads[layer_idx])}, but got {num_heads} for layer {layer_idx}."
    
    vit_model_pruned = copy.deepcopy(vit_model).to(device)
    transformer_blocks = get_module_by_path(vit_model_pruned, model_naming[0])

    # number of heads per layer might be different, so we init with the maximum number of heads
    max_head_idx = max(h.max().item() for h in current_active_heads.values())
    model_heads_scores = torch.zeros(len(transformer_blocks), max_head_idx + 1, device=device)
    keep_heads = {i: torch.tensor([], device=device) for i in range(len(transformer_blocks))}    

    for layer_idx, (block_new) in enumerate(transformer_blocks):
        mha_new = getattr(block_new, model_naming[1])
        
        if strategy not in ["random", "predefined"]:
            layer_heads_scores = get_heads_scores(
                mha = mha_new,
                vit_model = vit_model_pruned,
                strategy = strategy, 
                dataloader = dataloader,
                device = device,
                text_logits = text_logits,
                compute_grads = context == "layer",
                use_bias = use_bias
            )
        
            layer_heads_scores_adjusted = torch.zeros(model_heads_scores.shape[1], device = device)
            layer_heads_scores_adjusted[current_active_heads[layer_idx]] = layer_heads_scores
            layer_heads_scores_adjusted[layer_heads_scores_adjusted == 0] = -1e-30
            model_heads_scores[layer_idx,:] = layer_heads_scores_adjusted

        # if pruning at the level of the layer, gradients are recomputed after each layer pruning
        if context == "layer":
            if strategy != "predefined":
                num_heads_to_prune = layer_to_heads[layer_idx]    
                heads_layer_scores = model_heads_scores[layer_idx,:]
                keep_heads[layer_idx] = get_layer_heads_pruning(
                    heads_layer_scores,
                    len(current_active_heads[layer_idx]),
                    num_heads_to_prune,
                    stocastic,
                    seed = seed if strategy != "random" else seed + (layer_idx+1)*32,
                    random = True if strategy == "random" else False
                )
            else:
                heads_to_keep = layer_to_heads[layer_idx]
                keep_heads[layer_idx] = torch.tensor(heads_to_keep)
            keep_heads[layer_idx] = keep_heads[layer_idx].to(device)
            num_heads = mha_new.num_heads
            embed_dim = mha_new.embed_dim
            mha_new = MultiHeadAttentionPruned(keep_heads[layer_idx], 
                                               current_active_heads[layer_idx],
                                               embed_dim,
                                               mha_new.in_proj_weight.data,  
                                               mha_new.out_proj.weight.data,
                                               mha_new.in_proj_bias.data, 
                                               mha_new.out_proj.bias.data
                                               )
            setattr(block_new, model_naming[1], mha_new)
    
    if context == "global":
        keep_heads = get_global_heads_pruning(
            model_heads_scores,
            original_num_heads,
            len(transformer_blocks),
            num_heads_to_prune,
            min_heads,
            stocastic,
            seed
        )
        keep_heads = {i: keep_heads[i].to(device) for i in range(len(transformer_blocks))}
    
        for layer_idx, (block_new) in enumerate(transformer_blocks):
            mha_new = getattr(block_new, model_naming[1])
            num_heads = mha_new.num_heads
            embed_dim = mha_new.embed_dim
            mha_new = MultiHeadAttentionPruned(keep_heads[layer_idx], 
                                               current_active_heads[layer_idx],
                                               embed_dim,
                                               mha_new.in_proj_weight.data,  
                                               mha_new.out_proj.weight.data,
                                               mha_new.in_proj_bias.data, 
                                               mha_new.out_proj.bias.data
                                               )
            setattr(block_new, model_naming[1], mha_new)

    if context == "layer-global":
        keep_heads = get_global_heads_pruning(
            model_heads_scores,
            original_num_heads,
            len(transformer_blocks),
            num_heads_to_prune,
            min_heads,
            stocastic = False,
            seed = seed
        )
        # preserve number of heads per layer following the global distribution from global pruning
        for layer_idx, keep_heads_layer in keep_heads.items():
            layer_to_heads[layer_idx] = original_num_heads[layer_idx] - keep_heads_layer.numel()

        for layer_idx, (block_new) in enumerate(transformer_blocks):
            mha_new = getattr(block_new, model_naming[1])
            
            layer_heads_scores = get_heads_scores(
                mha = mha_new,
                vit_model = vit_model_pruned,
                strategy = strategy, 
                dataloader = dataloader,
                device = device,
                text_logits = text_logits,
                compute_grads = True,
                use_bias = use_bias
                )

            layer_heads_scores_adjusted = torch.zeros(model_heads_scores.shape[1], device = device)
            layer_heads_scores_adjusted[current_active_heads[layer_idx]] = layer_heads_scores
            layer_heads_scores_adjusted[layer_heads_scores_adjusted == 0] = -1e-30
            model_heads_scores[layer_idx,:] = layer_heads_scores_adjusted

            # if pruning at the level of the layer, gradients are recomputed after each layer pruning
            num_heads = mha_new.num_heads
            num_heads_to_prune = layer_to_heads[layer_idx]
            heads_layer_scores = model_heads_scores[layer_idx,:]
            keep_heads[layer_idx] = get_layer_heads_pruning(
                heads_layer_scores,
                original_num_heads[layer_idx],
                num_heads_to_prune,
                stocastic,
                seed
            ) 
            keep_heads[layer_idx] = keep_heads[layer_idx].to(device)
            embed_dim = mha_new.embed_dim
            mha_new = MultiHeadAttentionPruned(keep_heads[layer_idx], 
                                               current_active_heads[layer_idx],
                                               embed_dim,
                                               mha_new.in_proj_weight.data,  
                                               mha_new.out_proj.weight.data,
                                               mha_new.in_proj_bias.data, 
                                               mha_new.out_proj.bias.data
                                               )
            setattr(block_new, model_naming[1], mha_new)
    
    if verbose:
        for layer_idx, keep_heads_layer in keep_heads.items():
            print(f"[Layer {layer_idx}] Keeping heads ({keep_heads_layer.numel()}): {keep_heads_layer.sort()[0].tolist()}") 
        total_num_heads_kept = sum([keep_heads[i].numel() for i in range(len(transformer_blocks))])
        print(f"Total number of heads kept: {total_num_heads_kept}") 

    vit_model_pruned.prune_heads = types.MethodType(prune_heads, vit_model_pruned)
    vit_model_pruned.unprune_heads = types.MethodType(unprune_heads, vit_model_pruned)
    vit_model_pruned.get_active_heads = types.MethodType(get_active_heads, vit_model_pruned)
    
    return vit_model_pruned

def structurally_prune_mlp(
        vit_model: torch.nn.Module,
        strategy: str,
        rows_to_prune: Union[int, float, Dict[int, list[int]]],
        model_name: str = "ViT",
        current_active_rows: Optional[Dict[int, torch.Tensor]] = None,
        context: Optional[str] = "layer",
        min_rows = 800,
        text_logits: torch.Tensor = None ,
        dataloader: Optional[DataLoader] = None,
        device: Optional[torch.device] = None,
        verbose: bool = False,
        stocastic: bool = False,
        seed: Optional[int] = None
    ) -> torch.nn.Module:            
    """
    Structurally prunes the hidden (intermediate) dimension of the MLP blocks in a Vision Transformer model.

    Args:
    -----------
    vit_model : torch.nn.Module
        The Vision Transformer model to prune.

    strategy : str
        Pruning strategy. One of:
        - "magnitude": uses L1 norm of weights.
        - "gradient": uses absolute gradient-weight product.
        - "gradient_squared": uses squared gradient-weight product.
        - "random": randomly select rows to prune.
    
    context : Optional[str], default="layer"
        Context for pruning MLP layers. One of:
        - "layer": prune MLP layers per transformer block.
        - "global": prune MLP layers globally across all transformer blocks.
        - "layer-global": prune MLP layers per transformer block, but use the distribution of rows from global pruning.

    rows_to_prune : Union[int, float]
        Number or proportion of rows (neurons) to prune:
        - If int: absolute number of rows to prune.
        - If float: percentage (between 0 and 1) of rows to prune.

    min_rows : int, default=800
        Minimum number of rows to keep in each MLP layer after pruning. If the number of
        rows to prune results in fewer than `min_rows`, the pruning is adjusted to keep at least `min_rows` rows.
        Relevant only for global pruning.
    
    model_name : str, default="ViT"
        Name of the model architecture. Currently supports "ViT" and "OpenCLIP".
        This is used to determine the naming convention for the model's layers.
    
    current_active_rows : Optional[Dict[int, torch.Tensor]], default=None
        A dictionary mapping layer indices to tensors of currently active rows.
        If None, all rows are considered active. The keys are layer indices, and the values are tensors of row indices.
        If the model has a different number of rows per layer, this should be provided to avoid errors.
        If the model has a fixed number of rows per layer, this can be None, and the function will assume all rows are active.

    text_logits : Optional, default=None
        The text logits to use for computing gradients in the case of gradient-based pruning strategies.

    dataloader : Optional[torch.utils.data.DataLoader], default=None
        Required for gradient-based pruning strategies.

    device : Optional[torch.device], default=None
        Device on which to compute gradients, if needed.

    verbose : bool, default=False
        Whether to print pruning info per layer.

    Returns:
    --------
    vit_model_pruned : torch.nn.Module
        A Vision Transformer model with pruned MLP layers.

    Notes:
    ------
    For gradient-based strategies, `get_grads` is called at each layer iteration to get new estimates 
    after the pruning of the previous layers.
    """
    
    models_naming = {
        "OpenCLIP" : ["visual.transformer.resblocks", ("mlp.c_fc", "mlp.gelu", "mlp.c_proj")],
        "ViT" : ["encoder.layers", ("mlp.0", "mlp.1", "mlp.3")]
    }
    allowed_strategies = ["magnitude", "gradient", "gradient_squared", "random"]
    allowed_contexts = ["layer", "global", "layer-global"]

    assert model_name in models_naming, \
        f"Model name must be one of {list(models_naming.keys())}, but got {model_name}."

    model_naming = models_naming[model_name]

    assert strategy in allowed_strategies , \
        f"Strategy must be one of {allowed_strategies}, instead got{strategy}"
    assert context in allowed_contexts, \
        f"Context must be one of {allowed_contexts}, but got {context}."
    assert strategy == "magnitude" or (dataloader is not None and device is not None), \
        f"Gradient required for {strategy} pruning, but dataloader or device not provided."
    
    transformer_blocks = get_module_by_path(vit_model, model_naming[0])
    mlp = get_module_by_path(transformer_blocks[0], model_naming[1][0])
    original_num_rows = mlp.weight.shape[0] 
    if isinstance(rows_to_prune, int):
        # assume all MLPs have the same number of rows
        assert rows_to_prune < original_num_rows, \
            f"Number of rows to prune must be < {original_num_rows}, but got {rows_to_prune}"
        layer_to_rows = {i: rows_to_prune for i in range(len(transformer_blocks))}
    elif isinstance(rows_to_prune, float):
        assert rows_to_prune > 0 and rows_to_prune < 1, \
            f"Percentage of rows to prune must be between 0 and 1, but got {rows_to_prune}"
        original_num_rows = mlp.weight.shape[0]
        num_rows_to_prune = int(original_num_rows * rows_to_prune)
        layer_to_rows = {i: num_rows_to_prune for i in range(len(transformer_blocks))}
    elif isinstance(rows_to_prune, dict):
        assert all(isinstance(layers, list) for layers in rows_to_prune.values()), \
            f"rows_to_prune must be a dict with lists of layers, but got {list(rows_to_prune.values())}."
        assert all(isinstance(num_rows, int) for num_rows in rows_to_prune.keys()), \
            f"rows_to_prune must be a dict with int keys, but got {list(rows_to_prune.keys())}"
        layer_to_rows = {layer: num_heads for num_heads, layers in rows_to_prune.items() for layer in layers}
        layer_to_rows = {i: layer_to_rows.get(i, 0) for i in range(len(transformer_blocks))}
        assert all(n <= original_num_rows for n in layer_to_rows.values()), \
            f"Cannot prune > {original_num_rows} rows. Got: {layer_to_rows}"
    else:
        raise ValueError(f"rows_to_prune must be either an int, float or a dict, but got {type(rows_to_prune)}.")

    if context in ["global", "layer-global"]:
        assert isinstance(rows_to_prune, int), \
            f"num_rows_to_prune must be an int when context is {context}, but got {type(num_rows_to_prune)}."
        
    device = device or next(vit_model.parameters()).device
    vit_model_pruned = copy.deepcopy(vit_model).to(device)
    vit_model_pruned.eval()
    
    transformer_blocks = get_module_by_path(vit_model_pruned, model_naming[0])
    model_mlp_scores = torch.zeros([len(transformer_blocks), mlp.weight.shape[0]], device=device)
    active_rows = {i: torch.tensor([], device=device) for i in range(len(transformer_blocks))}
    
    for i, block in enumerate(transformer_blocks):
        fc1 = get_module_by_path(block, model_naming[1][0])
        activation = get_module_by_path(block, model_naming[1][1])
        fc2 = get_module_by_path(block, model_naming[1][2])

        model_mlp_scores[i,:] = get_mlp_scores(
            fc1,
            vit_model_pruned,
            strategy, 
            context,
            dataloader,
            device,
            text_logits
        )

        if context == "layer":
            num_rows_to_prune = layer_to_rows[i]
            mlp_layer_scores = model_mlp_scores[i,:]
            active_rows[i] = get_layer_mlp_pruning(
                mlp_layer_scores,
                original_num_rows,
                num_rows_to_prune,
                stocastic,
                seed
            )
            active_rows[i] = active_rows[i].to(device)
            
            in_proj_weight = fc1.weight.data
            in_proj_bias = fc1.bias.data
            out_proj_weight = fc2.weight.data
            out_proj_bias = fc2.bias.data
            pruned_mlp_block = PruneMLPBlock(in_proj_weight,
                                            in_proj_bias,
                                            out_proj_weight,
                                            out_proj_bias, 
                                            active_rows[i],
                                            dropout = 0.0,
                                            activation = activation)

            mlp_block_name = model_naming[1][0].split('.')[0]
            setattr(block, mlp_block_name, pruned_mlp_block)
    
    if context == "global":
        active_rows = get_global_mlp_pruning(
            model_mlp_scores,
            original_num_rows,
            len(transformer_blocks),
            rows_to_prune,
            min_rows,
            stocastic,
            seed    
        )
    
        for i, block in enumerate(transformer_blocks):
            fc1 = get_module_by_path(block, model_naming[1][0])
            activation = get_module_by_path(block, model_naming[1][1])
            fc2 = get_module_by_path(block, model_naming[1][2])

            in_proj_weight = fc1.weight.data
            in_proj_bias = fc1.bias.data
            out_proj_weight = fc2.weight.data
            out_proj_bias = fc2.bias.data

            pruned_mlp_block = PruneMLPBlock(in_proj_weight,
                                            in_proj_bias,
                                            out_proj_weight,
                                            out_proj_bias, 
                                            active_rows[i],
                                            dropout = 0.0,
                                            activation = activation)

            mlp_block_name = model_naming[1][0].split('.')[0]
            setattr(block, mlp_block_name, pruned_mlp_block)

    if context == "layer-global":
        active_rows = get_global_mlp_pruning(
            model_mlp_scores,
            original_num_rows,
            len(transformer_blocks),
            rows_to_prune,
            min_rows,
            stocastic = False,
            seed = seed   
        )
        # preserve number of rows per layer following the global distribution from global pruning
        for layer_idx, active_rows_layer in active_rows.items():
            layer_to_rows[layer_idx] = original_num_rows - active_rows_layer.numel()

        for i, block in enumerate(transformer_blocks):
            fc1 = get_module_by_path(block, model_naming[1][0])
            activation = get_module_by_path(block, model_naming[1][1])
            fc2 = get_module_by_path(block, model_naming[1][2])

            model_mlp_scores[i,:] = get_mlp_scores(
                fc1,
                vit_model_pruned,
                strategy, 
                context,
                dataloader,
                device,
                text_logits
            )

            num_rows_to_prune = layer_to_rows[i]
            mlp_layer_scores = model_mlp_scores[i,:]
            active_rows[i] = get_layer_mlp_pruning(
                mlp_layer_scores,
                original_num_rows,
                num_rows_to_prune,
                stocastic,
                seed
            )
            active_rows[i] = active_rows[i].to(device)

            in_proj_weight = fc1.weight.data
            in_proj_bias = fc1.bias.data
            out_proj_weight = fc2.weight.data
            out_proj_bias = fc2.bias.data

            pruned_mlp_block = PruneMLPBlock(in_proj_weight,
                                            in_proj_bias,
                                            out_proj_weight,
                                            out_proj_bias, 
                                            active_rows[i],
                                            dropout = 0.0,
                                            activation = activation)

            mlp_block_name = model_naming[1][0].split('.')[0]
            setattr(block, mlp_block_name, pruned_mlp_block)

    for layer_idx, active_rows_layer in active_rows.items():
        print(f"[Layer {layer_idx}] Keeping rows: {active_rows_layer.numel()}") if verbose else None

    vit_model_pruned.prune_mlp = types.MethodType(prune_mlp, vit_model_pruned) 
    vit_model_pruned.unprune_mlp = types.MethodType(unprune_mlp, vit_model_pruned)
    vit_model_pruned.get_active_mlp = types.MethodType(get_active_mlp, vit_model_pruned)

    return vit_model_pruned

def get_module_by_path(model, path):
    for attr in path.split('.'):
        model = getattr(model, attr)
    return model

def get_heads_scores(
    mha: torch.nn.Module,
    strategy: str,
    compute_grads: bool,
    vit_model: torch.nn.Module = None,
    dataloader: DataLoader = None,
    device: torch.device = None,
    text_logits: torch.Tensor = None,
    use_bias : bool = False
):
    """
    Given a MultiHeadAttention module that follows the design of torchvision or OpenCLIP, 
    compute the scores for each attention head based on the specified strategy. 
    Args:
    -----------
    mha : torch.nn.Module
        A MultiHeadAttention module from a Vision Transformer model.
    vit_model : torch.nn.Module
        The Vision Transformer model from which the MultiHeadAttention module is taken.
        Needed for gradient-based strategies.
    strategy : str
        The strategy to use for scoring the attention heads. One of:
        - "magnitude": uses L1 norm of weights.
        - "gradient": uses absolute gradient-weight product.
        - "gradient_squared": uses squared gradient-weight product.
        - "attention_scores": computes scores based on attention scores.
    compute_grads : bool
        If True, computes new gradients even if they were alrady computed.
        This is useful for example when pruning layer by layer to get new estimates after the pruning of the previous layers.
    device : torch.device
        The device on which to compute the gradients, if needed.
    dataloader : torch.utils.data.DataLoader
        A DataLoader providing data for backpropagation to compute gradients.
        In general this should be either the training or validation dataloader.
    text_logits : torch.Tensor
        The text logits to use for computing gradients in the case of gradient-based pruning strategies and CLIP models.
    use_bias : bool, default=False
        Whether to use the bias terms in the MultiHeadAttention module for scoring. (not tested yet)
    """

    supported_strategies = ["magnitude", "gradient", "gradient_squared", "attention_scores"]

    embed_dim = mha.embed_dim
    num_heads = mha.num_heads
    head_dim = mha.head_dim
            
    W = mha.in_proj_weight
    q_weight, k_weight, v_weight = W.chunk(3, dim=0)
    if use_bias:
        q_bias, k_bias, v_bias = mha.in_proj_bias.chunk(3, dim=0)

    # Score attention heads
    if strategy == "magnitude":
        q_score = q_weight.abs().reshape(num_heads, head_dim, -1).mean(dim=(1, 2))
        k_score = k_weight.abs().reshape(num_heads, head_dim, -1).mean(dim=(1, 2))
        v_score = v_weight.abs().reshape(num_heads, head_dim, -1).mean(dim=(1, 2))
        if use_bias:
            q_score += q_bias.abs().reshape(num_heads, head_dim).mean(dim=1)
            k_score += k_bias.abs().reshape(num_heads, head_dim).mean(dim=1)
            v_score += v_bias.abs().reshape(num_heads, head_dim).mean(dim=1)
    elif strategy in ["gradient", "gradient_squared"]:
        if compute_grads or mha.in_proj_weight.grad is None:
            get_grads(vit_model, dataloader, device, text_logits)
        G = mha.in_proj_weight.grad
        q_grad, k_grad, v_grad = G.chunk(3, dim=0)
        score_func = torch.abs if strategy == "gradient" else torch.square
        q_score = score_func(q_weight * q_grad).view(num_heads, head_dim, embed_dim).mean(dim=(1, 2))
        k_score = score_func(k_weight * k_grad).view(num_heads, head_dim, embed_dim).mean(dim=(1, 2))
        v_score = score_func(v_weight * v_grad).view(num_heads, head_dim, embed_dim).mean(dim=(1, 2))
        if use_bias:
            q_grad_bias, k_grad_bias, v_grad_bias = mha.in_proj_bias.grad.chunk(3, dim=0)
            q_score += score_func(q_bias * q_grad_bias).view(num_heads, head_dim).mean(dim=1)
            k_score += score_func(k_bias * k_grad_bias).view(num_heads, head_dim).mean(dim=1)
            v_score += score_func(v_bias * v_grad_bias).view(num_heads, head_dim).mean(dim=1)
    elif strategy == "attention_scores":
        images, _ = next(dataloader)
        return compute_attention_head_scores(vit_model, images, device)  
    else:
        raise ValueError(f"Unknown strategy: {strategy}. Supported strategies are: {supported_strategies}.")

    scores = (q_score + k_score + v_score) / 3
    return scores

def get_mlp_scores(
        mlp: torch.nn.Module,
        vit_model,
        strategy: str,
        context: str,
        dataloader: DataLoader = None,
        device: torch.device = None,
        text_logits: torch.Tensor = None
):
    W = mlp.weight.data 

    if strategy == "magnitude":
        scores = W.abs().sum(dim=1)  # L1 norm
    elif strategy in ["gradient", "gradient_squared"]:
        if context == "layer" or mlp.weight.grad is None:
            get_grads(vit_model, dataloader, device, text_logits) 
        G = mlp.weight.grad
        score_func = torch.abs if strategy == "gradient" else torch.square
        scores = score_func(W * G).sum(dim=1)
    return scores

def get_global_heads_pruning(
        model_heads_scores: torch.Tensor,
        original_num_heads: list[int],
        num_transformer_blocks: int,
        num_heads_to_prune: int,
        min_heads: int ,
        stocastic: bool,
        seed: int = None
) -> Dict[int, torch.Tensor]:
    """
    Get the heads to keep for global pruning of attention heads across all transformer blocks.
    Args:
    -----------
    model_heads_scores : torch.Tensor
        Tensor of shape (num_transformer_blocks, max_num_heads) containing the scores for each
        attention head in each transformer block. The scores are used to determine which heads to keep.
    original_num_heads : list[int]
        List of integers containing the original number of heads in each transformer block.
    num_transformer_blocks : int
        Number of transformer blocks in the model.
    num_heads_to_prune : int
        Number of heads to prune from each transformer block. However, the actual pruning is done globally,
        so the this arg is used to determine the global number of heads to prune as 
        num_heads_to_prune * num_transformer_blocks.
    min_heads : int
        Minimum number of heads to keep in each transformer block after pruning. If the number of heads to
        prune results in fewer than `min_heads`, the pruning is adjusted to keep at least `min_heads` heads.
        If the number of heads in the transformer block is less than `min_heads`, the pruning is 
        adjusted to keep all heads.
    stocastic : bool
        If True, uses stochastic pruning based on the scores of the heads derived using the strategy.
    seed : Optional[int]
        Seed for the random number generator used in stochastic pruning. Required if `stocastic` is True.
    """
    keep_heads = {i: torch.tensor([], device="cpu") for i in range(num_transformer_blocks)}
    model_heads_scores = model_heads_scores.flatten()
    global_num_heads = sum(original_num_heads)
    max_num_heads = max(original_num_heads)
    global_num_heads_to_keep = global_num_heads - num_heads_to_prune * num_transformer_blocks
    total_min_required = sum(min(min_heads, h) for h in original_num_heads)
    global_num_heads_to_keep = max(global_num_heads_to_keep, total_min_required)
    while True:
        if stocastic:
            idx_heads_to_keep = bernoulli_score_sampling(model_heads_scores, global_num_heads_to_keep, seed, max_tries=200)
        else:
            idx_heads_to_keep = torch.topk(model_heads_scores, k=global_num_heads_to_keep, largest=True).indices
        valid = True
        for layer_idx in range(num_transformer_blocks):
            # heads scores are zero-padded to the maximum number of heads
            start, end = layer_idx * max_num_heads, (layer_idx + 1) * max_num_heads
            layer_mask = (idx_heads_to_keep >= start) & (idx_heads_to_keep < end)
            keep_heads_layer = idx_heads_to_keep[layer_mask] - start
            # if the number of heads in the layer is less than the minimum
            adjusted_min_heads = min(min_heads, original_num_heads[layer_idx])
            if keep_heads_layer.numel() < adjusted_min_heads:
                valid = False
                idx_heads_to_keep_layer = torch.topk(model_heads_scores[start:end], k=adjusted_min_heads, largest=True).indices
                abs_idx = start + idx_heads_to_keep_layer
                all_idx = torch.arange(start, end, device=model_heads_scores.device)
                abs_idx_to_neg_inf = all_idx[~torch.isin(all_idx, abs_idx)]
                model_heads_scores[abs_idx] = 1e+30 
                model_heads_scores[abs_idx_to_neg_inf] = -1e-30
                break
            keep_heads[layer_idx] = keep_heads_layer
        if valid:
            break
    tot_num_heads_kept = sum([keep_heads[i].numel() for i in range(num_transformer_blocks)])
    assert tot_num_heads_kept == global_num_heads_to_keep, \
        f"Total number of heads kept {tot_num_heads_kept} does not match the expected {global_num_heads_to_keep}."
    return keep_heads

def get_layer_heads_pruning(
        heads_layer_scores: torch.Tensor,
        original_num_heads: int,
        num_heads_to_prune: int,
        stocastic: bool,
        seed: Optional[int] = None,
        random: bool = False
) -> torch.Tensor:
    """
    Get the heads to keep for pruning at the layer level. 
    Args:
    -----------
    heads_layer_scores : torch.Tensor
        Tensor of shape (num_heads,) containing the scores for each attention head in the layer.
    original_num_heads : int
        The original number of heads in the layer.
    num_heads_to_prune : int
        Number of heads to prune from the layer. The actual number of heads kept will be
        original_num_heads - num_heads_to_prune.
    stocastic : bool
        If True, uses stochastic pruning based on the scores of the heads derived using the strategy.
    seed : Optional[int]
        Seed for the random number generator used in stochastic pruning. Required if `stocastic`
        is True.
    random : bool, default=False
        If True, randomly selects heads to keep. This is just useful for debugging purposes as a baseline.
    """
    new_num_heads = original_num_heads - num_heads_to_prune

    if new_num_heads == original_num_heads:
        keep_heads = torch.arange(original_num_heads)

    if random:
        assert seed is not None, \
            f"Random pruning requires a seed to be set"
        torch.manual_seed(seed)
        keep_heads = torch.randperm(original_num_heads)[:new_num_heads]
    else:        
        if stocastic:
            keep_heads = bernoulli_score_sampling(heads_layer_scores, new_num_heads, seed, max_tries=20)
        else:
            keep_heads = torch.topk(heads_layer_scores, k=new_num_heads, largest=True).indices
    assert keep_heads.numel() == new_num_heads, \
        f"Number of heads kept {keep_heads.numel()} does not match the expected {new_num_heads}."

    return keep_heads

def get_layer_mlp_pruning(
        mlp_layer_scores: torch.Tensor,
        original_num_rows: int,
        num_rows_to_prune: int,
        stocastic: bool,
        seed: Optional[int] = None
) -> torch.Tensor:
    
    new_num_rows = original_num_rows - num_rows_to_prune

    if new_num_rows == original_num_rows:
        keep_rows = torch.arange(original_num_rows)
            
    if stocastic:
        keep_rows = bernoulli_score_sampling(mlp_layer_scores, new_num_rows, seed, max_tries=3000)
    else:
        keep_rows = torch.topk(mlp_layer_scores, k=new_num_rows, largest=True).indices
    assert keep_rows.numel() == new_num_rows, \
        f"Number of rows kept {keep_rows.numel()} does not match the expected {new_num_rows}."
    return keep_rows

def get_global_mlp_pruning(
        model_mlp_scores: torch.Tensor,
        original_num_rows: int,
        num_transformer_blocks: int,
        num_rows_to_prune: int,
        min_rows: int,
        stocastic: bool,
        seed: Optional[int] = None    
):
    keep_rows = {i: torch.tensor([], device="cpu") for i in range(num_transformer_blocks)}
    model_mlp_scores = model_mlp_scores.flatten()
    global_num_rows = original_num_rows * num_transformer_blocks
    global_num_rows_to_keep = global_num_rows - num_rows_to_prune * num_transformer_blocks
    while True:
        if stocastic:
            idx_rows_to_keep = bernoulli_score_sampling(model_mlp_scores, global_num_rows_to_keep, seed, max_tries=1e4)
        else:
            idx_rows_to_keep = torch.topk(model_mlp_scores, k=global_num_rows_to_keep, largest=True).indices
        valid = True
        for layer_idx in range(num_transformer_blocks):
            start, end = layer_idx * original_num_rows, (layer_idx + 1) * original_num_rows
            layer_mask = (idx_rows_to_keep >= start) & (idx_rows_to_keep < end)
            keep_rows_layer = idx_rows_to_keep[layer_mask] - start
            if keep_rows_layer.numel() < min_rows:
                valid = False
                idx_rows_to_keep_layer = torch.topk(model_mlp_scores[start:end], k=min_rows, largest=True).indices
                abs_idx = start + idx_rows_to_keep_layer
                all_idx = torch.arange(start, end, device=model_mlp_scores.device)
                abs_idx_to_neg_inf = all_idx[~torch.isin(all_idx, abs_idx)]
                model_mlp_scores[abs_idx] = 1e+30 
                model_mlp_scores[abs_idx_to_neg_inf] = 1e-30
                break
            keep_rows[layer_idx] = keep_rows_layer
        if valid:
            break
    tot_num_rows_kept = sum([keep_rows[i].numel() for i in range(num_transformer_blocks)])
    assert tot_num_rows_kept == global_num_rows_to_keep, \
        f"Total number of rows kept {tot_num_rows_kept} does not match the expected {global_num_rows_to_keep}."
    return keep_rows
