import torch 
from models.block import (
    MultiHeadAttentionPruned, prune_heads, unprune_heads, get_active_heads
)
from pruning.pruning_utils import get_grads
import types 

import copy

from torch.utils.data import DataLoader
from typing import Dict, Union, Optional
import warnings

def structurally_prune_attention_heads(
        vit_model: torch.nn.Module,
        num_heads_to_prune: Union[int, Dict[int, list[int]]],
        strategy: str,
        model_name: str = "ViT",
        prune_text_encoder: bool = False,
        current_active_heads: Optional[Dict[int, torch.Tensor]] = None,
        dataloader: Optional[DataLoader] = None,
        device: Optional[torch.device] = None,
        text_logits = None,
        verbose: bool = False,
        use_bias: bool = False
    ) -> torch.nn.Module:
    """
    Structurally prunes attention heads from a Vision Transformer (ViT) model based on a specified strategy.
    The embedding dimension are not changed, only the number of heads is reduced.
    
    Args:
    -----------
    vit_model : torch.nn.Module
        A Vision Transformer model (e.g., ViT-B/16). At the moment, only supports ViT from torchvision.models
    
    num_heads_to_prune : Union[int, Dict[int, int]]
        Number of attention heads to prune. If an int, the same number is pruned in each layer.
        If a dict, keys are the number of heads to prune, and values are the list of layers to prune that number from.
        If strategy is "predefined", this should be a dict with keys as layer indices and values as lists of head indices to prune.

    strategy : str
        Pruning strategy. One of:
        - "predefined": uses a predefined list of heads to prune.
        - "gradient": uses absolute gradient-weight product.
        
    model_name : str, default="ViT"
        Name of the model architecture. Currently supports "ViT" and "OpenCLIP".
        This is used to determine the naming convention for the model's layers.

    prune_text_encodr : bool, default=False
        If True, prunes the text encoder of CLIP models as well. Only relevant if
        model_name is "OpenCLIP".

    current_active_heads : Optional[Dict[int, torch.Tensor]], default=None
        A dictionary mapping layer indices to tensors of currently active heads.
        If None, all heads are considered active. The keys are layer indices, and the values are tensors of head indices.
        If the model has a different number of heads per layer, this should be provided to avoid errors.
        If the model has a fixed number of heads per layer, this can be None, and the function will assume all heads are active.
        NOTE: if your pruning a model already pruned you can use 'model.get_active_heads()' to get the current active heads.
    
    dataloader : Optional[torch.utils.data.DataLoader], default=None
        Required for gradient-based strategies. Provides data for backpropagation to compute gradients.

    device : Optional[torch.device], default=None
        Device on which to run gradient computation. Required for gradient-based strategies.

    text_logits : Optional, default=None
        The text logits to use for computing gradients in the case of gradient-based pruning strategies.

    verbose : bool, default=False
        Whether to print per-layer pruning info.

    stocastic : bool, default=False
        If True, uses stochastic pruning based on the scores of the heads derived using the strategy.
        Drop probability is Bernoulli('score / sum_score') with probabilities [avg(probabs), 1] 
        # NOTE: this is a naive implementation, better methods should be explored and the performance has not been tested yet. 
    Returns:
    --------
    vit_model_pruned : torch.nn.Module
        A structurally pruned Vision Transformer model with heads removed according to the given strategy.

    Notes:
    ------
    For gradient-based strategies, `get_grads` is called at each layer iteration to get new estimates 
    after the pruning of the previous layers.
    """

    models_naming = {
        "OpenCLIP" : ["visual.transformer.resblocks", "attn"],
        "ViT" : ["encoder.layers", "self_attention"]
    }

    allowed_strategies = ["gradient", "predefined"]

    assert model_name in models_naming, \
        f"Model name must be one of {list(models_naming.keys())}, but got {model_name}."

    if prune_text_encoder and strategy != "predefined":
        raise NotImplementedError("Pruning the text encoder is only implemented for 'predefined' strategy.")

    if strategy == "predefined" and model_name == "OpenCLIP":
        layer_indices = list(num_heads_to_prune.keys())
        assert not any(idx > 11 for idx in layer_indices) or prune_text_encoder, \
            f"Layer indices in num_heads_to_prune must be between 0 and 11 for the vision transformer, " \
            f"but got {layer_indices}. If you want to prune the text encoder, set prune_text_encoder=True."
        if prune_text_encoder:
            num_layers = 12
            
            original_num_heads_vision = [vit_model.visual.transformer.resblocks[i].attn.num_heads for i in range(num_layers)]
            original_num_heads_text = [vit_model.transformer.resblocks[i].attn.num_heads for i in range(num_layers)]
            current_active_heads_vision = {layer: torch.arange(num_heads, device=device) for layer, num_heads in enumerate(original_num_heads_vision)}
            current_active_heads_text = {layer: torch.arange(num_heads, device=device) for layer, num_heads in enumerate(original_num_heads_text)}

            layer_to_heads_txt = {}
            layer_to_heads_vis = {}

            for layer_idx, heads_to_prune in num_heads_to_prune.items():
                if layer_idx <= 11:
                    layer_heads = original_num_heads_vision[layer_idx] 
                    layer_heads = list(range(layer_heads))
                    heads_to_keep = [head for head in layer_heads if head not in heads_to_prune]
                    layer_to_heads_vis[layer_idx] = heads_to_keep
                else:
                    txt_layer_idx = layer_idx - 12
                    layer_heads = original_num_heads_text[txt_layer_idx]
                    layer_heads = list(range(layer_heads))
                    heads_to_keep = [head for head in layer_heads if head not in heads_to_prune]
                    layer_to_heads_txt[txt_layer_idx] = heads_to_keep

            for layer_idx in range(num_layers):
                if layer_idx not in layer_to_heads_vis:
                    layer_to_heads_vis[layer_idx] = list(range(original_num_heads_vision[layer_idx]))
                if layer_idx not in layer_to_heads_txt:
                    layer_to_heads_txt[layer_idx] = list(range(original_num_heads_text[layer_idx]))

            if verbose:
                for layer, heads_to_keep in layer_to_heads_vis.items():
                    if len(heads_to_keep) != original_num_heads_vision[layer]:
                        print(f"[Vision Layer {layer}] Keeping heads ({len(heads_to_keep)}): {heads_to_keep}")
                
                for layer, heads_to_keep in layer_to_heads_txt.items():
                    if len(heads_to_keep) != original_num_heads_text[layer]:
                        print(f"[Text Layer {layer}] Keeping heads ({len(heads_to_keep)}): {heads_to_keep}")

            vit_model_pruned = copy.deepcopy(vit_model).to(device)
            vision_transformer_blocks = vit_model_pruned.visual.transformer.resblocks
            text_transformer_blocks = vit_model_pruned.transformer.resblocks

            keep_heads = dict()
            for layer_idx, (block_new) in enumerate(vision_transformer_blocks):
                mha_new = block_new.attn
                heads_to_keep = layer_to_heads_vis[layer_idx]
                keep_heads[layer_idx] = torch.tensor(heads_to_keep)
                keep_heads[layer_idx] = keep_heads[layer_idx].to(device)
                num_heads = mha_new.num_heads
                embed_dim = mha_new.embed_dim
                mha_new = MultiHeadAttentionPruned(
                        keep_heads[layer_idx], 
                        current_active_heads_vision[layer_idx],
                        embed_dim,
                        mha_new.in_proj_weight.data,  
                        mha_new.out_proj.weight.data,
                        mha_new.in_proj_bias.data, 
                        mha_new.out_proj.bias.data
                        )
                setattr(block_new, "attn", mha_new)

            for layer_idx, (block_new) in enumerate(text_transformer_blocks):
                mha_new = block_new.attn
                heads_to_keep = layer_to_heads_txt[layer_idx]
                keep_heads[layer_idx] = torch.tensor(heads_to_keep)
                keep_heads[layer_idx] = keep_heads[layer_idx].to(device)
                num_heads = mha_new.num_heads
                embed_dim = mha_new.embed_dim
                mha_new = MultiHeadAttentionPruned(
                    keep_heads[layer_idx], 
                    current_active_heads_text[layer_idx],
                    embed_dim,
                    mha_new.in_proj_weight.data,  
                    mha_new.out_proj.weight.data,
                    mha_new.in_proj_bias.data, 
                    mha_new.out_proj.bias.data
                    )
                setattr(block_new, "attn", mha_new)
            return vit_model_pruned

    model_naming = models_naming[model_name]
    num_heads_to_prune = copy.deepcopy(num_heads_to_prune)

    assert strategy in allowed_strategies, \
        f"Strategy must be one of {allowed_strategies}, instead got{strategy}"
    assert strategy == "predefined" or (dataloader is not None and device is not None), \
        f"Gradient required for {strategy} pruning, but dataloader or device not provided."
    assert isinstance(num_heads_to_prune, (dict,int)), \
        f"num_heads_to_prune must be either an int or a dict, but got {type(num_heads_to_prune)}."
    if strategy == "predefined":
        assert isinstance(num_heads_to_prune, dict), \
            f"When strategy is 'predefined', num_heads_to_prune must be a dict, but got {type(num_heads_to_prune)}."
        if dataloader is not None:
            warnings.warn(
                f"Using dataloader with strategy {strategy} is not necessary, as no gradients are computed"
                " and it will be ignored.",
                UserWarning
            )
        if device is not None:
            warnings.warn(
                f"Using device with strategy {strategy} is not necessary, as no gradients are computed"
                " and it will be ignored.",
                UserWarning
            )
    
    device = device or next(vit_model.parameters()).device

    transformer_blocks = get_module_by_path(vit_model, model_naming[0])
    original_num_heads = []
    for i in range(len(transformer_blocks)):
        mha = getattr(transformer_blocks[i], model_naming[1])
        original_num_heads.append(mha.num_heads)
    
    is_num_heads_same = len(set(original_num_heads)) == 1
    if current_active_heads is None:
        assert is_num_heads_same, \
            f"The number of heads per is not the same across all layers {original_num_heads}, so current_active_heads must be provided. " 
        current_active_heads = {layer: torch.arange(num_heads) for layer, num_heads in enumerate(original_num_heads)}
    current_active_heads = {layer: active_heads.to(device) for layer, active_heads in current_active_heads.items()}
    
    if isinstance(num_heads_to_prune, int):
        layer_to_heads = {layer: num_heads_to_prune for layer in range(len(transformer_blocks))}
    elif isinstance(num_heads_to_prune, dict):
        if strategy == "predefined":
            layer_to_heads = {}
            for layer_idx, heads_to_prune in num_heads_to_prune.items():
                layer_heads = original_num_heads[layer_idx] 
                layer_heads = list(range(layer_heads))
                heads_to_keep = [head for head in layer_heads if head not in heads_to_prune]
                num_heads_to_prune[layer_idx] = heads_to_keep
                layer_to_heads[layer_idx] = heads_to_keep
        else:
            layer_to_heads = {layer: num_heads for num_heads, layers in num_heads_to_prune.items() for layer in layers}     
    if isinstance(num_heads_to_prune, dict):
        if isinstance(list(num_heads_to_prune.values())[0], int):
            layer_to_heads = {i: layer_to_heads.get(i, 0) for i in range(len(transformer_blocks))}
        elif isinstance(list(num_heads_to_prune.values())[0], list):
            layer_to_heads = {i: layer_to_heads.get(i, list(range(len(transformer_blocks)))) \
                                for i in range(len(transformer_blocks))}
    if strategy == "gradient":
        for layer_idx, num_heads in layer_to_heads.items():
            assert num_heads < len(current_active_heads[layer_idx]), \
                f"Number of heads to prune must be < {len(current_active_heads[layer_idx])}, but got {num_heads} for layer {layer_idx}."
    
    vit_model_pruned = copy.deepcopy(vit_model).to(device)
    transformer_blocks = get_module_by_path(vit_model_pruned, model_naming[0])

    # number of heads per layer might be different, so we init with the maximum number of heads
    max_head_idx = max(h.max().item() for h in current_active_heads.values())
    model_heads_scores = torch.zeros(len(transformer_blocks), max_head_idx + 1, device=device)
    keep_heads = {i: torch.tensor([], device=device) for i in range(len(transformer_blocks))}    

    for layer_idx, (block_new) in enumerate(transformer_blocks):
        mha_new = getattr(block_new, model_naming[1])
        
        if strategy == "gradient":
            layer_heads_scores = get_heads_scores(
                mha = mha_new,
                vit_model = vit_model_pruned,
                strategy = strategy, 
                dataloader = dataloader,
                device = device,
                text_logits = text_logits,
                compute_grads = True,
                use_bias = use_bias
            )
        
            layer_heads_scores_adjusted = torch.zeros(model_heads_scores.shape[1], device = device)
            layer_heads_scores_adjusted[current_active_heads[layer_idx]] = layer_heads_scores
            layer_heads_scores_adjusted[layer_heads_scores_adjusted == 0] = -1e-30
            model_heads_scores[layer_idx,:] = layer_heads_scores_adjusted

            # if pruning at the level of the layer, gradients are recomputed after each layer pruning
            num_heads_to_prune = layer_to_heads[layer_idx]    
            heads_layer_scores = model_heads_scores[layer_idx,:]
            keep_heads[layer_idx] = get_layer_heads_pruning(
                heads_layer_scores,
                len(current_active_heads[layer_idx]),
                num_heads_to_prune
            )
        else:
            heads_to_keep = layer_to_heads[layer_idx]
            keep_heads[layer_idx] = torch.tensor(heads_to_keep)
        keep_heads[layer_idx] = keep_heads[layer_idx].to(device)
        num_heads = mha_new.num_heads
        embed_dim = mha_new.embed_dim
        mha_new = MultiHeadAttentionPruned(keep_heads[layer_idx], 
                                           current_active_heads[layer_idx],
                                           embed_dim,
                                           mha_new.in_proj_weight.data,  
                                           mha_new.out_proj.weight.data,
                                           mha_new.in_proj_bias.data, 
                                           mha_new.out_proj.bias.data
                                           )
        setattr(block_new, model_naming[1], mha_new)
    
    if verbose:
        for layer_idx, keep_heads_layer in keep_heads.items():
            print(f"[Layer {layer_idx}] Keeping heads ({keep_heads_layer.numel()}): {keep_heads_layer.sort()[0].tolist()}") 
        total_num_heads_kept = sum([keep_heads[i].numel() for i in range(len(transformer_blocks))])
        print(f"Total number of heads kept: {total_num_heads_kept}") 

    vit_model_pruned.prune_heads = types.MethodType(prune_heads, vit_model_pruned)
    vit_model_pruned.unprune_heads = types.MethodType(unprune_heads, vit_model_pruned)
    vit_model_pruned.get_active_heads = types.MethodType(get_active_heads, vit_model_pruned)
    
    return vit_model_pruned

