import math
import numpy as np



def centering(K):
    n = np.shape(K)[0]
    unit = np.ones(n)
    I = np.eye(n)
    H = I - unit/n
    return (H @ K) @ H


def linear_HSIC(X, Y):
    L_X = X @ X.transpose()
    L_Y = Y @ Y.transpose()
    return np.sum(centering(L_X) * centering(L_Y))


def rbf(X, sigma=None):
    GX = X @ X.transpose()
    KX = np.diag(GX) - GX + (np.diag(GX) - GX).transpose(0,1)
    if sigma is None:
        # mdist = torch.median(KX[KX != 0])
        try:
            mdist = np.median(KX[KX != 0])
            #mdist = torch.quantile(KX[KX != 0], q=0.75)
        except:
            #mdist = 5.
            mdist = np.zeros(1).to(KX.device)
        sigma = math.sqrt(np.clip(a=mdist, a_min=1e-12, a_max=1e+12))
        print(sigma)
    KX = KX * (-0.5 / (sigma * sigma))
    KX = np.exp(KX)
    #print(KX)
    return KX

def kernel_HSIC(X, Y, sigma=None):
    return np.sum(centering(rbf(X, sigma)) * centering(rbf(Y, sigma)))