import math

import numpy as np
import torch
from torch import nn
from transformers import Conv1D
from transformers import DataCollatorForLanguageModeling
import torch
from tqdm import tqdm
import gc
import torch.nn.functional as F
from torch.utils.data import DataLoader
from accelerate import Accelerator
import matplotlib.pyplot as plt

def nested_getattr(obj, attr):
    for attr_part in attr.split('.'):
        obj = getattr(obj, attr_part)
    return obj


def find_layers(module, layers=[nn.Linear, Conv1D], name=''):
    """
    Recursively find the layers of a certain type in a module.

    Args:
        module (nn.Module): PyTorch module.
        layers (list): List of layer types to find.
        name (str): Name of the module.

    Returns:
        dict: Dictionary of layers of the given type(s) within the module.
    """
    if type(module) in layers:
        return {name: module}
    res = {}
    for name1, child in module.named_children():
        res.update(find_layers(
            child, layers=layers, name=name + '.' + name1 if name != '' else name1
        ))
    return res


def prepare_calibration_input(model, dataloader, nsamples, seqlen):
    """
    Prepare inputs for model calibration.

    Args:
        model (nn.Module): The model to prepare inputs for.
        dataloader (DataLoader): DataLoader object to fetch input data.
        device (torch.device): Device on which the model is loaded.

    Returns:
        inps (torch.Tensor): Input tensor for calibration.
        outs (torch.Tensor): Output tensor for calibration.
        attention_mask (torch.Tensor): Attention mask tensor.
        position_ids (torch.Tensor): Position IDs tensor.
    """
    use_cache = model.config.use_cache
    model.config.use_cache = False
    try:
        layers = model.model.layers
    except AttributeError:
        try:
            layers = model.base_model.layers
        except AttributeError:
            layers = model.base_model.decoder.layers
    if "model.embed_tokens" in getattr(model, 'hf_device_map', {}):
        device = model.hf_device_map["model.embed_tokens"]
    else:
        device = model.device
    dtype = next(iter(model.parameters())).dtype
    # dtype = torch.float
    inps = torch.zeros((nsamples, seqlen, model.config.hidden_size), dtype=dtype, device=device)
    inps.requires_grad = False
    cache = {'i': 0, 'attention_mask': None, "position_ids": None}

    class Catcher(nn.Module):
        def __init__(self, module):
            super().__init__()
            self.module = module
            if hasattr(module, "self_attn"):
                self.self_attn = module.self_attn
            elif hasattr(module, "attn"):
                self.attn = module.attn

        def forward(self, inp, **kwargs):
            inps[cache['i']] = inp
            cache['i'] += 1
            cache['attention_mask'] = kwargs['attention_mask']
            cache['position_ids'] = kwargs['position_ids']
            raise ValueError

    layers[0] = Catcher(layers[0])
    for batch in dataloader:
        try:
            # model(torch.rand_like(batch[0].to(device)))
            model(batch[0].to(device))
        except ValueError:
            pass
    layers[0] = layers[0].module

    outs = torch.zeros_like(inps)
    attention_mask = cache['attention_mask']
    position_ids = cache['position_ids']
    model.config.use_cache = use_cache

    return inps, outs, attention_mask, position_ids


