import pdb
import numpy as np
import ot
import matplotlib.pyplot as plt
from scipy.spatial.distance import cdist
import matplotlib as mpl
from scipy.stats import gaussian_kde
mpl.use('TkAgg')
import numpy as np
from scipy.stats import multivariate_normal
import ot
from aux import closest_to_origin, mmd2_grad, rbf_kernel
import numpy as np
from scipy.spatial import cKDTree
from scipy.special import digamma, gamma
from scipy.stats import multivariate_normal
import matplotlib.colors as mcol
import seaborn as sns
from matplotlib.colors import TwoSlopeNorm
from matplotlib.colors import LinearSegmentedColormap, TwoSlopeNorm
import scipy
from matplotlib import cm
# Parameters for plots
# Parameters for plots
length_ticks = 2
font_size = 9
linewidth = 1.2
scatter_size = 2
length_ticks = 2
scatter_size = 20
horizontal_size = 1.2
vertical_size = 1.2
mpl.rcParams.update({'font.size': font_size})
mpl.rcParams['lines.linewidth'] = linewidth
mpl.rcParams['xtick.labelsize'] = font_size - 5
mpl.rcParams['ytick.labelsize'] = font_size - 5
mpl.rcParams['axes.spines.right'] = False
mpl.rcParams['axes.spines.top'] = False
mpl.rcParams['axes.titlesize'] = font_size - 2
mpl.rcParams['legend.fontsize'] = font_size - 2






fig, ax = plt.subplots(figsize=(horizontal_size, vertical_size))  #
ax.spines['left'].set_linewidth(linewidth)
ax.spines['bottom'].set_linewidth(linewidth)
ax.tick_params(width=linewidth, length=length_ticks)

W2_cumulative_wasserstein_all_runs=np.load("W2_cumulative_wasserstein.npy")#
W2_global_wasserstein_all_runs=np.load("W2_global_wasserstein.npy")#
W2_decoded_error_all_runs=np.load("W2_decoded_errors.npy")#Krupic_
W2_mean=np.nanmean(W2_decoded_error_all_runs,axis=0)
W2_sem=scipy.stats.sem(W2_decoded_error_all_runs,axis=0)


# Plot decoded error
fig, ax = plt.subplots(figsize=(horizontal_size, vertical_size))  #
ax.spines['left'].set_linewidth(linewidth)
ax.spines['bottom'].set_linewidth(linewidth)
ax.tick_params(width=linewidth, length=length_ticks)

plt.plot(np.arange(len(W2_mean))[::1],W2_mean[::1],label="DNL")
plt.fill_between(np.arange(len(W2_mean))[::1],W2_mean[::1]-W2_sem[::1],W2_mean[::1]+W2_sem[::1],alpha=0.1)


# Covariance of kernel
#sigma_list=[0.1,0.25,0.5,1,2,3,5,100]
sigma_list=[0.1,0.25,0.5,0.75,1]
#sigma_list=[0.1,0.25,0.75,1,2,3,4,5,6]
cmap = cm.get_cmap("Purples", len(sigma_list))   # sample N colors
colors = [cmap(i) for i in range(1,len(sigma_list)+1)]

for i_sigma,sigma in enumerate(sigma_list):

    mmd_decoded_error_all_runs=np.load("MMD_decoded_errors_"+str(sigma)+".npy")#Krupic_

    mmd_mean=np.nanmean(mmd_decoded_error_all_runs,axis=0)
    mmd_sem=scipy.stats.sem(mmd_decoded_error_all_runs,axis=0)
    plt.plot(np.arange(len(mmd_mean))[::50],mmd_mean[::50],label="MMD "+r"$\sigma=$"+str(sigma),color=colors[i_sigma])
    plt.fill_between(np.arange(len(mmd_mean))[::50],mmd_mean[::50]-mmd_sem[::50],mmd_mean[::50]+mmd_sem[::50],alpha=0.1,color=colors[i_sigma])
    #plt.show()


plt.xlabel("Time steps")
plt.ylabel("Decoding error")
plt.xlim(-1500,30000)
plt.legend()
plt.savefig("decoding_error_value.svg")
plt.show()
#plt.show()