import numpy as np
from scipy.linalg import sqrtm

"""
CODE ADAPTED FROM DING ET AL. 2021

Ding, F., Denain, J. S., & Steinhardt, J. (2021). Grounding Representation Similarity Through Statistical Testing. Advances in Neural Information Processing Systems, 34, 1556-1568.
"""

def lin_cka_dist_2(A, B):
    """
    Computes Linear CKA distance bewteen representations A and B
    based on the reformulation of the Frobenius norm term from Kornblith et al. (2018)
    np.linalg.norm(A.T @ B, ord="fro") ** 2 == np.trace((A @ A.T) @ (B @ B.T))
    
    Parameters
    ----------
    A : examples x neurons
    B : examples x neurons

    Original Code from Ding et al. (2021)
    -------------
    similarity = np.linalg.norm(B @ A.T, ord="fro") ** 2
    normalization = np.linalg.norm(A @ A.T, ord="fro") * np.linalg.norm(B @ B.T, ord="fro")
    """

    similarity = np.trace((A @ A.T) @ (B @ B.T))
    normalization = (np.linalg.norm(A @ A.T,ord='fro') * 
                     np.linalg.norm(B @ B.T,ord='fro'))

    return 1 - similarity / normalization


def procrustes_2(A, B):
    """
    Computes Procrustes distance bewteen representations A and B
    for when |neurons| >> |examples| and A.T @ B too large to fit in memory.
    Based on:
         np.linalg.norm(A.T @ B, ord="nuc") == np.sum(np.sqrt(np.linalg.eig(((A @ A.T) @ (B @ B.T)))[0]))
    
    Parameters
    ----------
    A : examples x neurons
    B : examples x neurons

    Original Code
    -------------    
    nuc = np.linalg.norm(A @ B.T, ord="nuc")  # O(p * p * n)
    """
    A_sq_frob = np.sum(A ** 2)
    B_sq_frob = np.sum(B ** 2)
    nuc = np.sum(np.sqrt(np.abs(np.linalg.eig(((A @ A.T) @ (B @ B.T)))[0])))
    return A_sq_frob + B_sq_frob - 2 * nuc


def cca_decomp_2(A,B,pen_a=0,pen_b=0):
    """
    Computes CCA vectors, correlations, and transformed matrices
    based on Tuzhilina et al. (2021)

    Args:
        A: np.array of size n x a where a is the number of neurons and n is the dataset size
        B: np.array of size n x b where b is the number of neurons and n is the dataset size
    Returns:
        u: left singular vectors for the inner SVD problem
        s: canonical correlation coefficients
        vh: right singular vectors for the inner SVD problem
        transformed_a: canonical vectors for matrix A, n x a array
        transformed_b: canonical vectors for matrix B, n x b array
        
    Tuzhilina et al. (2021) normalizes by (1/n), but that doesn't match
    Ding et al. (2021):
    
        A_cov_inv = np.linalg.inv(sqrtm((1/n) * A_cov + pen_a * np.identity(A.shape[1])))
        B_cov_inv = np.linalg.inv(sqrtm((1/n) * B_cov + pen_b * np.identity(B.shape[1])))

        objective_matrix = (A_cov_inv @ ((1/n) * AB_cov) @ B_cov_inv)        
        
    """
    A_cov = A.T @ A
    B_cov = B.T @ B
    AB_cov = A.T @ B
    
    A_cov_inv = np.linalg.inv(sqrtm(A_cov + pen_a * np.identity(A.shape[1])))
    B_cov_inv = np.linalg.inv(sqrtm(B_cov + pen_b * np.identity(B.shape[1])))
    
    objective_matrix = (A_cov_inv @ (AB_cov) @ B_cov_inv)
    
    u,s,vh = np.linalg.svd(objective_matrix,full_matrices=False)
    transformed_a = (u.T @ A_cov_inv @ A.T).T
    transformed_b = (vh  @ B_cov_inv @ B.T).T
    
    return u, s, vh, transformed_a, transformed_b


def cca_decomp_kernel_trick(A,B,pen_a=0,pen_b=0):
    """
    Computes CCA vectors, modified correlations, and transformed matrices.
    Implements the kernel trick from Tuzhilina et al. (2021). Useful for n << a,b.
    The kernel trick replaces A and B in the objective function with 
    A_R and B_R (A = A_R @ V.T). Replacing A with A_R and B with B_R, 
    reduces the size of the covariance matrices, making working in high dimensions tractable.
    The CCA vectors and modified correlations are the same for solutions based on A and A_R.
    The only caveat is that the dimension of the CCA vectors are restricted to the size of the
    dataset (n).
    
    Args:
        A: np.array of size n x a where a is the number of neurons and n is the dataset size
        B: np.array of size n x b where b is the number of neurons and n is the dataset size
        pen_a: regularization penalty for A, required when a >= n
        pen_b: regularization penalty for B, required when b >= n
    Returns:
        u: left singular vectors for the inner SVD problem
        s: canonical correlation coefficients
        vh: right singular vectors for the inner SVD problem
        transformed_a: canonical vectors for matrix A, n x a array
        transformed_b: canonical vectors for matrix B, n x b array
        
    Tuzhilina et al. (2021) normalizes by (1/n), but that doesn't match
    Ding et al. (2021):
    
        A_cov_inv = np.linalg.inv(sqrtm((1/n) * A_cov + pen_a * np.identity(A.shape[1])))
        B_cov_inv = np.linalg.inv(sqrtm((1/n) * B_cov + pen_b * np.identity(B.shape[1])))

        objective_matrix = (A_cov_inv @ ((1/n) * AB_cov) @ B_cov_inv)        
        
    """
    Au,As,Av = np.linalg.svd(A,full_matrices=False)
    As_diag = np.diag(As)
    A_R = Au @ As_diag

    Bu,Bs,Bv = np.linalg.svd(B,full_matrices=False)
    Bs_diag = np.diag(Bs)
    B_R = Bu @ Bs_diag

    A_cov  = A_R.T @ A_R
    B_cov  = B_R.T @ B_R
    AB_cov = A_R.T @ B_R

    A_cov_inv = np.linalg.inv(sqrtm(A_cov + pen_a * np.identity(A_R.shape[1])))
    B_cov_inv = np.linalg.inv(sqrtm(B_cov + pen_b * np.identity(B_R.shape[1])))

    objective_matrix = (A_cov_inv @ (AB_cov) @ B_cov_inv)

    u,s,vh = np.linalg.svd(objective_matrix,full_matrices=False)

    transformed_a = (u.T @ A_cov_inv @ A_R.T).T
    transformed_b = (vh @ B_cov_inv @ B_R.T).T
    
    return u, s, vh, transformed_a, transformed_b