def compress_residue(layer, attn_mask, mlp_mask, attn_mean_inp, mlp_mean_inp, device, mapping, bias=True, head_dim=128,
                     is_gqa=False):
    """
    Compress a model layer by masking or pruning based on the given masks,
    and return the pruned parameters as separate nn.Linear layers.

    Attention: for pruned attention parameters, the pruned parts of q, k, v, o
    will be returned as separate Linear layers whose weight shape corresponds to
    the parts that were removed by the mask.

    Args:
        layer (nn.Module): The model layer to compress.
        attn_mask (torch.Tensor): The mask to apply to the attention weights.
        mlp_mask (torch.Tensor): The mask to apply to the MLP weights.
        attn_mean_inp (torch.Tensor): The mean attention input.
        mlp_mean_inp (torch.Tensor): The mean MLP input.
        device (torch.device): Device on which the model is loaded.
        bias (bool, optional): Whether to consider bias while compressing. Defaults to True.
        head_dim (int, optional): The hidden dimension per head. Defaults to 128.
        is_gqa (bool, optional): Whether using GQA (Grouped Query Attention). Defaults to False.

    Returns:
        dict: A dictionary of new nn.Linear layers representing the pruned parameters.
    """
    pruned_layers = {}  # 用于存储各个剪枝掉的参数对应的 Linear 层

    # ----- Attention Weight Pruning -----
    if attn_mask is not None:
        q = nested_getattr(layer, mapping['attn']['q'])
        k = nested_getattr(layer, mapping['attn']['k'])
        v = nested_getattr(layer, mapping['attn']['v'])
        o = nested_getattr(layer, mapping['attn']['o'])

        if is_gqa:
            repeat_count = layer.self_attn.num_heads // layer.self_attn.num_key_value_heads
            retain_heads_kv = int(torch.count_nonzero(attn_mask).item())
            retain_heads_qo = retain_heads_kv * repeat_count

            attn_mask_kv = attn_mask.repeat_interleave(head_dim)
            attn_mask_qo = attn_mask_kv.repeat_interleave(repeat_count)

            # 对 k, v 进行剪枝（行剪枝，即丢弃部分输出维度）
            for m, name in zip([k, v], ['k', 'v']):
                keep_idx = torch.where(attn_mask_kv)[0]
                prune_idx = torch.where(~attn_mask_kv)[0]
                # 保存被剪枝的部分：权重 shape = (num_pruned, in_features)
                pruned_weight = m.weight.data[prune_idx].clone().detach()
                if m.bias is not None:
                    pruned_bias = m.bias.data[prune_idx].clone().detach()
                else:
                    pruned_bias = None
                # 构造新的 Linear 层来存储被剪枝参数
                pruned_layer = nn.Linear(m.in_features, len(prune_idx), bias=(m.bias is not None))
                pruned_layer.weight.data = pruned_weight
                if m.bias is not None:
                    pruned_layer.bias.data = pruned_bias
                pruned_layers[name + '_pruned'] = pruned_layer

                # 更新原层：只保留 keep_idx 部分
                m.weight.data = m.weight.data[keep_idx]
                m.out_features = int(attn_mask_kv.sum().item())
                if m.bias is not None:
                    m.bias.data = m.bias.data[keep_idx]

            # 对 q 进行剪枝（行剪枝）
            for m, name in zip([q], ['q']):
                keep_idx = torch.where(attn_mask_qo)[0]
                prune_idx = torch.where(~attn_mask_qo)[0]
                pruned_weight = m.weight.data[prune_idx].clone().detach()
                if m.bias is not None:
                    pruned_bias = m.bias.data[prune_idx].clone().detach()
                else:
                    pruned_bias = None
                pruned_layer = nn.Linear(m.in_features, len(prune_idx), bias=(m.bias is not None))
                pruned_layer.weight.data = pruned_weight
                if m.bias is not None:
                    pruned_layer.bias.data = pruned_bias
                pruned_layers[name + '_pruned'] = pruned_layer

                m.weight.data = m.weight.data[keep_idx]
                m.out_features = int(attn_mask_qo.sum().item())
                if m.bias is not None:
                    m.bias.data = m.bias.data[keep_idx]

            # 对 o 进行剪枝：注意 o 的剪枝方向在于输入维度（列剪枝）
            if bias:
                output_bias = ((attn_mean_inp * ~attn_mask_qo.to(device)) @ o.weight.data.T)
            keep_idx = torch.where(attn_mask_qo)[0]
            prune_idx = torch.where(~attn_mask_qo)[0]
            pruned_weight = o.weight.data[:, prune_idx].clone().detach()
            # 构造一个新的 Linear 层，其 in_features 为剪枝掉的列数，out_features 保持 o.out_features
            pruned_layer = nn.Linear(len(prune_idx), o.out_features, bias=False)
            pruned_layer.weight.data = pruned_weight
            pruned_layers['o_pruned'] = pruned_layer

            # 更新原 o 层：只保留 keep_idx 列
            output_weight = o.weight.data[:, keep_idx]
            o.in_features = int(attn_mask_qo.sum().item())
            if bias:
                o.bias.data = output_bias
            o.weight.data = output_weight

            # 更新 layer 中 attention 配置
            if hasattr(layer, 'attn'):
                layer.attn.num_heads = retain_heads_qo
                layer.attn.num_key_value_heads = retain_heads_kv
                layer.attn.num_key_value_groups = layer.attn.num_heads // layer.attn.num_key_value_heads
                layer.attn.hidden_size = retain_heads_qo * head_dim
            elif hasattr(layer, 'self_attn'):
                layer.self_attn.num_heads = retain_heads_qo
                layer.self_attn.num_key_value_heads = retain_heads_kv
                layer.self_attn.num_key_value_groups = layer.self_attn.num_heads // layer.self_attn.num_key_value_heads
                layer.self_attn.hidden_size = retain_heads_qo * head_dim
                if hasattr(layer.self_attn, 'embed_dim'):
                    layer.self_attn.embed_dim = retain_heads_qo * head_dim
            else:
                raise KeyError("Layer does not have attn or self_attn attribute")

        else:
            # 非 GQA 分支：直接对 q, k, v 使用相同的 mask 处理
            retain_heads = int(torch.count_nonzero(attn_mask).item())
            attn_mask_rep = attn_mask.repeat_interleave(head_dim)

            for m, name in zip([q, k, v], ['q', 'k', 'v']):
                keep_idx = torch.where(attn_mask_rep)[0]
                prune_idx = torch.where(~attn_mask_rep)[0]
                pruned_weight = m.weight.data[prune_idx].clone().detach()
                if m.bias is not None:
                    pruned_bias = m.bias.data[prune_idx].clone().detach()
                else:
                    pruned_bias = None
                pruned_layer = nn.Linear(m.in_features, len(prune_idx), bias=(m.bias is not None))
                pruned_layer.weight.data = pruned_weight
                if m.bias is not None:
                    pruned_layer.bias.data = pruned_bias
                pruned_layers[name + '_pruned'] = pruned_layer

                m.weight.data = m.weight.data[keep_idx]
                m.out_features = int(attn_mask_rep.sum().item())
                if m.bias is not None:
                    m.bias.data = m.bias.data[keep_idx]

            # 对 o 进行列剪枝
            if bias:
                output_bias = ((attn_mean_inp * ~attn_mask_rep.to(device)) @ o.weight.data.T)
            keep_idx = torch.where(attn_mask_rep)[0]
            prune_idx = torch.where(~attn_mask_rep)[0]
            pruned_weight = o.weight.data[:, prune_idx].clone().detach()
            pruned_layer = nn.Linear(len(prune_idx), o.out_features, bias=False)
            pruned_layer.weight.data = pruned_weight
            pruned_layers['o_pruned'] = pruned_layer

            output_weight = o.weight.data[:, keep_idx]
            o.in_features = int(attn_mask_rep.sum().item())
            if bias:
                o.bias.data = output_bias
            o.weight.data = output_weight

            if hasattr(layer, 'attn'):
                layer.attn.num_heads = retain_heads
                layer.attn.num_key_value_heads = retain_heads
                layer.attn.num_key_value_groups = layer.attn.num_heads // layer.attn.num_key_value_heads
                layer.attn.hidden_size = retain_heads * head_dim
            elif hasattr(layer, 'self_attn'):
                layer.self_attn.num_heads = retain_heads
                layer.self_attn.num_key_value_heads = retain_heads
                layer.self_attn.num_key_value_groups = layer.self_attn.num_heads // layer.self_attn.num_key_value_heads
                layer.self_attn.hidden_size = retain_heads * head_dim
                if hasattr(layer.self_attn, 'embed_dim'):
                    layer.self_attn.embed_dim = retain_heads * head_dim
            else:
                raise KeyError("Layer does not have attn or self_attn attribute")

    # ----- MLP Weight Pruning -----
    if mlp_mask is not None:
        u = nested_getattr(layer, mapping['mlp']['u'])
        d = nested_getattr(layer, mapping['mlp']['d'])
        if 'g' in mapping['mlp']:
            g = nested_getattr(layer, mapping['mlp']['g'])
        else:
            g = None

        # 对 u 和 g 进行行剪枝
        for m, name in zip([u, g], ['u', 'g']):
            if m is not None:
                keep_idx = torch.where(mlp_mask)[0]
                prune_idx = torch.where(~mlp_mask)[0]
                pruned_weight = m.weight.data[prune_idx].clone().detach()
                if m.bias is not None:
                    pruned_bias = m.bias.data[prune_idx].clone().detach()
                else:
                    pruned_bias = None
                pruned_layer = nn.Linear(m.in_features, len(prune_idx), bias=(m.bias is not None))
                pruned_layer.weight.data = pruned_weight
                if m.bias is not None:
                    pruned_layer.bias.data = pruned_bias
                pruned_layers[name + '_pruned'] = pruned_layer

                m.weight.data = m.weight.data[keep_idx]
                m.out_features = int(mlp_mask.sum().item())
                if m.bias is not None:
                    m.bias.data = m.bias.data[keep_idx]

        # 对 d 进行列剪枝
        if bias:
            output_bias = ((mlp_mean_inp * ~mlp_mask.to(device)) @ d.weight.data.T)
        keep_idx = torch.where(mlp_mask)[0]
        prune_idx = torch.where(~mlp_mask)[0]
        pruned_weight = d.weight.data[:, prune_idx].clone().detach()
        pruned_layer = nn.Linear(len(prune_idx), d.out_features, bias=False)
        pruned_layer.weight.data = pruned_weight
        pruned_layers['d_pruned'] = pruned_layer

        output_weight = d.weight.data[:, keep_idx]
        if mapping['mlp']['block'] != '':
            layer.mlp.intermediate_size = int(mlp_mask.sum().item())
        d.in_features = int(mlp_mask.sum().item())
        if bias:
            d.bias.data = output_bias
        d.weight.data = output_weight

    torch.cuda.empty_cache()
    return pruned_layers


