import pdb
import numpy as np
import ot
import matplotlib.pyplot as plt
from scipy.spatial.distance import cdist
import matplotlib as mpl
from scipy.stats import gaussian_kde
mpl.use('TkAgg')
import numpy as np
from scipy.stats import multivariate_normal
import ot
from aux import closest_to_origin, mmd2_grad, rbf_kernel
import numpy as np
from scipy.spatial import cKDTree
from scipy.special import digamma, gamma
from scipy.stats import multivariate_normal
import matplotlib.colors as mcol
import seaborn as sns
from matplotlib.colors import TwoSlopeNorm
from matplotlib.colors import LinearSegmentedColormap, TwoSlopeNorm
# Parameters for plots
length_ticks = 2
font_size = 9
linewidth = 1.2
scatter_size = 2
length_ticks = 2
scatter_size = 20
horizontal_size = 1.2
vertical_size = 1.2
mpl.rcParams.update({'font.size': font_size})
mpl.rcParams['lines.linewidth'] = linewidth
mpl.rcParams['xtick.labelsize'] = font_size - 5
mpl.rcParams['ytick.labelsize'] = font_size - 5
mpl.rcParams['axes.spines.right'] = False
mpl.rcParams['axes.spines.top'] = False
mpl.rcParams['axes.titlesize'] = font_size - 2
mpl.rcParams['legend.fontsize'] = font_size - 2


# Parameters for MMD learning rules
n_interactions = 40000#40000
batch_size = 10
learning_rates = 2*np.ones(n_interactions) #(0.9999 ** np.arange(n_interactions))  # exponential decaying
n_repetitions = 10
n_particles = 12

# Target distribution
mean_init=np.array([1,1])
cov_init=np.eye(2)*0.25

# Compute true decoded value using 25% of units closer to origin
mean = np.array([3, 3])
cov=np.eye(2)
n_samples_test = 100000
gamma=1 # temporal discount factor
X = np.random.multivariate_normal(mean, cov, size=n_samples_test)
idx_particles_closest_to_origin_init = closest_to_origin(X, int(n_samples_test * 0.25))
value = np.sum(X[idx_particles_closest_to_origin_init, 1] * gamma ** (X[idx_particles_closest_to_origin_init, 0]) * (1.0 / len(idx_particles_closest_to_origin_init)))
true_decoded = value

#particles_init = np.random.multivariate_normal(mean_init, cov_init, n_particles) # sample

# Picked initial particles
particles_init= np.array([[0.85      , 1.19],
       [1.7       , 0.6       ],
       [1.4       , 1.4       ],
       [0.4       , 1.58465806],
       [1.44, 1 ],
       [0.78      , 0.71981831],
       [0.4       , 1.1       ],
       [1.12305538, 1.33102621],
       [1.1       , 0.7       ],
       [0.66      , 1       ],
       [0.60242665, 1.44190863],
       [1.08      , 0.54      ]])

# Covariance of kernel
sigma_list=[0.1,0.25,0.5,0.75,1]

for sigma in sigma_list:

    Sigma = np.diag([sigma, sigma])

    # Initialize all variables
    decoded_error_all_runs = np.zeros((n_repetitions, n_interactions))
    global_wasserstein_all_runs = np.zeros(n_repetitions)
    cumulative_wasserstein_all_runs = np.zeros(n_repetitions)

    for rep in range(n_repetitions):

        # Initialize particles
        particles = np.copy(particles_init)


        particles_save = np.zeros((n_interactions, particles.shape[0], particles.shape[1]))
        particles_save[0, :, :] = particles_init
        cumulative_wasserstein = 0
        decoded = []

        # Decode
        idx_particles_closest_to_origin = closest_to_origin(particles, int(n_particles * 0.25))
        value = np.sum(particles[idx_particles_closest_to_origin, 1] * gamma ** (particles[idx_particles_closest_to_origin, 0]) * (1.0 / len(idx_particles_closest_to_origin)))
        decoded.append(value)

        all_magnitude_updates = []
        kl_all=[]

        for iter in range(1, n_interactions):
            # Samples
            reward = np.random.multivariate_normal(mean, cov, batch_size)

            # Compute gradient
            gradient = mmd2_grad(particles, reward, Sigma)

            alpha = learning_rates[iter]  # learning rate

            # Update particles
            particles_new = particles - alpha * gradient


            all_magnitude_updates.append(alpha * np.mean(np.abs(gradient)))

            #plt.hist(all_magnitude_updates)
            #plt.show()

            # Decode
            value = np.sum(particles_new[idx_particles_closest_to_origin, 1] * gamma ** (particles_new[idx_particles_closest_to_origin, 0]) * (1.0 / len(idx_particles_closest_to_origin)))
            decoded.append(value)

            # Compute local Wasserstein
            M = ot.dist(particles_new, particles, metric='euclidean') ** 2
            wasserstein_distance = ot.emd2(np.ones(n_particles) * (1.0 / n_particles),
                                           np.ones(n_particles) * (1.0 / n_particles), M)
            cumulative_wasserstein += np.sqrt(wasserstein_distance)
            particles = particles_new

            # Save
            particles_save[iter, :, :] = particles

        # Compute global Wasserstein
        M_global = ot.dist(particles_init, particles, metric='euclidean') ** 2
        wasserstein_global = np.sqrt(
            ot.emd2(np.ones(n_particles) * (1.0 / n_particles), np.ones(n_particles) * (1.0 / n_particles), M_global))

        #plt.plot(kl_all,label=str(sigma))
        #plt.show()


        #plt.plot(all_magnitude_updates)
        #plt.show()

        global_wasserstein_all_runs[rep] = wasserstein_global
        cumulative_wasserstein_all_runs[rep] = cumulative_wasserstein
        decoded = np.array(decoded)
        decoded_error_all_runs[rep, :] = (decoded - true_decoded) ** 2

        plt.plot((decoded - true_decoded) ** 2)
        #plt.axhline(y=true_decoded)


    #np.save("MMD_cumulative_wasserstein_"+str(sigma)+".npy", cumulative_wasserstein_all_runs)
    #np.save("MMD_global_wasserstein_"+str(sigma)+".npy", global_wasserstein_all_runs)
    #np.save("MMD_decoded_errors_"+str(sigma)+".npy", decoded_error_all_runs)

