import torch
from geomloss import SamplesLoss
import ot


def gaussian_kernel_matrix(X,Y,sigma=1):
    return torch.exp(-torch.cdist(X,Y)**2/(2*sigma**2))

def laplacian_kernel_matrix(X,Y,sigma=1):
    return torch.exp(-torch.cdist(X,Y,p=1)/(sigma))

def energy_kernel_matrix(X,Y):
    return 0.5*(torch.cdist(X,torch.zeros_like(Y))+torch.t(torch.cdist(Y,torch.zeros_like(X)))-torch.cdist(X,Y))


def homemade_mmd(X,Y,kernel_matrix_func,**kwargs):
    kXX = kernel_matrix_func(X,X,**kwargs)
    kXY = kernel_matrix_func(X,Y,**kwargs)
    kYY = kernel_matrix_func(Y,Y,**kwargs)
    energyX = kXX.mean()
    energyY = kYY.mean()
    energyXY = kXY.mean()
    MMD = energyX + energyY - 2*energyXY
    MMD = torch.sqrt(MMD)
    return MMD

def energy_mmd(X,Y):
    kXX = energy_kernel_matrix(X,X)
    kXY = energy_kernel_matrix(X,Y)
    kYY = energy_kernel_matrix(Y,Y)
    energyX = kXX.mean()
    energyY = kYY.mean()
    energyXY = kXY.mean()
    MMD = energyX + energyY - 2*energyXY
    MMD = torch.sqrt(MMD)
    return MMD


def normalised_MMD(X,Y,kernel_matrix_func,**kwargs):
    # N = X.shape[0]
    # M = Y.shape[0]
    kXX = kernel_matrix_func(X,X,**kwargs)
    kXY = kernel_matrix_func(X,Y,**kwargs)
    kYY = kernel_matrix_func(Y,Y,**kwargs)
    energyX = kXX.mean()
    energyY = kYY.mean()
    energyXY = kXY.mean()
    MMD = 1 - 2*energyXY/(energyX+energyY)
    MMD = torch.sqrt(MMD)
    return MMD

def normalised_gaussian_MMD(X,Y,sigma=1):
    kwargs = {'sigma':sigma}
    kernel_matrix_func = gaussian_kernel_matrix
    MMD = normalised_MMD(X,Y,kernel_matrix_func,**kwargs)
    return MMD

def normalised_laplacian_MMD(X,Y,sigma=1):
    kwargs = {'sigma':sigma}
    kernel_matrix_func = laplacian_kernel_matrix
    MMD = normalised_MMD(X,Y,kernel_matrix_func,**kwargs)
    return MMD

def aggregate_MMD(X,Y,sigma_list=[1.,2.,3.,4.,5.,6.,7.,8.,9.,10.]):
    k = len(sigma_list)
    dist = 0.
    for sigma in sigma_list:
        val = SamplesLoss("gaussian", blur=sigma)
        dist = dist + val(X,Y)
    return dist/k

def KFDA(gamma,X,Y,kernel_matrix_func,**kwargs):
    nx = X.shape[0]
    ny = Y.shape[0]
    n = nx + ny
    m = torch.cat((-1/nx*torch.ones(nx),1/ny*torch.ones(ny)),dim=0)
    Z = torch.cat((X,Y),dim=0)
    K = kernel_matrix_func(Z,Z,**kwargs)
    P1 = torch.eye(nx)-1/nx*torch.ones(nx,nx)
    P2 = torch.eye(ny)-1/ny*torch.ones(ny,ny)
    N = torch.block_diag(P1,P2)
    KFDA = nx*ny/(gamma*n)*(m@(K@m)-1/n*m@(K@N@torch.inverse(gamma*torch.eye(n)+1/n*N@K@N)@N@K)@m)
    return KFDA

def effective_dim(gamma,sigma,r):
    eig,_ = torch.linalg.eigh(sigma)
    inv = 1/(gamma+eig)
    return torch.linalg.vector_norm(eig*inv,ord=r)

