import numpy as np

def subspace_alignment_with_tikhonov(original, ordinal, eta=1e-3):
    """
    Aligns `ordinal` (n x d2) to `original` (n x d1) using Tikhonov regularization.

    Solves the following optimization problem:
        argmin_A ||X - Y @ A||_F^2 + eta ||A||_F^2,
    where X is the original embedding, Y is the ordinal embedding, and A is the projection matrix.
    Solution is given by A = (Y^T Y + eta I)^(-1) Y^T X.

    Args:
        original (np.ndarray): Original embedding matrix (n x d1).
        ordinal (np.ndarray): Ordinal embedding matrix (n x d2).
        eta (float): Regularization parameter for Tikhonov regularization.

    Returns:
        np.ndarray: Aligned ordinal embedding in the original space (n x d1).
    """
    # Center the embeddings
    original_mean = np.mean(original, axis=0)
    ordinal_mean = np.mean(ordinal, axis=0)
    original_centered = original - original_mean
    ordinal_centered = ordinal - ordinal_mean

    # Compute the regularized projection matrix A
    H = ordinal_centered.T @ ordinal_centered + eta * np.eye(ordinal_centered.shape[1])
    A = np.linalg.solve(H, ordinal_centered.T @ original_centered)

    # Align the ordinal embedding
    aligned_ordinal = ordinal_centered @ A + original_mean
    return aligned_ordinal

def normalized_procrustes_distance(X, Y, eta=1e-3):
    """ Computes the normalized procrustes distance between the original and the recovered matrix.


    Args:
        X (np.array): of shape Nxd1
        Y (np.array): of shape Nxd2
        eta (float, optional): Amount of Tikhonov regularization to add during subspace alignment. Defaults to 1e-3.

    Returns:
        float: Normalized Procrustes Distance between the Y_aligned and X (between 0 and 1)
    """
    X_centered = X - np.mean(X, axis=0)
    Y_aligned = subspace_alignment_with_tikhonov(X, Y, eta)
    raw_distance = np.linalg.norm(X - Y_aligned, 'fro')
    return raw_distance / np.linalg.norm(X_centered, ord='fro')

def compute_psnr(X, Y, eta=1e-3):
    """
    Computes the Peak Signal-to-Noise Ratio (PSNR) between the original and ordinal embeddings.
    Only use this after aligning the ordinal embedding to the original embedding.
    
    Args:
        X (np.ndarray): Original embedding matrix (Nxd1).
        Y (np.ndarray): Ordinal embedding matrix (Nxd2).
        eta (float): Regularization parameter for Tikhonov regularization.

    Returns:
        float: PSNR value in decibels (dB).
    """
    # Align the ordinal embedding to the original embedding
    aligned = subspace_alignment_with_tikhonov(X, Y, eta)

    # Compute PSNR
    psnr = compute_psnr_aligned(X, aligned)
    return psnr

def compute_psnr_aligned(original, aligned):
    """
    Computes the Peak Signal-to-Noise Ratio (PSNR) between the original and aligned embeddings.
    Only use this after aligning the ordinal embedding to the original embedding.
    
    Args:
        original (np.ndarray): Original embedding matrix (n x d1).
        aligned (np.ndarray): Aligned embedding matrix (n x d1).

    Returns:
        float: PSNR value in decibels (dB).
    """
    # Compute Mean Squared Error (MSE)
    mse = np.mean((original - aligned) ** 2)

    # Compute dynamic range of the original embedding
    dynamic_range = original.max() - original.min()

    # Handle perfect alignment (avoid division by zero)
    if mse == 0:
        return float('inf')

    # Compute PSNR
    psnr = 20 * np.log10(dynamic_range / np.sqrt(mse))
    return psnr
