import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
import scipy
import math
import pdb
import matplotlib as mpl
import pdb
import scipy.stats
from scipy.stats import norm
from matplotlib.animation import FuncAnimation
from scipy.stats import gaussian_kde
from scipy.stats import multivariate_normal
from scipy.interpolate import RBFInterpolator
import matplotlib.animation as animation
from aux import particle_update

length_ticks = 3
font_size = 8
linewidth = 1.2
scatter_size = 2
mpl.rcParams['axes.spines.right'] = False
mpl.rcParams['axes.spines.top'] = False
mpl.rcParams.update({'font.size': font_size})
mpl.rcParams['lines.linewidth'] = linewidth
mpl.rcParams['xtick.labelsize'] = font_size
mpl.rcParams['ytick.labelsize'] = font_size
mpl.rcParams['lines.linewidth'] = linewidth
horizontal_size = 3.5
vertical_size = 3.5
from mpl_toolkits.axes_grid1 import make_axes_locatable

mpl.use('TkAgg')

# Parameters for distributional neural leaning
n_iterations = 350
alpha = 0.01 # Learning rate
gamma = 1000
batch_size = 25
bw = 0.4  # For Gaussian KDE
lamb = 0.08
pad = bw + 2.5

# Generate hexagonal grid
ratio = np.sqrt(3) / 2  # cos(60°)
N = 121
N_X = int(np.sqrt(N) / ratio)
N_Y = N // N_X
xv, yv = np.meshgrid(np.arange(N_X), np.arange(N_Y), sparse=False, indexing='xy')
xv = xv * ratio
xv[::2, :] += ratio / 2
xv = xv.flatten()
yv = yv.flatten()
selected = np.where(((xv - 5) ** 2) + ((yv - 5) ** 2) < 9)[0]
xv = xv[selected]
yv = yv[selected]

# Plot grid
# plt.scatter(xv, yv)
# plt.show()


n_particles = len(selected)


xv = xv[:, None]
yv = yv[:, None]
particles_init = np.concatenate((xv, yv), axis=1)
particles = particles_init

# To compute gradient
dx_der = 0.25
dy_der = 0.25
x_der, y_der = np.mgrid[np.min(particles[:, 0]) - 1:np.max(particles[:, 0]) + 1:dx_der,
               np.min(particles[:, 1]) - 1:np.max(particles[:, 1]) + 1:dy_der]
Nx_der = len(x_der)
Ny_der = len(y_der)
x_der_flat = np.expand_dims(np.ndarray.flatten(x_der), axis=1)
y_der_flat = np.expand_dims(np.ndarray.flatten(y_der), axis=1)
particles_der = np.concatenate((x_der_flat, y_der_flat), axis=1)
gradient_f1 = np.zeros((n_particles, 2))
gradient_f2 = np.zeros((n_particles, 2))

# Reward is given in these locations:
means = [np.array([6.5, 6]), np.array([4.5, 3.5]), np.array([3, 4.5])]
cov = np.array([[1, 0], [0, 1]]) * 0.5
reward = np.zeros((batch_size, 2))
for iter in range(n_iterations - 1):

    # Give reward positions
    sample_gauss = np.random.randint(3, size=batch_size)

    # Sample reward from these 3 locations
    for i_gaussian in range(3):
        pos_samples = np.where(sample_gauss == i_gaussian)[0]
        x, y = np.random.multivariate_normal(means[i_gaussian], cov, len(pos_samples)).T
        reward[pos_samples, 0] = x
        reward[pos_samples, 1] = y

    # Compute likelihood
    RBF = gaussian_kde(reward.T, bw_method=bw)
    likelihood = np.log(np.reshape(RBF(particles_der.T), (x_der.shape[0], x_der.shape[1])))
    likelihood = likelihood / np.sum(likelihood)

    # Gradient of F1
    gradient_likelihood = np.gradient(likelihood, dx_der, dy_der)
    bins_x = np.digitize(particles[:, 0], x_der[:, 0]) - 1
    bins_y = np.digitize(particles[:, 1], y_der[0, :]) - 1
    gradient_f1[:, 0] = -np.ndarray.flatten(gradient_likelihood[0][bins_x, bins_y])
    gradient_f1[:, 1] = -np.ndarray.flatten(gradient_likelihood[1][bins_x, bins_y])

    # Gradient of F2
    particles_matrix = np.tile(particles, (n_particles, 1, 1))
    dif_matrix = -np.subtract(particles_matrix[:, :, :], particles[:, np.newaxis, :])
    distance_matrix = np.abs(dif_matrix)
    gradient_f2 = np.sum(gamma * (distance_matrix / lamb - 1) * np.exp(-distance_matrix / lamb) * dif_matrix, axis=0)
    gradient_f2 = gradient_f2 / n_particles
    # Sum gradient
    gradient = particle_update(particles, reward, particles_der,x_der,y_der,dx_der,dy_der,bw,lamb,gamma)


    #gradient_f1 + gradient_f2

    # Update particles
    particles = particles - alpha * gradient

