import torch
from tqdm import tqdm

import torch.nn as nn
from utils.layerwrapper import *
from utils.data import get_loaders, get_loaders_sample
from transformers.models.llama.configuration_llama import *
from model_impl.llama import LlamaRope
import utils

def prepare_calibration_input(model, dataloader, nsamples, device):
    # 准备校准输入
    use_cache = model.config.use_cache
    model.config.use_cache = False
    layers = model.model.layers

    # dev = model.hf_device_map["model.embed_tokens"]
    # if "model.embed_tokens" in model.hf_device_map:
    #     device = model.hf_device_map["model.embed_tokens"]

    dtype = next(iter(model.parameters())).dtype
    # inps = torch.zeros((128, model.seqlen, model.config.hidden_size), dtype=dtype, device=device)
    inps = torch.zeros((nsamples, 4096, model.config.hidden_size), dtype=dtype, device='cpu')
    inps.requires_grad = False
    cache = {'i': 0, 'attention_mask': None, "position_ids": None}

    class Catcher(nn.Module):
        def __init__(self, module):
            super().__init__()
            self.module = module

        def forward(self, inp, **kwargs):
            inps[cache['i']] = inp
            cache['i'] += 1
            cache['attention_mask'] = kwargs['attention_mask']
            cache['position_ids'] = kwargs['position_ids']
            raise ValueError

    layers[0] = Catcher(layers[0])
    for batch in dataloader:
        # print(batch[0])
        input_ids = batch[0]
        try:
            model(input_ids.to(device))
        except ValueError:
            pass
    layers[0] = layers[0].module

    outs = torch.zeros_like(inps)
    attention_mask = cache['attention_mask']
    position_ids = cache['position_ids']
    model.config.use_cache = use_cache

    return inps, outs, attention_mask, position_ids


def return_given_alpha(alpha, sort_res, W_metric, tmp_metric, sum_before):
    # 返回给定的alpha
    thres_cumsum = sum_before * alpha
    sort_mask = tmp_metric <= thres_cumsum.reshape((-1, 1))
    thres = torch.gather(sort_res[0], dim=1, index=sort_mask.sum(dim=1, keepdims=True) - 1)
    W_mask = (W_metric <= thres)
    cur_sparsity = (W_mask == True).sum() / W_mask.numel()
    return W_mask, cur_sparsity

def feature_base_key_allocate_label_sparse(model, tokenizer, nsamples, llamaconfig:LlamaConfig, device=torch.device("cuda:0"), seed=42):
    # lord
    use_cache = model.config.use_cache
    model.config.use_cache = False
    print("loading calibdation data")

    dataloader, _ = get_loaders_sample("c4", nsamples=nsamples, seed=seed, seqlen=4096, tokenizer=tokenizer)

    print("dataset loading complete")
    with torch.no_grad():
        inps, outs, attention_mask, position_ids = prepare_calibration_input(model, dataloader, nsamples, device)

    layers = model.model.layers

    layer_lst = []
    for i in tqdm(range(len(layers))):
        layer = layers[i]
        subset = utils.find_layers(layer, [LlamaRope, nn.Linear])
        wrapped_layers = {}
        device = layer.self_attn.q_proj.weight.data.device
        # import pdb;pdb.set_trace()
        for name in subset:
            
            if 'k_rope' in name:
                wrapped_layers[name] = WrappedGPT_after_rope_perhead(subset[name], num_head=llamaconfig.num_key_value_heads, device=device)
            if 'k_proj' in name:
                wrapped_layers[name] = WrappedGPT(subset[name], i)

        def add_batch(name):
            def tmp(_, inp, out):
                wrapped_layers[name].add_batch(inp[0].data, out.data)

            return tmp

        # 对每个(mlp或attention层)添加钩子函数
        handles = []
        for name in wrapped_layers:
            if not ('k_proj' in name or 'k_rope' in name):
            # if 'k_rope' not in name:
                continue
            handles.append(subset[name].register_forward_hook(add_batch(name)))
        for j in range(nsamples):
            with torch.no_grad():
                outs[j] = layer(inps[j].unsqueeze(0).to(device), attention_mask=attention_mask.to(device), position_ids=position_ids.to(device))[0]
        for h in handles:
            h.remove()

        new_subset = {}
        for name in subset:
            if 'k_rope' in name:
            # wrapped_layers[name].calculate()
                Q = wrapped_layers[name].out_cov
                # torch.linalg.eigh(Q)
                # plot_eigh_value(wrapped_layers[name].out_cov, i, 'out')
                # plot_eigh_value(wrapped_layers[name].inp_cov, i, 'inp')

                new_subset[name + 'mean'] = wrapped_layers[name].out_mean.to('cpu')
                new_subset[name + 'Q'] = Q.to('cpu')
                new_subset[name + 'inp_mean'] = wrapped_layers[name].inp_mean.to('cpu')
                new_subset[name + 'inp_Q'] = wrapped_layers[name].inp_cov.half().to('cpu')
            elif 'k_proj' in name:
                new_subset[name +'mean'] = wrapped_layers[name].out_mean.to('cpu')
                new_subset[name + 'Q'] = wrapped_layers[name].out_cov.to('cpu')
                
        layer_lst.append(new_subset)
        inps, outs = outs, inps

    model.config.use_cache = use_cache
    torch.cuda.empty_cache()

    return layer_lst

