from models.singular_2d import Singular2D
from torch import Tensor
import torch
import numpy as np
import matplotlib.pyplot as plt
import numpy as np
import json
from scipy import stats


def plot_samples(p, ax):
    p = p.numpy()
    x, y = p.T
    xmin, xmax = -L, L
    ymin, ymax = -L, L
    X, Y = np.mgrid[xmin:xmax:100j, ymin:ymax:100j]
    positions = np.vstack([X.ravel(), Y.ravel()])
    values = np.vstack([x, y])
    kernel = stats.gaussian_kde(values)
    Z = np.reshape(kernel(positions).T, X.shape) + 1.e-10
    z_min, z_max = Z.min(), Z.max()
    levels = np.linspace(z_min, z_max, 10)
    contour = ax.contourf(X, Y, Z, levels=levels, cmap='YlGn', vmin=z_min, vmax=z_max)
    contour_lines = ax.contour(X, Y, Z, levels=levels, colors='black', alpha=1, linewidths=0.5)
    ax.scatter(x, y, alpha=0.8, s=0.5, c='#424242', edgecolor=None)

plt.rcParams.update({
    "text.usetex": True,
    "font.family": "serif",
    "font.serif": ["Computer Modern"],
    "font.size": 20,
    "axes.labelsize": 22,
    "axes.titlesize": 24,
    "xtick.labelsize": 18,
    "ytick.labelsize": 18,
    "legend.fontsize": 18,
    "figure.dpi": 300,
})

T = 0.9
N = 1000

# load data
kwargs = {}
with open('./conf/singular_blob.json', mode="r", encoding="utf-8") as file:
    kwargs = json.load(file)

bw_traj = torch.load('./samples/trajectory/backward/singular-blob.pt')
lyap_exps = torch.load('./samples/lyap-exp/singular-blob.pt')
lyap_vects = torch.load('./samples/lyap-vec/singular-blob.pt')
N, n_noise_realizations, n_samples, _  = list(bw_traj.shape)

singular = Singular2D(N, T, device = 'cpu', noise_schedule='cosine', **kwargs)
t = T/N
g = singular.schedule_g(t)
L = 2


n_grid = 500
grid = np.linspace(-L,L,n_grid)
X,Y = np.meshgrid(grid, grid)
xs = np.reshape(X, n_grid**2)
ys = np.reshape(Y, n_grid**2)
pts = np.stack((xs,ys), axis = 1)
pdfs = singular.pdf(Tensor(pts), g)
pdfs = pdfs.reshape((n_grid,n_grid))

#######################################
# # Plot the pdf of the blob
#######################################

t = T/N
g = singular.schedule_g(t)
L = 1.7
n_grid = 1000
grid = np.linspace(-L,L,n_grid)
X,Y = np.meshgrid(grid, grid)
xs = np.reshape(X, n_grid**2)
ys = np.reshape(Y, n_grid**2)
pts = np.stack((xs,ys), axis = 1)
pdfs = singular.pdf(Tensor(pts), g)
pdfs = pdfs.reshape((n_grid,n_grid))

# Define your z range and contour levels
z_min, z_max = pdfs.min(), pdfs.max()
levels = np.linspace(z_min, z_max, 10)

fig, ax = plt.subplots(figsize=(4, 4))
cf = ax.contourf(X, Y, pdfs, levels=levels, cmap='YlGn', vmin=z_min, vmax=z_max)
# cbar = fig.colorbar(cf, ax=ax)
# cbar.set_label(r"Probability Density", rotation=270, labelpad=15)

# ax.set_title(r"Density")
ax.set_xticks([-1,0,1])
ax.set_yticks([-1,0,1])
ax.set_xlim(-L, L)
ax.set_ylim(-L, L)
ax.set_aspect('equal')  

plt.tight_layout()
plt.savefig('./image/blob/blob_pdf.png')
plt.clf()


########################################
# # Plot the generated samples
########################################

fig, ax = plt.subplots(figsize=(4, 4))
plot_samples(bw_traj[0, :3000, 0, :], ax)

ax.set_xticks([-1,0,1])
ax.set_yticks([-1,0,1])
ax.set_xlim(-L, L)
ax.set_ylim(-L, L)
ax.set_aspect('equal')  

plt.tight_layout()
plt.savefig('./image/blob/blob_pdf.png')
plt.clf()

#################################################
# # Plotting a few sample trajectories
#################################################

print("Plotting a few sample trajectories.\n")
fig, ax = plt.subplots(figsize=(4, 4))
n_trajectories = 5
offset = 0
for i in range(n_trajectories):
    ax.plot(bw_traj[:, i+offset, 0, 0], bw_traj[:, i+offset, 0, 1], lw = 0.5, color = 'grey', alpha = 0.5)
    ax.plot(bw_traj[0, i+offset, 0, 0], bw_traj[0, i+offset, 0, 1], 'go', label=r'end' if i == 0 else "")
    ax.plot(bw_traj[-1, i+offset, 0, 0], bw_traj[-1, i+offset, 0, 1], 'ko', label=r'start' if i == 0 else "")