def normalised_KFDA(gamma,X,Y,kernel_matrix_func,**kwargs):
    nx = X.shape[0]
    ny = Y.shape[0]
    n = nx + ny
    P1 = torch.eye(nx)-1/nx*torch.ones(nx,nx)
    P2 = torch.eye(ny)-1/ny*torch.ones(ny,ny)
    N = torch.block_diag(P1,P2)
    Z = torch.cat((X,Y),dim=0)
    K = kernel_matrix_func(Z,Z,**kwargs)
    sigma = 1/n*N@torch.t(N)@K
    kfda = KFDA(gamma,X,Y,kernel_matrix_func,**kwargs)
    d1 = effective_dim(gamma,sigma,1)
    d2 = effective_dim(gamma,sigma,2)
    return (kfda-d1)/(torch.sqrt(torch.Tensor([2.0]))*d2)

def gauss_KFDA(gamma,X,Y,**kwargs):
    return KFDA(gamma,X,Y,gaussian_kernel_matrix,sigma=1)

def gauss_normalised_KFDA(gamma,X,Y,**kwargs):
    return normalised_KFDA(gamma,X,Y,gaussian_kernel_matrix,sigma=1)

def gauss_kernel_cost_matrix(X,Y,sigma=1):
    return torch.sqrt(2-2*gaussian_kernel_matrix(X,Y,sigma=sigma))

def gauss_kernel_wasserstein(X,Y,sigma=1):
    n = X.shape[0]
    m = Y.shape[0]
    a, b = torch.ones((n,)) / n, torch.ones((m,)) / m
    M = gauss_kernel_cost_matrix(X,Y,sigma)
    return ot.emd2(a,b,M)

def KBW(X,Y,kernel_matrix_func):
    K_XY = 1/torch.sqrt(torch.tensor(X.shape[0]*Y.shape[0]))*kernel_matrix_func(X,Y)
    K_XX = 1/X.shape[0]*kernel_matrix_func(X,X)
    K_YY = 1/Y.shape[0]*kernel_matrix_func(Y,Y)
    singvals = torch.linalg.svdvals(K_XY)
    return torch.sqrt(torch.trace(K_XX+K_YY)-2*torch.abs(singvals).sum())

def gauss_KBW(X,Y,sigma=1):
    K_XY = gaussian_kernel_matrix(X,Y,sigma)
    _,S,_ = torch.linalg.svd(K_XY)
    singvals = 1/torch.sqrt(torch.tensor(X.shape[0]*Y.shape[0]))*S
    return torch.sqrt(2-2*torch.abs(singvals).sum())

def general_gauss_kernel_matrix(X,Y,sigma=1):
    N = X.shape[0]
    M = Y.shape[0]

    kXX = gaussian_kernel_matrix(X,X,sigma)
    kXY = gaussian_kernel_matrix(X,Y,sigma)
    kYY = gaussian_kernel_matrix(Y,Y,sigma)

    KXY = 1/torch.sqrt(torch.tensor(N*M))*torch.complex(torch.zeros(kXY.size()),kXY)

    KXX = 1/N*torch.complex(kXX,torch.zeros(kXX.size()))

    KYY = -1/M*torch.complex(kYY,torch.zeros(kYY.size()))

    up = torch.cat((KXX,KXY),dim=1)
    down = torch.cat((KXY.t(),KYY),dim=1)
    mat = torch.cat((up,down),dim=0)
    return mat

def general_kernel_matrix(gk_mat_func,X,Y):
    N = X.shape[0]
    M = Y.shape[0]

    kXX = gk_mat_func(X,X)
    kXY = gk_mat_func(X,Y)
    kYY = gk_mat_func(Y,Y)

    KXY = 1/torch.sqrt(torch.tensor(N*M))*torch.complex(torch.zeros(kXY.size()),kXY)

    KXX = 1/N*torch.complex(kXX,torch.zeros(kXX.size()))

    KYY = -1/M*torch.complex(kYY,torch.zeros(kYY.size()))

    up = torch.cat((KXX,KXY),dim=1)
    down = torch.cat((KXY.t(),KYY),dim=1)
    mat = torch.cat((up,down),dim=0)
    return mat

def gauss_dkt(X,Y,sigma=1):
    GK = general_gauss_kernel_matrix(X,Y,sigma)

    eivals = torch.linalg.eigvals(GK)
    dkt = torch.abs(eivals).sum()
    return dkt

def dkt(gk_mat_func,X,Y):
    GK = general_kernel_matrix(gk_mat_func,X,Y)
    eivals = torch.linalg.eigvals(GK)
    dkt = torch.abs(eivals).sum()
    return dkt

def Laplace_dkt(X,Y):
    return dkt(laplacian_kernel_matrix,X,Y)

