import torch
import numpy as np
from distances import gauss_dkt, Laplace_dkt
from geomloss import SamplesLoss

gaussian_MMD = SamplesLoss("gaussian", blur=0.5/np.sqrt(0.5))#0.5)
laplacian_MMD = SamplesLoss("laplacian", blur=0.5)

n = 100
d = 2
shift = 0.3#*np.ones(1,2)
dist = 10
np.random.seed(42)
X = np.random.standard_normal(size=(n, d))
Y = np.random.standard_normal(size=(n, d))+shift
X2 = X + dist
Y2 = Y + dist
X_concat = np.vstack((X,X2))
Y_concat = np.vstack((Y,Y2))
Y = torch.from_numpy(Y).float()
X = torch.from_numpy(X).float()
Y_concat = torch.from_numpy(Y_concat).float()
X_concat = torch.from_numpy(X_concat).float()

print("gaussian_MMD ",gaussian_MMD(X,Y))
print("gaussian_MMD 2 modes",gaussian_MMD(X_concat,Y_concat))
print("gaussian dkt ",gauss_dkt(X,Y))
print("gaussian dkt 2 modes",gauss_dkt(X_concat,Y_concat))
print("~Bonus~")
print("laplacian_MMD ",laplacian_MMD(X,Y))
print("laplacian_MMD 2 modes",laplacian_MMD(X_concat,Y_concat))
print("laplacian dkt ",Laplace_dkt(X,Y))
print("laplacian dkt 2 modes",Laplace_dkt(X_concat,Y_concat))