n_s = 500
# s_grid = np.linspace(0, 1, n_s)
# curve = singular.curve(s_grid).detach().numpy()
# plt.plot(curve[:,0], curve[:,1])
cf = ax.contourf(X, Y, pdfs, levels=levels, cmap='YlGn', vmin=z_min, vmax=z_max)
ax.set_xticks([-1,0,1])
ax.set_yticks([-1,0,1])
ax.set_xlim(-L, L)
ax.set_ylim(-L, L)
ax.set_aspect('equal')  
ax.legend(loc = 'lower left')
plt.tight_layout()
plt.savefig('./image/blob/blob_bw_traj.png')
plt.clf()


###################################################
# # Plotting Lyapunov Exponents
####################################################

times = np.linspace(T / N, T, num=N, endpoint=True)
n_samples_plot = 10

fig, ax = plt.subplots(figsize=(8, 5))

for i in range(n_samples_plot):
    ax.plot(times[:N // 2], lyap_exps[:N // 2, i, 0, 0], lw=1, color='darkgreen', label=r'$\lambda_1$' if i == 0 else "")
    ax.plot(times[:N // 2], lyap_exps[:N // 2, i, 0, 1], lw=1, color='black', label=r'$\lambda_2$' if i == 0 else "")

ax.set_xlabel(r"$t$")
ax.set_ylabel(r"$\lambda_i$")
ax.grid(True, which='both', linestyle='--', linewidth=0.5, alpha=0.5)
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.legend(loc='upper right', frameon=False)

plt.tight_layout()
plt.savefig(f'./image/blob/blob_lyap_exps.png')
plt.clf()
plt.close()

################################################
# # Plotting perturbed distrubutions
###############################################

print('Plotting perturbed distributions')
eps = ['0.1', '0.5', '1.0']
n_s = 500
s_grid = np.linspace(0, 1, n_s)
curve = singular.curve(s_grid).detach().numpy()
for ep in eps:
    print(' eps=',ep)
    bw_traj_ = torch.load('./samples/trajectory/backward/singular-blob-' + ep +'.pt')
    fig, ax = plt.subplots(figsize=(4, 4))
    plot_samples(bw_traj_[0, :3000, 0, :], ax)

    ax.set_xticks([-1,0,1])
    ax.set_yticks([-1,0,1])
    ax.set_xlim(-L, L)
    ax.set_ylim(-L, L)
    ax.set_aspect('equal')  

    plt.tight_layout()
    plt.savefig(f'./image/blob/blob_samples_{ep}.png')
    plt.clf()

##############################################
# Angle between BLV and tangent space
##############################################

print('Histograms of angles as function of eps:')
k=0
for ep in eps:
    print(' eps = ', ep)
    bw_traj_ = torch.load(f'./samples/trajectory/backward/singular-blob-{ep}.pt')
    lyap_exps_ = torch.load(f'./samples/lyap-exp/singular-blob-{ep}.pt')
    lyap_vects_ = torch.load(f'./samples/lyap-vec/singular-blob-{ep}.pt')
    endpoints = bw_traj_[0, :, k, :]
    ts = singular.find_closest(endpoints)
    stable_blvs = lyap_vects_[:, k, :, 0]
    tangents = singular.Dcurve(ts)
    tangents = tangents/(torch.linalg.norm(tangents, axis = 1).unsqueeze(1))

    batched_dot = torch.func.vmap(torch.dot)
    thetas = torch.arccos(batched_dot(tangents, stable_blvs)) 
    thetas[thetas > torch.pi/2] = thetas[thetas > torch.pi/2] - torch.pi

    plt.hist(thetas, bins=100, density=True, color='darkgreen')
    plt.xlim(-1.5, 1.5)
    plt.title(f"Angles eps = {ep}")
    plt.savefig(f'./image/blob/blob_thetas_{ep}.png')
    plt.clf()



###############################
# Plot samples next to the perturbed samples
###############################

bw_traj = torch.load('./samples/trajectory/backward/singular-blob.pt')
lyap_vects = torch.load('./samples/lyap-vec/singular-blob.pt')
bw_traj_pert = torch.load('./samples/trajectory/backward/singular-blob-1.0.pt')
lyap_vects_pert = torch.load('./samples/lyap-vec/singular-blob-1.0.pt')

N, n_noise_realizations, n_samples, _  = list(bw_traj.shape)

n_plot_samples = 3000

fig, axs = plt.subplots(1, 2, figsize=(8, 4))
plot_samples(bw_traj[0, :3000, 0, :], axs[0])
plot_samples(bw_traj_pert[0, :3000, 0, :], axs[1])

n_s = 500
s_grid = np.linspace(0, 1, n_s)
curve = singular.curve(s_grid).detach().numpy()
axs[0].plot(curve[:,0], curve[:,1], color = 'darkgreen', alpha = 1)
axs[1].plot(curve[:,0], curve[:,1], color = 'darkgreen', alpha = 1)
axs[0].set_title('Unperturbed ')
axs[1].set_title(r'Perturbed $\varepsilon = 1$')

axs[0].set_yticks([-1, 0, 1])
axs[1].set_yticks([])
axs[0].set_xlim(-L, L)
axs[0].set_ylim(-L, L)
axs[1].set_xlim(-L, L)
axs[1].set_ylim(-L, L)

plt.savefig('./image/blob/blob_samples_perturbed.png')


###############################
# Plot histograms of angles
###############################

bw_traj = torch.load('./samples/trajectory/backward/singular-blob.pt')
lyap_exps = torch.load('./samples/lyap-exp/singular-blob.pt')
lyap_vects = torch.load('./samples/lyap-vec/singular-blob.pt')
bw_traj_1 = torch.load('./samples/trajectory/backward/singular-blob-1.0.pt')
lyap_exps_1 = torch.load('./samples/lyap-exp/singular-blob-1.0.pt')
lyap_vects_1 = torch.load('./samples/lyap-vec/singular-blob-1.0.pt')

N, n_noise_realizations, n_samples, _  = list(bw_traj.shape)

fig, axs = plt.subplots(1, 1, figsize=(6, 4))

endpoints = bw_traj[0, :, 0, :]
endpoints_1 = bw_traj_1[0, :, 0, :]
ts = singular.find_closest(endpoints)
ts_1 = singular.find_closest(endpoints_1)
stable_blvs = lyap_vects[:, 0, :, 0]
stable_blvs_1 = lyap_vects_1[:, 0, :, 0]

tangents = singular.Dcurve(ts)
tangents = tangents/(torch.linalg.norm(tangents, axis = 1).unsqueeze(1))
tangents_1 = singular.Dcurve(ts_1)
tangents_1 = tangents_1/(torch.linalg.norm(tangents_1, axis = 1).unsqueeze(1))

batched_dot = torch.func.vmap(torch.dot)
thetas = torch.arccos(batched_dot(tangents, stable_blvs)) 
thetas[thetas > torch.pi/2] = thetas[thetas > torch.pi/2] - torch.pi
thetas_1 = torch.arccos(batched_dot(tangents_1, stable_blvs_1)) 
thetas_1[thetas_1 > torch.pi/2] = thetas_1[thetas_1 > torch.pi/2] - torch.pi

axs.hist(thetas, bins = 100, alpha = 1, color = 'black', density=True, label=r"$\varepsilon=0$", histtype = 'step')
axs.hist(thetas_1, bins = 100, alpha = 0.8, color = 'green', density=True, label=r"$\varepsilon=1$" , histtype = 'step')
axs.set_xlabel(r"$\theta$")
axs.legend(loc='upper right', frameon = False)
plt.tight_layout()
plt.savefig(f'./image/blob/blob_thetas_hist.png')
plt.close()

##########################
# Plotting pdf + Lyapunov Vectors
##########################
n_plot = 15

z_min, z_max = pdfs.min(), pdfs.max()
levels = np.linspace(z_min, z_max, 10)

fig, ax = plt.subplots(figsize=(4, 4))
cf = ax.contourf(X, Y, pdfs, levels=levels, cmap='YlGn', vmin=z_min, vmax=z_max)
# cbar = fig.colorbar(cf, ax=ax)
# cbar.set_label(r"Probability Density", rotation=270, labelpad=15)

# ax.set_xlabel(r"$x$")
# ax.set_ylabel(r"$y$")
# ax.set_title(r"\textbf{Probability Density Function}")
# for i in range(n_plot):
#     plt.plot(bw_traj[:, i, 0, 0], bw_traj[:, i, 0, 1], lw = 0.5, color = 'grey', alpha = 0.5)
#     plt.plot(bw_traj[0, i, 0, 0], bw_traj[0, i, 0, 1], 'b.')
#     plt.plot(bw_traj[-1, i, 0, 0], bw_traj[-1, i, 0, 1], 'r.')


pts = bw_traj[0, 0:n_plot, 0, :]
lvects = lyap_vects[0:n_plot, 0, :, : ]


x_pts = pts[:, 0]
y_pts = pts[:, 1]
u_pts_top = lvects[:, 0, 0]
v_pts_top = lvects[:, 1, 0]

# u_pts_bot = lvects[:, 0, 1]
# v_pts_bot = lvects[:, 1, 1]

ax.quiver(x_pts, y_pts, u_pts_top, v_pts_top, 
           color = 'black', scale = 10, width = 0.01, 
           headwidth = 3, headlength = 3, headaxislength = 1, label=r'$v_1$')
# plt.quiver(x_pts, y_pts, u_pts_bot, v_pts_bot, 
#            scale = 20, width = 0.005, headwidth = 2, 
#            headlength = 2, headaxislength = 1, label = r'$v_2$')
# plt.legend(loc='upper right', frameon=False)
ax.plot(x_pts, y_pts, '.', color = 'black')

# ax.set_title('Top Lyapunov Vector')
ax.set_xticks([-1,0,1])
ax.set_yticks([-1,0,1])
ax.set_xlim(-L, L)
ax.set_ylim(-L, L)
ax.set_aspect('equal')
plt.tight_layout()
plt.savefig('./image/blob/blob_lvects.png')
plt.clf()