def cosine_similarity(A, B):
    sim = F.cosine_similarity(A, B, dim=-1)
    token_mean = torch.mean(sim, dim=1)
    sample_mean = torch.mean(token_mean, dim=0)
    return sim, token_mean, sample_mean


def dot_product_similarity(A, B):
    # 在最后一个维度（4096）上做点乘
    sim = (A * B).sum(dim=-1)  # shape: (500, 256)

    # 在token维度（256）上求平均
    token_mean = torch.mean(sim, dim=1)  # shape: (500,)

    # 在sample维度（500）上求平均
    sample_mean = torch.mean(token_mean, dim=0)  # shape: scalar

    return sim, token_mean, sample_mean


def angular_distance(A, B):
    sim = cosine_similarity(A, B)[2]
    cos_sim = torch.clamp(sim, -1.0, 1.0)
    distance = torch.acos(cos_sim)
    return distance


def l2_distance(A, B):
    l2_distance = torch.norm(A - B)
    # l2_distance_token_wise = torch.norm(A - B, dim=(1, 2)).mean(dim=0)
    return l2_distance, l2_distance


def std(A, B):
    std_A = torch.std(A)
    std_B = torch.std(B)
    std_d = torch.abs(std_A - std_B)
    std_token_wise_A = torch.std(A, dim=(1, 2)).mean(dim=0)
    std_token_wise_B = torch.std(B, dim=(1, 2)).mean(dim=0)
    std_token_wise = torch.abs(std_token_wise_A - std_token_wise_B)
    return std_d, std_token_wise