def get_module_by_path(model, path):
    for attr in path.split('.'):
        model = getattr(model, attr)
    return model

def get_heads_scores(
    mha: torch.nn.Module,
    strategy: str,
    compute_grads: bool,
    vit_model: torch.nn.Module = None,
    dataloader: DataLoader = None,
    device: torch.device = None,
    text_logits: torch.Tensor = None,
    use_bias : bool = False
):
    """
    Given a MultiHeadAttention module that follows the design of torchvision or OpenCLIP, 
    compute the scores for each attention head based on the specified strategy. 
    Args:
    -----------
    mha : torch.nn.Module
        A MultiHeadAttention module from a Vision Transformer model.
    vit_model : torch.nn.Module
        The Vision Transformer model from which the MultiHeadAttention module is taken.
        Needed for gradient-based strategies.
    strategy : str
        The strategy to use for scoring the attention heads. One of:
        - "magnitude": uses L1 norm of weights.
        - "gradient": uses absolute gradient-weight product.
    compute_grads : bool
        If True, computes new gradients even if they were alrady computed.
        This is useful for example when pruning layer by layer to get new estimates after the pruning of the previous layers.
    device : torch.device
        The device on which to compute the gradients, if needed.
    dataloader : torch.utils.data.DataLoader
        A DataLoader providing data for backpropagation to compute gradients.
        In general this should be either the training or validation dataloader.
    text_logits : torch.Tensor
        The text logits to use for computing gradients in the case of gradient-based pruning strategies and CLIP models.
    use_bias : bool, default=False
        Whether to use the bias terms in the MultiHeadAttention module for scoring. (not tested yet)
    """

    supported_strategies = ["magnitude", "gradient", "gradient_squared", "attention_scores"]

    embed_dim = mha.embed_dim
    num_heads = mha.num_heads
    head_dim = mha.head_dim
            
    W = mha.in_proj_weight
    q_weight, k_weight, v_weight = W.chunk(3, dim=0)
    if use_bias:
        q_bias, k_bias, v_bias = mha.in_proj_bias.chunk(3, dim=0)

    # Score attention heads
    if strategy == "magnitude":
        q_score = q_weight.abs().reshape(num_heads, head_dim, -1).mean(dim=(1, 2))
        k_score = k_weight.abs().reshape(num_heads, head_dim, -1).mean(dim=(1, 2))
        v_score = v_weight.abs().reshape(num_heads, head_dim, -1).mean(dim=(1, 2))
        if use_bias:
            q_score += q_bias.abs().reshape(num_heads, head_dim).mean(dim=1)
            k_score += k_bias.abs().reshape(num_heads, head_dim).mean(dim=1)
            v_score += v_bias.abs().reshape(num_heads, head_dim).mean(dim=1)
    elif strategy == "gradient":
        if compute_grads or mha.in_proj_weight.grad is None:
            get_grads(vit_model, dataloader, device, text_logits)
        G = mha.in_proj_weight.grad
        q_grad, k_grad, v_grad = G.chunk(3, dim=0)
        score_func = torch.abs if strategy == "gradient" else torch.square
        q_score = score_func(q_weight * q_grad).view(num_heads, head_dim, embed_dim).mean(dim=(1, 2))
        k_score = score_func(k_weight * k_grad).view(num_heads, head_dim, embed_dim).mean(dim=(1, 2))
        v_score = score_func(v_weight * v_grad).view(num_heads, head_dim, embed_dim).mean(dim=(1, 2))
        if use_bias:
            q_grad_bias, k_grad_bias, v_grad_bias = mha.in_proj_bias.grad.chunk(3, dim=0)
            q_score += score_func(q_bias * q_grad_bias).view(num_heads, head_dim).mean(dim=1)
            k_score += score_func(k_bias * k_grad_bias).view(num_heads, head_dim).mean(dim=1)
            v_score += score_func(v_bias * v_grad_bias).view(num_heads, head_dim).mean(dim=1)
    else:
        raise ValueError(f"Unknown strategy: {strategy}. Supported strategies are: {supported_strategies}.")

    scores = (q_score + k_score + v_score) / 3
    return scores


def get_layer_heads_pruning(
        heads_layer_scores: torch.Tensor,
        original_num_heads: int,
        num_heads_to_prune: int,
) -> torch.Tensor:
    """
    Get the heads to keep for pruning at the layer level. 
    Args:
    -----------
    heads_layer_scores : torch.Tensor
        Tensor of shape (num_heads,) containing the scores for each attention head in the layer.
    original_num_heads : int
        The original number of heads in the layer.
    num_heads_to_prune : int
        Number of heads to prune from the layer. The actual number of heads kept will be
        original_num_heads - num_heads_to_prune.
    """
    new_num_heads = original_num_heads - num_heads_to_prune

    if new_num_heads == original_num_heads:
        keep_heads = torch.arange(original_num_heads)
    keep_heads = torch.topk(heads_layer_scores, k=new_num_heads, largest=True).indices
    assert keep_heads.numel() == new_num_heads, \
        f"Number of heads kept {keep_heads.numel()} does not match the expected {new_num_heads}."

    return keep_heads