from distances import gauss_KBW, gauss_dkt
from geomloss import SamplesLoss
import torch
import matplotlib.pyplot as plt
import numpy as np
from functools import partial

torch.manual_seed(42)

def routine_mean(dist,mean_list):
    results = []
    for mean in mean_list:
        orig = torch.randn(1000,1)
        conta = mean + torch.randn(1000,1)
        res = dist(orig,conta).item()
        results.append(res)
    return results

mean_vals = np.arange(10.5,step=0.5)
print("mean_vals",mean_vals)

samples = torch.randn(1000,1)

print("mmd")
mmd = SamplesLoss("gaussian", blur=1)
mmd_results = routine_mean(mmd,mean_vals)
print(mmd_results)
print("gauss_KBW")
gauss_kernel_wasserstein_results = routine_mean(gauss_KBW, mean_vals)
print(gauss_kernel_wasserstein_results)
print("dkt")
dkt_results = routine_mean(gauss_dkt, mean_vals)
print(dkt_results)

double_gauss_kernel_wasserstein_results = [2*x for x in gauss_kernel_wasserstein_results]
squared_gauss_kernel_wasserstein_results = [x*x for x in gauss_kernel_wasserstein_results]
root_dkt_results = [np.sqrt(x) for x in dkt_results]

fig = plt.figure()
# plt.plot(mean_vals,root_dkt_results, color='indianred', label='root dkt', lw=2)
plt.plot(mean_vals,mmd_results, color='indianred', label='$MMD_{k}$', lw=2)
plt.plot(mean_vals,gauss_kernel_wasserstein_results, color='g', label='gauss KBW', lw=2)
plt.plot(mean_vals,squared_gauss_kernel_wasserstein_results, color='y', label='squared gauss KBW', lw=2)
plt.plot(mean_vals,double_gauss_kernel_wasserstein_results, color='k', label='double gauss KBW', lw=2)
plt.plot(mean_vals,dkt_results, color='cornflowerblue', label='$d_{KT}$', lw=2)

plt.grid()
plt.xlabel(r"$\theta$")
plt.legend()

title = "var_mean_gauss_KBW"
figpath = 'save/' + title +'.pdf'       
fig.savefig(figpath,dpi=700)
plt.show()

