import math
import torch
import torch.nn as nn

from tqdm import tqdm
import tensorly as tl
from tensorly.decomposition import tucker
from transformers import AutoConfig
from matplotlib import pyplot as plt
from utils.data import get_loaders, get_loaders_sample
from .preprocess import prepare_calibration_input
from utils.layerwrapper import WrappedGPT, WrappedGPT_out
from transformers.models.llama.configuration_llama import *

# from model_impl.llama.configuration_llama_copy import LlamaConfig
from model_impl.llama import LlamaAttention_sparseMix_label_all_compression, LlamaRope, sparse_project
from transformers.models.llama.modeling_llama import *

import utils


def prune_mix_label_sparse(model, layer_lst, llamaconfig:LlamaConfig, device):
    layers = model.model.layers
    weight_names = ['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj']
    llamaconfig.is_compress = True
    for i in tqdm(range(len(layers))):
        layer = layers[i]
        layer_subset = utils.find_layers(layer)

        new_subset = {}

        # print(layer_lst[i].keys())
        # import pdb;pdb.set_trace()
        for name in layer_subset:
            new_subset[name] = layer_subset[name].weight.data.half()
            if 'k_proj' not in name:
                continue
            # print(name)
            # print(layer_subset.keys())
            recover_name = 'k_up_proj'
            down_name = 'k_down_proj'
            sparse_name = 'sparse_proj'
            
            out_mean = layer_lst[i][name.replace('k_proj', 'k_rope') + 'mean'].float().to('cuda')
            out_cov = layer_lst[i][name.replace('k_proj', 'k_rope') + 'Q'].float().to('cuda')
            # inp_mean = layer_lst[i][name.replace('k_proj', 'k_rope') + 'inp_mean'].float().to('cuda')
            # inp_cov = layer_lst[i][name.replace('k_proj', 'k_rope') + 'inp_Q'].float().to('cuda')
            inp_mean = layer_lst[i][name + 'mean'].float().to('cuda')
            inp_cov = layer_lst[i][name + 'Q'].float().to('cuda')
            len_Q = len(inp_mean)

            reduced_rank = llamaconfig.k_high_rank
            sparse_rank = llamaconfig.sparse_rank
            # sparse_rank = 128

            _, Q = torch.linalg.eigh(inp_cov)
            Q_remain = Q[:, len_Q - reduced_rank:]
            Q_2 = Q[:, :len_Q - reduced_rank]
            L = Q_remain.T  # r X d1
            # print("L shape", L.shape
            R = Q_remain  # d2 X r
            # print("R shape", R.shape)
            # b = (inp_mean - Q_remain @ Q_remain.T @ inp_mean).squeeze(-1)
            b = (Q_2 @ Q_2.T @ inp_mean).squeeze(-1)

            new_subset[name.replace('k_proj', down_name)] = L.half().to('cuda').contiguous()
            new_subset[name.replace('k_proj', recover_name)] = R.half().to('cuda').contiguous()
            new_subset[name.replace('k_proj', recover_name) + 'bias'] = b.half().to('cuda').contiguous()

            # _, Q = torch.linalg.eigh(out_cov)
            # # Q = torch.zeros((32, 128, 8))
            # Q_remain = Q[:, :, -sparse_rank:]
            # Q_2 = Q[:, :, :-sparse_rank]
            # if i == 0:
            #     print(Q_remain.shape)
            # # Q_2 = Q[:, :len_Q - reduced_rank]
            # # print("L shape", L.shape
            # R = Q_remain  # d2 X r
            # # import pdb;pdb.set_trace()
            # # print("R shape", R.shape)
            # b = torch.matmul(torch.matmul(Q_2, Q_2.transpose(1, 2)), out_mean).squeeze(-1)
            # new_subset[name.replace('k_proj', sparse_name)] = R.half().to('cuda').contiguous()
            # new_subset[name + '.1bias'] = torch.zeros(b.shape).half().to('cuda')
                # print(new_subset.keys())
        # import pdb;pdb.set_trace()
        layers[i].self_attn = layers[i].self_attn.__class__(llamaconfig, i)
        layer = layers[i]

        subset = utils.find_layers(layer, [nn.Linear, sparse_project])
        # print(subset)

        for name in subset:
            if "self_attn" not in name:
                continue
            subset[name].weight.data = new_subset[name].to(device)
            if subset[name].bias is not None:
                subset[name].bias.data = new_subset[name + 'bias'].to(device)
