import numpy as np
import torch

def frob(A, B):
    return torch.trace(A.T @ B)
    
def CKA(A,B):
    n = A.shape[0]
    Am = A - A.mean(1)[:, None]
    Bm = B - B.mean(1)[:, None]

    return frob(Am, Bm) / torch.sqrt(( frob(Am, Am) * frob(Bm, Bm) ))

def det_2x2_matrices(A):
    assert A.shape[-2] == 2 and A.shape[-1] == 2, "Last two dimensions must be 2x2 matrices."
    
    # Extract elements
    a = A[..., 0, 0]
    b = A[..., 0, 1]
    c = A[..., 1, 0]
    d = A[..., 1, 1]
    
    # Compute determinant
    return a * d - b * c

def inverse_2x2_matrices(A):
    assert A.shape[-2] == 2 and A.shape[-1] == 2, "Last two dimensions must be 2x2 matrices."
    
    # Extract elements
    a = A[..., 0, 0]
    b = A[..., 0, 1]
    c = A[..., 1, 0]
    d = A[..., 1, 1]
    
    # Compute determinant
    det = a * d - b * c
    
    # Check for non-zero determinant
    if np.any(det == 0):
        raise ValueError("One or more 2x2 matrices are singular and cannot be inverted.")
    
    # Calculate inverse using the formula
    inv_A = np.zeros_like(A)
    inv_A[..., 0, 0] = d / det
    inv_A[..., 0, 1] = -b / det
    inv_A[..., 1, 0] = -c / det
    inv_A[..., 1, 1] = a / det
    
    return inv_A

def randomized_svd(A, n_components=5, n_oversamples=10, n_iter=3):
    device = A.device
    m, n = A.shape
    n_random = n_components + n_oversamples

    # Step 1: Generate a random matrix
    P = torch.randn(n, n_random, device=device)
    
    # Step 2: Sample the range of A
    Z = A @ P

    # Step 3: Apply power iterations with orthonormalization
    for _ in range(n_iter):
        Z = A @ (A.T @ Z)
        Z, _ = torch.linalg.qr(Z)  # Orthonormalize Y
    
    # Step 3: Orthogonalize Z
    Q, _ = torch.linalg.qr(Z)
    
    # Step 4: Compute B = Q^T A
    B = Q.T @ A
    
    # Step 5: SVD on the smaller matrix B
    U_B, S, _ = torch.linalg.svd(B, full_matrices=False)
    
    # Step 6: Compute U
    U = Q @ U_B

    return U[:, :n_components], S[:n_components]

def cos_normalize_arr(arr): #Make this CUDA
    sq_grad_mag = torch.diag(arr)
    _norm = torch.sqrt(sq_grad_mag[None, :] * sq_grad_mag[:, None])
    return arr / _norm