import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import ot
import pdb
from aux import sample_reward_no_triangle, closest_to_origin, particle_update, plot_vector_field, plot_mean_displacement

length_ticks = 3
font_size = 10
linewidth = 1.2
scatter_size = 2
mpl.rcParams['axes.spines.right'] = False
mpl.rcParams['axes.spines.top'] = False
mpl.rcParams.update({'font.size': font_size})
mpl.rcParams['lines.linewidth'] = linewidth
mpl.rcParams['xtick.labelsize'] = font_size
mpl.rcParams['ytick.labelsize'] = font_size
mpl.rcParams['lines.linewidth'] = linewidth
horizontal_size = 2.2
vertical_size = 2.2
from mpl_toolkits.axes_grid1 import make_axes_locatable

mpl.use('TkAgg')



n_repetitions=10
n_iterations = 1000
decoded_error_all_runs=np.zeros((n_repetitions,n_iterations))
global_wasserstein_all_runs=np.zeros(n_repetitions)
cumulative_wasserstein_all_runs=np.zeros(n_repetitions)

# Parameters
alpha = 0.001  # Learning rate
gamma = 5  # 0.01
batch_size = 25
bw = 0.4
lamb = 0.07
pad = bw

# Compute true area
reward = sample_reward_no_triangle(100000,pad)
all_conditions = (reward[:, 0] < 9) & (reward[:, 1] > 0) & (reward[:, 1] < 4)
reward_correct=reward[all_conditions]
idx_reward_closest_to_origin = closest_to_origin(reward_correct, int(reward_correct.shape[0] * 0.1))
#R = np.linalg.norm(reward_correct, axis=1)[idx_reward_closest_to_origin]
R = reward_correct[:,0][idx_reward_closest_to_origin]
#true_decoded = np.std(np.abs(reward[:,1][idx_reward_closest_to_origin]))
true_decoded=np.mean(R)


# Initial particles
n_particles = 8 * 8
init_x = 0
end_x = 10
init_y = 0
end_y = 5
x, y = np.mgrid[init_x:end_x:(end_x - init_x) / np.sqrt(n_particles),init_y:end_y:(end_y - init_y) / np.sqrt(n_particles)]
n_particles = x.shape[0] * x.shape[1]
Nx = x.shape[0]
Ny = x.shape[1]
x = np.ndarray.flatten(x)
y = np.ndarray.flatten(y)
x = np.expand_dims(x, axis=1)
y = np.expand_dims(y, axis=1)
particles_init = np.concatenate((x, y), axis=1)

# Create grid where we estimate the log likelihood
# delta x and y for grid on which to compute likelihood
dx_der = 0.25
dy_der = 0.25
# create grid
x_der, y_der = np.mgrid[-0.5:20.5:dx_der, -0.5:9.5:dy_der]
# Number of x and y points in the grid
Nx_der = len(x_der)
Ny_der = len(y_der)
# get x and y of all points in the grid
x_der_flat = np.expand_dims(np.ndarray.flatten(x_der), axis=1)
y_der_flat = np.expand_dims(np.ndarray.flatten(y_der), axis=1)
particles_der = np.concatenate((x_der_flat, y_der_flat), axis=1)



for rep in range(n_repetitions):


    particles = np.copy(particles_init)

    idx_particles_closest_to_origin = closest_to_origin(particles, int(n_particles * 0.1))

    decoded_error=[]

    # Save the particles over time
    particles_save = np.zeros((n_iterations, n_particles, 2))
    particles_save[0, :, :] = particles_init

    # Compute area
    #R = np.linalg.norm(particles, axis=1)[idx_particles_closest_to_origin]
    R = particles[:, 0][idx_particles_closest_to_origin]
    r_med = np.mean(R)
    decoded_error.append(r_med)
    #decoded_error.append(np.std(particles_save[:,1][idx_particles_closest_to_origin]))
    #selected_particles=particles[idx_particles_closest_to_origin]
    #plt.scatter(particles[:,0],particles[:,1])
    #plt.scatter(selected_particles[:,0],selected_particles[:,1])
    #plt.show()

    cumulative_wasserstein = 0

    for iter in range(n_iterations - 1):

        # Get reward samples
        reward=sample_reward_no_triangle(batch_size,pad)

        # Get gradient
        gradient=particle_update(particles, reward, particles_der,x_der,y_der,dx_der,dy_der,bw,lamb,gamma)

        # Update particles
        particles_new = particles - alpha * gradient

        if iter == 0:
            print(np.mean(np.abs(alpha * gradient)))
            # pdb.set_trace()


        # Compute local Wasserstein
        M = ot.dist(particles_new, particles, metric='euclidean') ** 2
        wasserstein_distance = ot.emd2(np.ones(n_particles) * (1.0 / n_particles),np.ones(n_particles) * (1.0 / n_particles), M)
        cumulative_wasserstein += np.sqrt(wasserstein_distance)
        particles = particles_new

        # Compute area
        #R = np.linalg.norm(particles, axis=1)[idx_particles_closest_to_origin]
        R = particles[:, 0][idx_particles_closest_to_origin]
        r_med = np.mean(R)
        decoded_error.append(r_med)
        #decoded_error.append(np.std(particles_save[:, 1][idx_particles_closest_to_origin]))

        # Save particles
        particles_save[iter + 1, :, :] = particles

    # Compute global Wasserstein
    M_global = ot.dist(particles_init, particles, metric='euclidean') ** 2
    wasserstein_global = np.sqrt(ot.emd2(np.ones(n_particles)*(1.0/n_particles), np.ones(n_particles)*(1.0/n_particles), M_global))

    global_wasserstein_all_runs[rep]=wasserstein_global
    cumulative_wasserstein_all_runs[rep]=cumulative_wasserstein
    decoded=np.array(decoded_error)
    decoded_error_all_runs[rep,:]=(decoded-true_decoded)**2

    plt.plot((decoded-true_decoded)**2)
    plt.axhline(y=true_decoded)
plt.show()

#plt.hist(R)
#plt.show()

#pdb.set_trace()
# Save to numpy files
#np.save("Krupic_W2_cumulative_wasserstein.npy",cumulative_wasserstein_all_runs)
#np.save("Krupic_W2_global_wasserstein.npy",global_wasserstein_all_runs)
#np.save("Krupic_W2_decoded_errors.npy",decoded_error_all_runs)


# Plot final distribution
plt.scatter(particles[:, 0], particles[:, 1], color="black", s=scatter_size * 0.5)
plt.scatter(particles[idx_particles_closest_to_origin, 0], particles[idx_particles_closest_to_origin, 1])

#plt.xticks([])
#plt.yticks([])
plt.xlabel("x")
plt.ylabel("y")
plt.show()

plot_vector_field(particles_save,linewidth,"")
plot_mean_displacement(particles_save,linewidth)



