from distances import gauss_dkt
from geomloss import SamplesLoss
import torch
import matplotlib.pyplot as plt

params = {"legend.fontsize": 18,
        "axes.titlesize": 16,
        "axes.labelsize": 16,
        "xtick.labelsize": 13,
        "ytick.labelsize": 13,
        "pdf.fonttype": 42,
        "svg.fonttype": 'none'}
plt.rcParams.update(params)

def mmd(X,Y,sigma=1):
    mmdist = SamplesLoss("gaussian", blur=sigma)
    return mmdist(X,Y)

def routine_sigma_kernel(dist,orig,conta,sigma_list):
    results = []
    for sigma in sigma_list:
        res = dist(orig,conta,sigma=sigma).item()
        results.append(res)
    return results

orig = torch.randn(1000,1)
conta = 5 + torch.randn(1000,1)

sigma_list = [0.1,0.3,1,3,10,30,100,300,1000,3000,10000]

print("mmd")
mmd_results = routine_sigma_kernel(mmd,orig,conta,sigma_list)
print(mmd_results)
print("d1")
d1_results = routine_sigma_kernel(gauss_dkt,orig,conta,sigma_list)
print(d1_results)

fig = plt.figure()
plt.plot(sigma_list,mmd_results, color='indianred', label='$MMD_{k}$', lw=2)
plt.plot(sigma_list,d1_results, color='cornflowerblue', label='$d_{KT}$', lw=2)
plt.legend()
plt.xscale('log')
plt.grid()
plt.xlabel(r"$\sigma$")
name = "var_kernel_sigma"
plt.show()
figpath = 'save/' +  name +'.pdf'       
fig.savefig(figpath,dpi=700)
plt.show()





