import numpy as np
from utils import logsumexp

little = 1e-10


def calculate_loss(X_list, W, M, lam,M_c):
    loss = 0
    for j in range(len(X_list)):
        # Extract the current X_j
        X_j = X_list[j]

        # Calculate the matrices A and wt_A exp-normalize
        A = (np.exp(X_j @ W @ X_j.T - logsumexp(X_j @ W @ X_j.T)) + little) * M_c 
        wt_A = (np.exp(X_j @ (M * W) @ X_j.T - logsumexp(X_j @ (M * W) @ X_j.T)) + little) * M_c 

        # do row wise softmax on A and wt_A
        A = A / A.sum(axis=1)[:, None]
        wt_A = wt_A / wt_A.sum(axis=1)[:, None]

        term1 = np.linalg.norm(A - wt_A, 'fro')**2

        loss += term1/(np.linalg.norm(A, 'fro')**2)

    loss/=len(X_list)


    return loss

def calculate_loss_wanda_setting(X_list, W_q, W_k, M_q, M_k,M_c):
    loss = 0
    for j in range(len(X_list)):
        # Extract the current X_j
        X_j = X_list[j]

        # Calculate the matrices A and wt_A
        A = (np.exp(X_j @ W_q @ (X_j @ W_k).T - logsumexp(X_j @ W_q @ (X_j @ W_k).T)) + little) * M_c
        wt_A = (np.exp(X_j @ (W_q * M_q) @ (X_j @ (W_k * M_k)).T - logsumexp(X_j @ (W_q * M_q) @ (X_j @ (W_k * M_k)).T)) + little) * M_c

        # do row wise softmax on A and wt_A
        A = A / A.sum(axis=1)[:, None]
        wt_A = wt_A / wt_A.sum(axis=1)[:, None]        

        term1 = np.linalg.norm(A - wt_A, 'fro')**2

        # Update the total loss with the regularization term
        loss += term1/(np.linalg.norm(A, 'fro')**2)

    loss/=len(X_list)
    return loss

def calculate_loss_sparse(X_list, W_q, W_k, M_q, M_k, M_c,W_q_update,W_k_update):
    loss = 0
    for j in range(len(X_list)):
        # Extract the current X_j
        X_j = X_list[j]

        # Calculate the matrices A and wt_A
        A = (np.exp(X_j @ W_q @ (X_j @ W_k).T - logsumexp(X_j @ W_q @ (X_j @ W_k).T)) + little) * M_c
        wt_A = (np.exp(X_j @ (W_q_update * M_q) @ (X_j @ (W_k_update * M_k)).T - logsumexp(X_j @ (W_q_update * M_q) @ (X_j @ (W_k_update * M_k)).T)) + little) * M_c

        # do row wise softmax on A and wt_A
        A = A / A.sum(axis=1)[:, None]
        wt_A = wt_A / wt_A.sum(axis=1)[:, None]        

        term1 = np.linalg.norm(A - wt_A, 'fro')**2

        # Update the total loss with the regularization term
        loss += term1/(np.linalg.norm(A, 'fro')**2)

    loss/=len(X_list)
    return loss