plt.figure()
plt.scatter(particles_init[:, 0], particles_init[:, 1])
plt.scatter(particles[:, 0], particles[:, 1])
plt.show()

# Compute distance to goals
distance_before = []
distance_after = []

# Compute distance to first goal
distance_before_goal_1 = []
distance_after_goal_1 = []

for particle in range(n_particles):
    closest_goal = np.argmin(np.sum((means - particles_init[particle, :]) ** 2, axis=1))

    distance_before.append(np.sqrt(np.sum((means[closest_goal] - particles_init[particle, :]) ** 2)))
    distance_after.append(np.sqrt(np.sum((means[closest_goal] - particles[particle, :]) ** 2)))

    distance_before_goal_1.append(np.sqrt(np.sum((means[0] - particles_init[particle, :]) ** 2)))
    distance_after_goal_1.append(np.sqrt(np.sum((means[0] - particles[particle, :]) ** 2)))

# Smooth as in the data plot
kernel_before = stats.gaussian_kde(distance_before, bw_method=0.5)
kernel_after = stats.gaussian_kde(distance_after, bw_method=0.5)
x_pdf = np.linspace(-0.2, 3, 50)
pdf_before = kernel_before(x_pdf)
pdf_before = pdf_before / np.sum(pdf_before)
pdf_after = kernel_after(x_pdf)
pdf_after = pdf_after / np.sum(pdf_after)

fig, ax = plt.subplots(figsize=(1, 1))  #
ax.spines['left'].set_linewidth(linewidth)
ax.spines['bottom'].set_linewidth(linewidth)
ax.plot(x_pdf, pdf_before, color="blue")
ax.plot(x_pdf, pdf_after, color="red")
ax.set_xlabel("Distance to closest goal")
ax.set_ylabel("Density")
ax.set_xticks([])
ax.set_yticks([])
plt.show()

distance_before_goal_1 = np.array(distance_before_goal_1)
distance_after_goal_1 = np.array(distance_after_goal_1)
order = np.argsort(distance_before_goal_1)

fig, ax = plt.subplots(figsize=(1.5, 1.5))  #
ax.spines['left'].set_linewidth(linewidth)
ax.spines['bottom'].set_linewidth(linewidth)
ax.scatter(distance_before_goal_1[order], np.abs(distance_before_goal_1[order] - distance_after_goal_1[order]) / (
            distance_before_goal_1[order] + distance_after_goal_1[order]), color="purple", marker="|")
ax.set_xlabel("Distance to goal")
ax.set_ylabel("Attraction strength")
ax.set_xticks([])
ax.set_yticks([0, 1])
plt.show()

distance_before = np.array(distance_before)
distance_after = np.array(distance_after)
n_particles_close_to_reward_before = len(np.where(distance_before < 0.8)[0])
n_particles_close_to_reward_after = len(np.where(distance_after < 0.8)[0])

fig, ax = plt.subplots(figsize=(0.5, 1))  #
ax.spines['left'].set_linewidth(linewidth)
ax.spines['bottom'].set_linewidth(linewidth)
ax.bar(["Pre", "Post"],
       [n_particles_close_to_reward_before / n_particles, n_particles_close_to_reward_after / n_particles],
       color=['blue', 'red'], width=0.5)
ax.set_ylabel("Prop. of fields at goals")
plt.show()

fig, ax = plt.subplots(figsize=(0.5, 1))
ax.spines['left'].set_linewidth(linewidth)
ax.spines['bottom'].set_linewidth(linewidth)
ax.scatter(distance_before_goal_1[order], np.abs(distance_before_goal_1[order] - distance_after_goal_1[order]) / (
            distance_before_goal_1[order] + distance_after_goal_1[order]), color="purple", marker="|")
ax.set_xlabel("Distance to goal")
ax.set_ylabel("Attraction strength")
ax.set_xticks([])
ax.set_yticks([0, 1])
plt.show()

# Replicate paper figure
fig, ax = plt.subplots(figsize=(1, 1))  #
ax.spines['left'].set_linewidth(linewidth)
ax.spines['bottom'].set_linewidth(linewidth)
n_spikes = 25
cov = np.array([[1, 0], [0, 1]]) * 0.1
for part in range(n_particles):
    spike_x_before, spike_y_before = np.random.multivariate_normal(particles_init[part, :], cov, n_spikes).T
    ax.scatter(spike_x_before, spike_y_before, color="blue", s=0.1)

    spike_x_after, spike_y_after = np.random.multivariate_normal(particles[part, :], cov, n_spikes).T
    ax.scatter(spike_x_after, spike_y_after, color="red", s=0.1)

for i in range(3):
    plt.scatter(means[i][0], means[i][1], s=25, color="k")
ax.set_xticks([])
ax.set_yticks([])
ax.set_xlabel("x")
ax.set_ylabel("y")
plt.show()