#plt.legend()
plt.show()


fig, ax = plt.subplots(figsize=(horizontal_size, vertical_size))  #
ax.spines['left'].set_linewidth(linewidth)
ax.spines['bottom'].set_linewidth(linewidth)
ax.tick_params(width=linewidth, length=length_ticks)

# Plot true pdf
# Target distribution
x_initial, y_initial = np.mgrid[-1:5:.01, -1:5:.01]
pos_initial = np.dstack((x_initial, y_initial))

x_final, y_final = np.mgrid[-1:5:.01, -1:5:.01]
pos_final = np.dstack((x_final, y_final))


rv_target = multivariate_normal(mean, cov)
pdf_target =  rv_target.pdf(pos_final)
pdf_target=pdf_target/np.sum(pdf_target)

# Initial distribution
rv_initial = multivariate_normal(mean_init, cov_init)
pdf_initial = multivariate_normal(mean_init, cov_init).pdf(pos_initial)
pdf_initial = pdf_initial / np.sum(pdf_initial)


# Plot initial and target distributions in the same plot


Z = pdf_initial - pdf_target                      # + = red, – = blue
# robust symmetric scaling around 0
v = np.percentile(np.abs(Z), 98)
norm = TwoSlopeNorm(vmin=-v, vcenter=0.0, vmax=v)
# hide tiny diffs so the background is white (tune 0.5–2% of v)
tau = 0.000000001* v
Zm = np.ma.masked_where(np.abs(Z) < tau, Z)
orange = "#f16913"   # strong orange
blue   = "#2b8cbe"   # strong blue
orange_blue = LinearSegmentedColormap.from_list("orange_blue", [blue, "white", orange], N=256)

# Plot two
plt.imshow(Zm.T, extent=[-1, 5, -1, 5], origin="lower",cmap=orange_blue, norm=norm, interpolation="bilinear")


flatui = ["#9B59B6", "#3498DB", "#95A5A6", "#E74C3C", "#34495E", "#2ECC71"]
reward_cmap = plt.cm.jet(np.linspace(0., 1., 8)[:-1])
animal_cmap = sns.color_palette(flatui)
raster_cmap = plt.cm.bone_r
asym_cmap = plt.cm.autumn_r
asym_cmap = mcol.LinearSegmentedColormap.from_list("MyCmapName", [reward_cmap[1], reward_cmap[-1]])
vals = np.linspace(0, 1, n_particles, endpoint=False)  # shape (N,)

# Plot learning trajectory
for i in range(n_particles):
    x = np.concatenate([particles_save[::50, i, 0], [particles_save[-1, i, 0]]])
    y = np.concatenate([particles_save[::50, i, 1], [particles_save[-1, i, 1]]])
    if i == n_particles - 1:
        plt.plot(x, y, color="grey", linewidth=0.4,
                 label="Learning trajectory")
    else:
        plt.plot(x, y, color="grey", linewidth=0.4)

distance=np.linalg.norm(particles_init, axis=1)

# Plot target units
plt.scatter(particles_init[:, 0], particles_init[:, 1], c=distance, cmap=asym_cmap, s=5, zorder=n_particles)
plt.scatter(particles[:, 0], particles[:, 1], c=distance, s=5, cmap=asym_cmap, zorder=n_particles + 1)
plt.xlabel("Time")
plt.ylabel("Magnitude")
fig.savefig("example_transport_presentation_MMD_sigma_"+str(sigma)+".pdf")
plt.show()
