import torch
import numpy as np
import math

def centering(K):
    if isinstance(K, torch.Tensor):
        n = K.shape[0]
        unit = torch.ones([n, n], device=K.device)
        I = torch.eye(n, device=K.device)
        H = I - unit / n
        return torch.matmul(torch.matmul(H, K), H)
    else:
        n = K.shape[0]
        unit = np.ones([n, n])
        I = np.eye(n)
        H = I - unit / n
        return np.dot(np.dot(H, K), H)


def rbf(X, sigma=None):
    if isinstance(X, torch.Tensor):
        GX = torch.matmul(X, X.T)
        KX = torch.diag(GX) - GX + (torch.diag(GX) - GX).T
        if sigma is None:
            KX_flat = KX.view(-1)
            non_zero = KX_flat[KX_flat != 0]
            if len(non_zero) > 0:
                mdist = torch.median(non_zero)
                sigma = torch.sqrt(mdist)
            else:
                sigma = torch.tensor(1.0, device=X.device)
        KX *= - 0.5 / (sigma * sigma)
        KX = torch.exp(KX)
        return KX
    else:
        GX = np.dot(X, X.T)
        KX = np.diag(GX) - GX + (np.diag(GX) - GX).T
        if sigma is None:
            mdist = np.median(KX[KX != 0])
            sigma = math.sqrt(mdist)
        KX *= - 0.5 / (sigma * sigma)
        KX = np.exp(KX)
        return KX


def kernel_HSIC(X, Y, sigma):
    if isinstance(X, torch.Tensor) and isinstance(Y, torch.Tensor):
        return torch.sum(centering(rbf(X, sigma)) * centering(rbf(Y, sigma)))
    else:
        return np.sum(centering(rbf(X, sigma)) * centering(rbf(Y, sigma)))


def linear_HSIC(X, Y):
    if isinstance(X, torch.Tensor) and isinstance(Y, torch.Tensor):
        L_X = torch.matmul(X, X.T)
        L_Y = torch.matmul(Y, Y.T)
        return torch.sum(centering(L_X) * centering(L_Y))
    else:
        L_X = np.dot(X, X.T)
        L_Y = np.dot(Y, Y.T)
        return np.sum(centering(L_X) * centering(L_Y))


def linear_CKA(X, Y):
    
    X_torch = torch.from_numpy(X).to(torch.float32).cuda() if isinstance(X, np.ndarray) else X
    Y_torch = torch.from_numpy(Y).to(torch.float32).cuda() if isinstance(Y, np.ndarray) else Y
    
    hsic = linear_HSIC(X_torch, Y_torch)
    var1 = torch.sqrt(linear_HSIC(X_torch, X_torch))
    var2 = torch.sqrt(linear_HSIC(Y_torch, Y_torch))

    result = hsic / (var1 * var2)
    return result.cpu().item() if isinstance(result, torch.Tensor) else result


def kernel_CKA(X, Y, sigma=None):
    
    X_torch = torch.from_numpy(X).to(torch.float32).cuda() if isinstance(X, np.ndarray) else X
    Y_torch = torch.from_numpy(Y).to(torch.float32).cuda() if isinstance(Y, np.ndarray) else Y
    
    hsic = kernel_HSIC(X_torch, Y_torch, sigma)
    var1 = torch.sqrt(kernel_HSIC(X_torch, X_torch, sigma))
    var2 = torch.sqrt(kernel_HSIC(Y_torch, Y_torch, sigma))

    result = hsic / (var1 * var2)
    return result.cpu().item() if isinstance(result, torch.Tensor) else result


def batch_linear_CKA(X, Y, batch_size=1000):
    if isinstance(X, np.ndarray) and X.shape[0] > batch_size:
        
        n = X.shape[0]
        num_batches = int(np.ceil(n / batch_size))
        hsic_total = 0
        var1_total = 0
        var2_total = 0
        
        for i in range(num_batches):
            start_idx = i * batch_size
            end_idx = min((i + 1) * batch_size, n)
            
            X_batch = torch.from_numpy(X[start_idx:end_idx]).to(torch.float32).cuda()
            Y_batch = torch.from_numpy(Y[start_idx:end_idx]).to(torch.float32).cuda()
            
            hsic_batch = linear_HSIC(X_batch, Y_batch)
            var1_batch = linear_HSIC(X_batch, X_batch)
            var2_batch = linear_HSIC(Y_batch, Y_batch)
            
            hsic_total += hsic_batch.cpu().item()
            var1_total += var1_batch.cpu().item()
            var2_total += var2_batch.cpu().item()
            
        return hsic_total / (np.sqrt(var1_total) * np.sqrt(var2_total))
    else:
        return linear_CKA(X, Y)


def batch_kernel_CKA(X, Y, sigma=None, batch_size=1000):
    if isinstance(X, np.ndarray) and X.shape[0] > batch_size:
        
        n = X.shape[0]
        num_batches = int(np.ceil(n / batch_size))
        hsic_total = 0
        var1_total = 0
        var2_total = 0
        
        for i in range(num_batches):
            start_idx = i * batch_size
            end_idx = min((i + 1) * batch_size, n)
            
            X_batch = torch.from_numpy(X[start_idx:end_idx]).to(torch.float32).cuda()
            Y_batch = torch.from_numpy(Y[start_idx:end_idx]).to(torch.float32).cuda()
            
            hsic_batch = kernel_HSIC(X_batch, Y_batch, sigma)
            var1_batch = kernel_HSIC(X_batch, X_batch, sigma)
            var2_batch = kernel_HSIC(Y_batch, Y_batch, sigma)
            
            hsic_total += hsic_batch.cpu().item()
            var1_total += var1_batch.cpu().item()
            var2_total += var2_batch.cpu().item()
            
        return hsic_total / (np.sqrt(var1_total) * np.sqrt(var2_total))
    else:
        return kernel_CKA(X, Y, sigma)
