import numpy as np
import jax.numpy as jnp
from jax import jit

def logsumexp(X):
    max_X = np.max(X, axis=1, keepdims=True)
    return max_X + np.log(np.sum(np.exp(X - max_X), axis=1, keepdims=True))

@jit
def logsumexp_3d(X):
    max_X = jnp.max(X, axis=2, keepdims=True)
    return max_X + jnp.log(jnp.sum(jnp.exp(X - max_X), axis=2, keepdims=True))

def calculate_weight_matrix(X, W):
    column_l2_norm = np.linalg.norm(X, axis=0)
    M = np.zeros_like(W)
    for i in range(W.shape[0]):  
        for j in range(W.shape[1]):  
            M[i, j] = np.abs(W[i, j]) * column_l2_norm[i]
    
    return M

def pruning_sub_M(M_q,j,Bs,rho):
    d = M_q.shape[0]
    sub_mat = M_q[:,j:j+Bs]
    sub_mat_flatten = sub_mat.flatten()
    rho = int(rho * Bs * d) - 1

    m_sorted = np.sort(sub_mat_flatten)
    tau = m_sorted[rho]

    sub_mat_judge = sub_mat.copy()
    sub_mat[sub_mat_judge <= tau] = 0
    sub_mat[sub_mat_judge > tau] = 1

    return sub_mat