import numpy as np
import torch

def torch_tile(tensor, dim, n):
    """Tile n times along the dim axis"""
    if dim == 0:
        return tensor.unsqueeze(0).transpose(0,1).repeat(1,n,1).view(-1,tensor.shape[1])
    else:
        return tensor.unsqueeze(0).transpose(0,1).repeat(1,1,n).view(tensor.shape[0], -1)
    
def Gram_np(X, sg, N, Q):
    sg2 = 2*sg*sg
    aa = np.sum(X * X, 1).reshape(N,1)
    ab = np.dot(X, X.T)
    D = np.tile(aa, [1,N])
    xx = np.maximum(D + D.T - 2*ab, np.zeros((N,N)))
    Gx = np.exp(-xx/sg2)
    Kx = np.dot(Q, np.dot(Gx,Q))
    Kx = (Kx+Kx.T)/2
    return Kx

def Gram(X, sg, N, Q, device):
    sg2 = 2*sg*sg
    aa = torch.sum(X * X, 1).reshape(N,1)
    ab = torch.matmul(X, X.T)
    D = torch_tile(aa, 1, N)
    xx = torch.max(D + D.T - 2*ab, torch.zeros((N,N)).to(device))
    Gx = torch.exp(-xx/sg2)
    Kx = torch.matmul(Q, torch.matmul(Gx,Q))
    Kx = (Kx+Kx.T)/2
    return Kx

def pairwise_distances(x, y=None):
    x_norm = (x**2).sum(1).view(-1, 1)
    if y is not None:
        y_norm = (y**2).sum(1).view(1, -1)
    else:
        y = x
        y_norm = x_norm.view(1, -1)
    dist = x_norm + y_norm - 2.0 * torch.mm(x, torch.transpose(y, 0, 1))
    return dist

def MedianDist(x):
    px = pairwise_distances(x).view(1,-1).squeeze()
    mdis = torch.sqrt(torch.median(px[px.nonzero().squeeze()]))
    return mdis

def tr(X, Z, Y, N, device, EPS=0.000001):
#     X = (X-torch.mean(X))/torch.var(X)**(0.5)
#     Z = (Z-torch.mean(Z))/torch.var(Z)**(0.5)
    sgx = 0.5*MedianDist(X) 
    sgy = MedianDist(Y)
    I = torch.eye(N).to(device)
    Q = I - torch.ones(N).to(device)/N
    Kz = Gram(Z, sgx, N, Q, device)
    Ky = Gram(Y, sgy, N, Q, device)
    mz = torch.inverse(Kz + EPS*N*I)
    tr = torch.sum(torch.sum(Ky*mz,0),0)
    return tr

def HSIC_in(X, Z, Y, N, device, EPS=0.000001):
#     X = (X-torch.mean(X))/torch.var(X)**(0.5)
#     Z = (Z-torch.mean(Z))/torch.var(Z)**(0.5)
    sgx = MedianDist(X) 
    sgz = 0.5*sgx
    sgy = MedianDist(Y)
    I = torch.eye(N).to(device)
    Q = I - torch.ones(N).to(device)/N
    Kx = Gram(torch.cat((X,Z),1), sgx, N, Q, device)
    Ky = Gram(torch.cat((Y,Z),1), sgy, N, Q, device)
    Kz = Gram(Z, sgx, N, Q, device)
    Rx = torch.matmul(Kx, torch.inverse(Kx + EPS*N*I))
    Ry = torch.matmul(Ky, torch.inverse(Ky + EPS*N*I))
    Rz = torch.matmul(Kz, torch.inverse(Kz + EPS*N*I))
    term1 = torch.matmul(Ry, Rx)
    term2 = -2.*torch.matmul(term1, Rz)
    term3 = torch.matmul(torch.matmul(torch.matmul(Ry, Rz), Rx), Rz)
    HSIC = torch.trace(term1+term2+term3)
    return HSIC

def cor(X, Y, n, device):
    # to calculate distance correlation
    DX = pairwise_distances(X)
    DY = pairwise_distances(Y)
    J = (torch.eye(n) - torch.ones(n,n) / n).to(device)
    RX = J @ DX @ J
    RY = J @ DY @ J
    covXY = torch.mul(RX, RY).sum()/(n*n)
    covX = torch.mul(RX, RX).sum()/(n*n)
    covY = torch.mul(RY, RY).sum()/(n*n)
    return covXY/torch.sqrt(covX*covY)

