import pdb
import numpy as np
import ot
import matplotlib.pyplot as plt
from scipy.spatial.distance import cdist
import matplotlib as mpl
from scipy.stats import gaussian_kde
mpl.use('TkAgg')
import numpy as np
from scipy.stats import multivariate_normal
import ot
from aux import closest_to_origin,particle_update
import seaborn as sns
from matplotlib.colors import TwoSlopeNorm
from matplotlib.colors import LinearSegmentedColormap, TwoSlopeNorm
import matplotlib.colors as mcol
# Parameters for plots
length_ticks = 2
font_size = 9
linewidth = 1.2
scatter_size = 2
length_ticks = 2
scatter_size = 20
horizontal_size = 1.2
vertical_size = 1.2
mpl.rcParams.update({'font.size': font_size})
mpl.rcParams['lines.linewidth'] = linewidth
mpl.rcParams['xtick.labelsize'] = font_size - 5
mpl.rcParams['ytick.labelsize'] = font_size - 5
mpl.rcParams['axes.spines.right'] = False
mpl.rcParams['axes.spines.top'] = False
mpl.rcParams['axes.titlesize'] = font_size - 2
mpl.rcParams['legend.fontsize'] = font_size - 2



# To compute gradient
dx_der=0.1
dy_der=0.1
x_der, y_der = np.mgrid[-1:5:dx_der,-1:5:dy_der]
Nx_der=x_der.shape[0]
Ny_der=y_der.shape[1]
x_der_flat=np.expand_dims(np.ndarray.flatten(x_der),axis=1)
y_der_flat=np.expand_dims(np.ndarray.flatten(y_der),axis=1)
particles_der=np.concatenate((x_der_flat,y_der_flat),axis=1)

# Parameters for distributional neural learning
n_interactions= 30000
batch_size=10#100#1000
bw=3
lamb=0.1#0.01
gamma=200#0.09*16
learning_rates = 0.1*np.ones(n_interactions) #added 0.1(0.99 ** np.arange(n_interactions)) # exponential decaying


# Initialize particles and gradient of F1 and F2
n_particles=12
n_circle = n_particles - 1
r = 1
center = np.array([1, 1]) # Center point
points = [center]

# Equally spaced angles for circle
mean_init=np.array([1,1])
cov_init=np.eye(2)*0.25

n_repetitions=1

# Initial variables
decoded_error_all_runs=np.zeros((n_repetitions,n_interactions))
global_wasserstein_all_runs=np.zeros(n_repetitions)
cumulative_wasserstein_all_runs=np.zeros(n_repetitions)


# Compute true decoded
mean = np.array([3,3])
cov=np.eye(2)
n_samples_test=5_000_000
gamma_value=1
X = np.random.multivariate_normal(mean, cov, size=n_samples_test)
idx_particles_closest_to_origin = closest_to_origin(X, int(n_samples_test * 0.25))
value = np.sum(X[idx_particles_closest_to_origin, 1] * gamma_value ** (X[idx_particles_closest_to_origin, 0]) * (1.0 / len(idx_particles_closest_to_origin)))
true_decoded = value

all_magnitude_updates=[]


# Picked particle distribution
particles_init= np.array([[0.85      , 1.19],
       [1.7       , 0.6       ],
       [1.4       , 1.4       ],
       [0.4       , 1.58465806],
       [1.44, 1 ],
       [0.78      , 0.71981831],
       [0.4       , 1.1       ],
       [1.12305538, 1.33102621],
       [1.1       , 0.7       ],
       [0.66      , 1       ],
       [0.60242665, 1.44190863],
       [1.08      , 0.54      ]])



for rep in range(n_repetitions):

    # Initialize particles
    #particles=np.copy(particles_init)
    particles=np.copy(particles_init)

    gradient_f1 =np.zeros((n_particles,2))
    gradient_f2 =np.zeros((n_particles,2))
    particles_save=np.zeros((n_interactions,particles.shape[0],particles.shape[1]))
    particles_save[0,:,:]=particles_init

    # Reward is sampled from a multivariate gaussian
    cumulative_wasserstein=0
    decoded=[]

    # Decode
    idx_particles_closest_to_origin=closest_to_origin(particles,int(n_particles*0.25))
    value=np.sum(particles[idx_particles_closest_to_origin,1]*gamma_value**(particles[idx_particles_closest_to_origin,0])*(1.0/len(idx_particles_closest_to_origin)))
    decoded.append(value)


    for iter in range(1,n_interactions):

        # Sample
        reward=np.random.multivariate_normal(mean, cov,batch_size)

        gradient=particle_update(particles, reward, particles_der,x_der,y_der,dx_der,dy_der,bw,lamb,gamma)


        alpha=learning_rates[iter] # learning rate

        # Update particles
        particles_new = particles -alpha * gradient
        all_magnitude_updates.append(alpha*np.mean(np.abs(gradient)))



        # Decode
        value = np.sum(particles[idx_particles_closest_to_origin, 1] * 0.9 ** (particles[idx_particles_closest_to_origin, 0]) * (1.0 / len(idx_particles_closest_to_origin)))
        decoded.append(value)

        # Compute local Wasserstein
        M = ot.dist(particles_new, particles, metric='euclidean') ** 2
        wasserstein_distance = ot.emd2(np.ones(n_particles)*(1.0/n_particles), np.ones(n_particles)*(1.0/n_particles), M)
        cumulative_wasserstein+=np.sqrt(wasserstein_distance)
        particles=particles_new
        M_global = ot.dist(particles_init, particles, metric='euclidean') ** 2
        wasserstein_global = np.sqrt(ot.emd2(np.ones(n_particles)*(1.0/n_particles), np.ones(n_particles)*(1.0/n_particles), M_global))

        particles_save[iter,:,:]=particles


    decoded=np.array(decoded)

    # Compute global Wasserstein
    global_wasserstein_all_runs[rep]=wasserstein_global
    cumulative_wasserstein_all_runs[rep]=cumulative_wasserstein
    decoded_error_all_runs[rep,:]=(decoded-true_decoded)**2

    plt.plot((decoded-true_decoded)**2)

