import numpy as np
import jax.numpy as jnp
from jax import jit, vmap
import jax
jax.config.update("jax_enable_x64", True)
from utils import logsumexp_3d, calculate_weight_matrix, pruning_sub_M


little = 1e-10

def pruning_mask_gd(X_list, W, lam, rho):

    b = len(X_list)
    n, d = X_list[0].shape
    M = jnp.ones((d, d), dtype=jnp.float64)

    M_c = jnp.tril(jnp.ones((n, n)))
    rho = int(rho * d**2) - 1

    eta = 0.1/lam 
    T = 101

    grad_mom = 0
    beta = 0.9
    X_array = jnp.array(X_list)

    for i in range(T):
        grad = 0
        
        logits = X_array @ W @ X_array.transpose(0, 2, 1)  # (b, n, n)
        logsum = logsumexp_3d(logits)  # (b, n, 1)
        u_j = jnp.exp(logits - logsum) + little  # (b, n, n)
        u_j *= M_c
        f_j = u_j / u_j.sum(axis=2, keepdims=True)  # (b, n, n)

        wt_logits = X_array @ (M * W) @ X_array.transpose(0, 2, 1)
        wt_logsum = logsumexp_3d(wt_logits)
        wt_u_j = jnp.exp(wt_logits - wt_logsum) + little
        wt_u_j *= M_c
        wt_f_j = wt_u_j / wt_u_j.sum(axis=2, keepdims=True)

        c_j = wt_f_j - f_j  # (b, n, n)
        p1_j = c_j * wt_f_j
        p1_j_sum = p1_j.sum(axis=2)  # (b, n)
        p1_j_sum = p1_j_sum[:, :, jnp.newaxis]  # (b, n, 1)
        p2_j = wt_f_j * p1_j_sum    
        p_j = p1_j - p2_j
        grad += W * (X_array.transpose(0, 2, 1) @ p_j @ X_array).sum(axis=0)  # (d, d)

        grad_mom = beta * grad_mom + (1 - beta) * grad
        M -= eta * (grad_mom/(n*b) + lam * M)

    m = M.flatten()
    m_sorted = jnp.sort(m)
    tau = m_sorted[rho]
    
    # Pruning: Set values above tau to 0, others to 1
    M_judge = M.copy()
    M = jnp.where(M_judge <= tau, 0, 1)

    return M

def pruning_mask_wanda(X_list, W_q, W_k, rho):
    # Initialization
    k = len(X_list)
    n, d = X_list[0].shape
    # zero init M_q and M_k
    M_q = np.zeros((d, d))
    M_k = np.zeros((d, d))
    for i in range(k):
        M_q += calculate_weight_matrix(X_list[i], W_q)
        M_k += calculate_weight_matrix(X_list[i], W_k)
    
    # Define M_c as a lower triangular matrix with ones on and below the diagonal
    M_c = np.tril(np.ones((n, n)))
    rho = int(rho * d**2) - 1

    # Flatten M and prune the top rho elements
    m_q = M_q.flatten()
    m_k = M_k.flatten()
    m_q_sorted = np.sort(m_q)
    m_k_sorted = np.sort(m_k)
    tau_q = m_q_sorted[rho]
    tau_k = m_k_sorted[rho]

    # Pruning: Set values above tau to 0, others to 1
    M_q_judge = M_q.copy()
    M_k_judge = M_k.copy()
    M_q[M_q_judge <= tau_q] = 0
    M_q[M_q_judge > tau_q] = 1
    M_k[M_k_judge <= tau_k] = 0
    M_k[M_k_judge > tau_k] = 1


    return M_q, M_k

def pruning_mask_sparse_gpt(X_list, W_q, W_k, rho, B, Bs):
    # Initialization
    k = len(X_list) 
    n, d = X_list[0].shape

    X = np.vstack(X_list)

    X_T = X.T 

    M_q = np.ones((d, d))
    M_k = np.ones((d, d))

    B_list = [i for i in range(d) if i % B == 0]

    W_q_T = W_q.copy().T
    W_k_T = W_k.copy().T
    H_inv_q = np.linalg.inv(X_T @ X_T.T + little * np.eye(d))
    H_inv_q_diag = np.diag(H_inv_q)
    E_q = np.zeros((d, B))

    H_inv_k = np.linalg.inv(X_T @ X_T.T + little * np.eye(d))
    H_inv_k_diag = np.diag(H_inv_k)
    E_k = np.zeros((d, B))
    for B_index in B_list:
        for j in range(B_index, B_index + B):
            if j % Bs == 0:
                pruning_metric_q = (W_q_T ** 2) / (H_inv_q_diag[:, np.newaxis] ** 2)
                M_q[:,j:j+Bs] = pruning_metric_q[:,j:j+Bs]
                M_q[:,j:j+Bs] = pruning_sub_M(M_q,j,Bs,rho)
                
                pruning_metric_k = (W_k_T ** 2) / (H_inv_k_diag[:, np.newaxis] ** 2)
                M_k[:,j:j+Bs] = pruning_metric_k[:,j:j+Bs]
                M_k[:,j:j+Bs] = pruning_sub_M(M_k,j,Bs,rho)

            E_q[:,j-B_index] = W_q_T[:,j] / H_inv_q_diag[j]
            E_q[:,j-B_index] = (1 - M_q[:,j]) * E_q[:,j-B_index]
            W_q_T[:,j:(B_index+B)] = W_q_T[:,j:(B_index+B)] - E_q[:,j-B_index][:, np.newaxis] * H_inv_q[j,j:B_index+B]

            E_k[:,j-B_index] = W_k_T[:,j] / H_inv_k_diag[j]
            E_k[:,j-B_index] = (1 - M_k[:,j]) * E_k[:,j-B_index]
            W_k_T[:,j:(B_index+B)] = W_k_T[:,j:(B_index+B)] - E_k[:,j-B_index][:, np.newaxis] * H_inv_k[j,j:B_index+B]

        
        W_q_T[:,(B_index+B):] = W_q_T[:,(B_index+B):] - E_q @ H_inv_q[B_index:B_index+B,B_index+B:]
        W_k_T[:,(B_index+B):] = W_k_T[:,(B_index+B):] - E_k @ H_inv_k[B_index:B_index+B,B_index+B:]

    W_q_update = W_q_T.T
    W_k_update = W_k_T.T
    M_q = M_q.T
    M_k = M_k.T

    return M_q, M_k, W_q_update, W_k_update

