import torch
from transformers import AutoModel
from rebuttal_gen_X import load_embeddings
from rebuttal_model import GDMask
from rebuttal_utils import set_top_large_to_one, calculate_weight_matrix


def pruning_sub_M(M_q, j, Bs, rho):
    d = M_q.shape[0]
    sub_mat = M_q[:, j:j+Bs]
    sub_mat_flatten = sub_mat.flatten()
    num_elements = sub_mat_flatten.numel()
    k = int((1 - rho) * num_elements)
    if k > 0:
        tau = torch.topk(sub_mat_flatten, k, largest=True)[0][-1]
        sub_mat_judge = sub_mat.clone()
        sub_mat[sub_mat_judge <= tau] = 0
        sub_mat[sub_mat_judge > tau] = 1
    else:
        sub_mat = torch.ones_like(sub_mat)
    M_q[:, j:j+Bs] = sub_mat
    return M_q

def pruning_mask_sparse_gpt(X, W, rho, B, Bs):

    mask = torch.ones_like(W)
    d_col = W.shape[1]
    d_row = W.shape[0]

    B_list = [i for i in range(0, d_col , B)]
    W = W.clone()
    little = 1e-6  # ensure invertible

    # Initialize H matrix
    H = torch.zeros((d_col, d_col), device=X.device)
    nsamples = X.shape[1]  # Number of samples

    # Accumulate H
    inp = X.float()
    H += inp @ inp.t()
    dead = torch.diag(H) == 0
    H[dead, dead] = 1

    # Add damping
    percdamp = 0.01  # Damping percentage, you can adjust this value
    damp = percdamp * torch.mean(torch.diag(H))
    diag_idx = torch.arange(d_col, device=X.device)
    H[diag_idx, diag_idx] += damp

    # Compute H_inv using Cholesky decomposition
    L = torch.linalg.cholesky(H)
    H_inv = torch.cholesky_inverse(L)
    H_inv = torch.linalg.cholesky(H_inv, upper=True)

    H_inv_diag = torch.diag(H_inv)
    E = torch.zeros((d_row, B), device=X.device)

    for B_index in B_list:
        for j in range(B_index, min(B_index + B, d_col)):
            if j % Bs == 0:
                pruning_metric = (W[:,j:j+Bs] ** 2) / (H_inv_diag[j:j+Bs] ** 2)               
                mask[:,j:j+Bs] = pruning_metric
                mask = pruning_sub_M(mask, j, Bs, rho)

            E[:, j - B_index] = W[: , j] / H_inv_diag[j]

            E[:, j - B_index] = (1 - mask[:, j]) * E[:, j - B_index]

            W[:, j:min(B_index + B, d_col)] -= E[:, j - B_index][:,None] * H_inv[j, j:min(B_index + B, d_col)]
            if torch.isinf(W).any():
                print("nan")
                assert False
            


        W[:, (B_index + B):] -= E @ H_inv[B_index:B_index + B, (B_index + B):]


    return mask, W


def run_sparse_gpt(my_model,rho, B, Bs, device='cpu'):
    embeddings = load_embeddings()
    embeddings = [torch.tensor(embedding, dtype=torch.float32).squeeze(0).to(device) for embedding in embeddings]
    X = torch.cat(embeddings, dim=0).to(device).T

    W_q = my_model.q_proj.weight.data
    W_k = my_model.k_proj.weight.data

    M_q, W_q_update = pruning_mask_sparse_gpt(X, W_q, rho, B, Bs)
    M_k, W_k_update = pruning_mask_sparse_gpt(X, W_k, rho, B, Bs)

    my_model.q_proj.weight.data = W_q_update
    my_model.k_proj.weight.data = W_k_update

    my_model.q_proj_mask.data = M_q
    my_model.k_proj_mask.data = M_k

    return my_model

def run_wanda(my_model,rho, device='cpu'):
    embeddings = load_embeddings()
    embeddings = [torch.tensor(embedding).squeeze(0) for embedding in embeddings]
    
    X = torch.cat(embeddings, dim=0).to(device)

    M_q = torch.zeros_like(my_model.q_proj.weight).to(device)
    M_k = torch.zeros_like(my_model.k_proj.weight).to(device)

    M_q = calculate_weight_matrix(X, my_model.q_proj.weight)
    M_k = calculate_weight_matrix(X, my_model.k_proj.weight)

    M_q = set_top_large_to_one(M_q, rho)
    M_k = set_top_large_to_one(M_k, rho)

    my_model.q_proj_mask.data = M_q
    my_model.k_proj_mask.data = M_k

    return my_model


def run_gd_mask(my_model,rho,lr=1,lam = 0.001,device = 'cpu' ,epochs = 100):    
    # laod the embeddings
    embeddings = load_embeddings()
    num_embeddings = len(embeddings)

    # use embeddings to train the mask
    my_model.to(device)
    my_model.train()

    
    optimizer = torch.optim.AdamW([my_model.q_proj_mask,my_model.k_proj_mask], lr=lr)
    for epoch in range(epochs):
        total_loss = 0
        optimizer.zero_grad()
        for i, embedding in enumerate(embeddings):
            
            optimizer.zero_grad()

            embedding = torch.tensor(embedding, dtype=torch.float32)
            embedding = embedding.to(device)

            f, f_m = my_model(embedding)
            minus = f - f_m
            f_norms = torch.norm(minus,p = 'fro', dim=(2,3))

            sum_f_norms = torch.sum(f_norms)
            loss_func = 0.5 * sum_f_norms + 0.5 * lam * (torch.norm(my_model.q_proj_mask, p='fro') + torch.norm(my_model.k_proj_mask, p='fro'))

            total_loss += loss_func

        total_loss /= num_embeddings
        total_loss.backward()
        optimizer.step()

    q_proj_mask = my_model.q_proj_mask
    k_proj_mask = my_model.k_proj_mask

    q_proj_mask = set_top_large_to_one(q_proj_mask, rho)
    k_proj_mask = set_top_large_to_one(k_proj_mask, rho)

    my_model.q_proj_mask.data = q_proj_mask
    my_model.k_proj_mask.data = k_proj_mask

    return my_model