from distances import gauss_dkt
from geomloss import SamplesLoss
import torch
import matplotlib.pyplot as plt


def routine_sigma_std(dist,D,sigma_list):
    results = []
    for sigma in sigma_list:
        orig = sigma*torch.randn(1000,1)
        conta = D + sigma*torch.randn(1000,1)
        res = dist(orig,conta).item()
        results.append(res)
    return results

### Varier sigma variance

samples = torch.randn(1000,1)
D = 100

sigma_list = [0.1,0.3,1,3,10,30,100]
print("sigma_list", sigma_list)

print("mmd")
mmd = SamplesLoss("gaussian", blur=1)
mmd_results = routine_sigma_std(mmd,D,sigma_list)
print(mmd_results)
print("dkt")
dkt_results = routine_sigma_std(gauss_dkt,D,sigma_list)
print(dkt_results)

fig = plt.figure()
plt.plot(sigma_list,mmd_results, color='indianred', label='$MMD_{k}$', lw=2)
plt.plot(sigma_list,dkt_results, color='cornflowerblue', label='$d_{KT}$', lw=2)
plt.legend()
plt.grid()
plt.xlabel("s")
plt.xscale('log')
title = "var_std"
figpath = 'save/' + title +'.pdf'       
fig.savefig(figpath,dpi=700)
plt.show()

