import pdb
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from scipy.stats import gaussian_kde
from matplotlib import cm
from matplotlib.colors import ListedColormap
import ot
from scipy.spatial import ConvexHull
from aux import sample_reward_no_triangle, closest_to_origin, mmd2_grad, plot_vector_field, plot_mean_displacement


length_ticks = 3
font_size = 10
linewidth = 1.2
scatter_size = 2
mpl.rcParams['axes.spines.right'] = False
mpl.rcParams['axes.spines.top'] = False
mpl.rcParams.update({'font.size': font_size})
mpl.rcParams['lines.linewidth'] = linewidth
mpl.rcParams['xtick.labelsize'] = font_size
mpl.rcParams['ytick.labelsize'] = font_size
mpl.rcParams['lines.linewidth'] = linewidth
horizontal_size = 2.2
vertical_size = 2.2
from mpl_toolkits.axes_grid1 import make_axes_locatable

mpl.use('TkAgg')




n_iterations = 1000
n_repetitions=1

alpha = 4.5 #4.5 * 10  # Learning rate
batch_size = 10
bw = 0.4
pad = bw


# Compute true area
reward = sample_reward_no_triangle(100000,pad)
all_conditions = (reward[:, 0] < 9) & (reward[:, 1] > 0) & (reward[:, 1] < 4)
reward_correct=reward[all_conditions]
idx_reward_closest_to_origin = closest_to_origin(reward_correct, int(reward_correct.shape[0] * 0.1))
#R = np.linalg.norm(reward_correct, axis=1)[idx_reward_closest_to_origin]
R = reward_correct[:,0][idx_reward_closest_to_origin]

#true_decoded = np.std(np.abs(reward[:,1][idx_reward_closest_to_origin]))
true_decoded=np.mean(R)



# reward_selected=reward[idx_reward_closest_to_origin]
# plt.scatter(reward[:,0],reward[:,1])
# plt.scatter(reward_selected[:,0],reward_selected[:,1])
# plt.show()


# Initial particles
init_x = 0
end_x = 10
init_y = 0
end_y = 5
n_particles = 8 * 8
x, y = np.mgrid[init_x:end_x:(end_x - init_x) / np.sqrt(n_particles),
       init_y:end_y:(end_y - init_y) / np.sqrt(n_particles)]
n_particles = x.shape[0] * x.shape[1]
Nx = x.shape[0]
Ny = x.shape[1]
x = np.ndarray.flatten(x)
y = np.ndarray.flatten(y)
x = np.expand_dims(x, axis=1)
y = np.expand_dims(y, axis=1)
particles_init = np.concatenate((x, y), axis=1)

# To compute gradient
dx_der = 0.25
dy_der = 0.25
x_der, y_der = np.mgrid[-0.5:20.5:dx_der, -0.5:9.5:dy_der]
Nx_der = len(x_der)
Ny_der = len(y_der)
x_der_flat = np.expand_dims(np.ndarray.flatten(x_der), axis=1)
y_der_flat = np.expand_dims(np.ndarray.flatten(y_der), axis=1)
particles_der = np.concatenate((x_der_flat, y_der_flat), axis=1)
gradient_f1 = np.zeros((n_particles, 2))
gradient_f2 = np.zeros((n_particles, 2))

sigma_list=[6]#[0.1,0.25,0.75,1,2,3,6]

for sigma in sigma_list:
    decoded_error_all_runs = np.zeros((n_repetitions, n_iterations))
    global_wasserstein_all_runs = np.zeros(n_repetitions)
    cumulative_wasserstein_all_runs = np.zeros(n_repetitions)

    Sigma = np.diag([sigma, sigma])

    for rep in range(n_repetitions):

        decoded = []

        particles = np.copy(particles_init)

        idx_particles_closest_to_origin = closest_to_origin(particles, int(n_particles * 0.1))

        # Save the particles over time
        particles_save = np.zeros((n_iterations, n_particles, 2))
        particles_save[0, :, :] = particles_init

        # Compute average
        R = np.linalg.norm(particles, axis=1)[idx_particles_closest_to_origin]
        R = particles[:, 0][idx_particles_closest_to_origin]
        r_med = np.mean(R)  # np.quantile(R, 0.5)
        decoded.append(r_med)
        #decoded.append(np.std(particles_save[:, 1][idx_particles_closest_to_origin]))

        cumulative_wasserstein = 0

        for iter in range(n_iterations - 1):


            reward=sample_reward_no_triangle(batch_size,pad)


            gradient = mmd2_grad(particles, reward, Sigma)

            # Update particles
            particles_new = particles - alpha * gradient
            #idx_out = (particles_new[:, 0] < 0) | (particles_new[:, 0] > 9) | (particles_new[:, 1] < 0) | (particles_new[:, 1] > 4.5)
            #particles_new[idx_out,:]=particles[idx_out,:]

            if iter==0:
                print(np.mean(np.abs(alpha*gradient)))
                #pdb.set_trace()

            #plt.hist(np.abs(alpha * gradient))
            #plt.show()

            # Compute local Wasserstein
            M = ot.dist(particles_new, particles, metric='euclidean') ** 2
            wasserstein_distance = ot.emd2(np.ones(n_particles) * (1.0 / n_particles),np.ones(n_particles) * (1.0 / n_particles), M)
            cumulative_wasserstein += np.sqrt(wasserstein_distance)
            particles = particles_new


            # Compute area
            #R = np.linalg.norm(particles,axis=1)[idx_particles_closest_to_origin]
            R=particles[:,0][idx_particles_closest_to_origin]
            r_med = np.mean(R)  # np.quantile(R, 0.5)
            decoded.append(r_med)
            #decoded.append(np.std(particles_save[:, 1][idx_particles_closest_to_origin]))

            # Save
            particles_save[iter + 1, :, :] = particles


        M_global = ot.dist(particles_init, particles, metric='euclidean') ** 2
        wasserstein_global = np.sqrt(ot.emd2(np.ones(n_particles)*(1.0/n_particles), np.ones(n_particles)*(1.0/n_particles), M_global))

        #print("global: ",wasserstein_global )
        #print("cumulative: ",cumulative_wasserstein )
        #print("denominator: ",wasserstein_global/cumulative_wasserstein)

        global_wasserstein_all_runs[rep]=wasserstein_global
        cumulative_wasserstein_all_runs[rep]=cumulative_wasserstein
        decoded=np.array(decoded)
        decoded_error_all_runs[rep,:]=(decoded-true_decoded)**2#decoded#-true_decoded)**2

    # Save
    #np.save("Krupic_MMD_cumulative_wasserstein_"+str(sigma)+".npy",cumulative_wasserstein_all_runs)
    #np.save("Krupic_MMD_global_wasserstein_"+str(sigma)+".npy",global_wasserstein_all_runs)
    #np.save("Krupic_MMD_decoded_errors_"+str(sigma)+".npy",decoded_error_all_runs)

    plt.plot((decoded-true_decoded)**2,label=str(sigma))
plt.legend()
plt.show()

#plt.hist(R)
#plt.show()

# Plot final distribution
plt.scatter(particles[:, 0], particles[:, 1], color="black", s=scatter_size * 0.5)
plt.scatter(particles[idx_particles_closest_to_origin, 0], particles[idx_particles_closest_to_origin, 1])

#plt.xticks([])
#plt.yticks([])
plt.xlabel("x")
plt.ylabel("y")
plt.show()


plot_vector_field(particles_save,linewidth,"example_Krupic_MMD_sigma_6.pdf")


plot_mean_displacement(particles_save,linewidth)