# Save
#np.save("W2_cumulative_wasserstein.npy", cumulative_wasserstein_all_runs)
#np.save("W2_global_wasserstein.npy", global_wasserstein_all_runs)
#np.save("W2_decoded_errors.npy", decoded_error_all_runs)
plt.show()


fig, ax = plt.subplots(figsize=(horizontal_size, vertical_size))  #
ax.spines['left'].set_linewidth(linewidth)
ax.spines['bottom'].set_linewidth(linewidth)
ax.tick_params(width=linewidth, length=length_ticks)

# Plot true pdf
# Target distribution
x_initial, y_initial = np.mgrid[-1:5:.01, -1:5:.01]
pos_initial = np.dstack((x_initial, y_initial))

x_final, y_final = np.mgrid[-1:5:.01, -1:5:.01]
pos_final = np.dstack((x_final, y_final))


rv_target = multivariate_normal(mean, cov)
pdf_target =  rv_target.pdf(pos_final)
pdf_target=pdf_target/np.sum(pdf_target)

# Initial distribution
rv_initial = multivariate_normal(mean, cov_init)
pdf_initial = multivariate_normal(mean_init, cov_init).pdf(pos_initial)
pdf_initial = pdf_initial / np.sum(pdf_initial)


# Plot initial and target distribution
Z = pdf_initial - pdf_target                      # + = red, – = blue
# robust symmetric scaling around 0
v = np.percentile(np.abs(Z), 98)                # tune 98–99.9
norm = TwoSlopeNorm(vmin=-v, vcenter=0.0, vmax=v)
# hide tiny diffs so the background is white (tune 0.5–2% of v)
tau = 0.000000001* v
Zm = np.ma.masked_where(np.abs(Z) < tau, Z)
orange = "#f16913"   # strong orange
blue   = "#2b8cbe"   # strong blue
orange_blue = LinearSegmentedColormap.from_list("orange_blue", [blue, "white", orange], N=256)

# Plot
plt.imshow(Zm.T, extent=[-1, 5, -1, 5], origin="lower",cmap=orange_blue, norm=norm, interpolation="bilinear")


flatui = ["#9B59B6", "#3498DB", "#95A5A6", "#E74C3C", "#34495E", "#2ECC71"]
reward_cmap = plt.cm.jet(np.linspace(0., 1., 8)[:-1])
animal_cmap = sns.color_palette(flatui)
raster_cmap = plt.cm.bone_r
asym_cmap = plt.cm.autumn_r
asym_cmap = mcol.LinearSegmentedColormap.from_list("MyCmapName", [reward_cmap[1], reward_cmap[-1]])
vals = np.linspace(0, 1, n_particles, endpoint=False)  # shape (N,)

# Plot learning trajectory
for i in range(n_particles):
    x = np.concatenate([particles_save[::50, i, 0], [particles_save[-1, i, 0]]])
    y = np.concatenate([particles_save[::50, i, 1], [particles_save[-1, i, 1]]])
    if i == n_particles - 1:
        plt.plot(x, y, color="grey", linewidth=0.4,
                 label="Learning trajectory")
    else:
        plt.plot(x, y, color="grey", linewidth=0.4)

distance=np.linalg.norm(particles_init, axis=1)

# Plot target units
plt.scatter(particles_init[:, 0], particles_init[:, 1], c=distance, cmap=asym_cmap, s=5, zorder=n_particles)
plt.scatter(particles[:, 0], particles[:, 1], c=distance, s=5, cmap=asym_cmap, zorder=n_particles + 1)
plt.xlabel("Time")
plt.ylabel("Magnitude")
fig.savefig("example_transport_presentation_DNL.pdf")
plt.show()