def mean(A, B):
    mean_A = torch.mean(A)
    mean_B = torch.mean(B)
    mean_d = torch.abs(mean_A - mean_B)
    mean_token_wise_A = torch.mean(A, dim=(1, 2)).mean(dim=0)
    mean_token_wise_B = torch.mean(B, dim=(1, 2)).mean(dim=0)
    mean_token_wise = torch.abs(mean_token_wise_A - mean_token_wise_B)
    return mean_d, mean_token_wise

def greedy_by_value(x, i0) -> torch.Tensor:
    """
    从起始 index i0 出发，动态比较左右邻居值，先取值更小的一侧。
    返回探测的索引序列 order，再可以用 x[order] 取值排序后的新 tensor。
    """
    N = x.size(0)
    order = [i0]
    left = right = i0

    # 不断扩展，直到左右两边都走到头
    while left > 0 or right < N - 1:
        candidates = []
        if left > 0:
            candidates.append((x[left-1].item(), left-1))   # 左侧邻居
        if right < N - 1:
            candidates.append((x[right+1].item(), right+1)) # 右侧邻居

        # 选出“值更小”的那一项
        val, idx = min(candidates, key=lambda t: t[0])

        order.append(idx)
        # 更新边界
        if idx < left:
            left = idx
        else:
            right = idx

    return torch.tensor(order, dtype=torch.long, device=x.device)

def show(tensor_list, label, title):
    tensor_list = [t.cpu() for t in tensor_list]
    fig, ax1 = plt.subplots(figsize=(14, 7))
    ax1.set_xlabel('Layer Index')
    ax1.set_ylabel('Tensor Value', color='red')
    ax1.plot(range(len(tensor_list)), tensor_list, marker='o', color='red', label=label)

    plt.title(title)
    fig.legend(loc='upper right', bbox_to_anchor=(1, 1), bbox_transform=ax1.transAxes)

    ax1.grid(True, linestyle='--', alpha=0.7)
    plt.savefig(f'{title}.png')
    plt.tight_layout()
    plt.close(fig)