from models.half_moons import HalfMoons
from torch import Tensor
import torch
import numpy as np
import matplotlib.pyplot as plt
import numpy as np
from scipy import stats

plt.rcParams.update({
    "text.usetex": True,
    "font.family": "serif",
    "font.serif": ["Computer Modern"],
    "font.size": 20,
    "axes.labelsize": 22,
    "axes.titlesize": 24,
    "xtick.labelsize": 18,
    "ytick.labelsize": 18,
    "legend.fontsize": 18,
    "figure.dpi": 300,
})

def plot_samples(p, ax):
    p = p.numpy()
    x, y = p.T
    xmin, xmax = -L, L
    ymin, ymax = -L, L
    X, Y = np.mgrid[xmin:xmax:100j, ymin:ymax:100j]
    positions = np.vstack([X.ravel(), Y.ravel()])
    values = np.vstack([x, y])
    kernel = stats.gaussian_kde(values)
    Z = np.reshape(kernel(positions).T, X.shape) + 1.e-10
    z_min, z_max = Z.min(), Z.max()
    levels = np.linspace(z_min, z_max, 10)
    contour = ax.contourf(X, Y, Z, levels=levels, cmap='YlGn', vmin=z_min, vmax=z_max)
    # contour_lines = ax.contour(X, Y, Z, levels=levels, colors='black', alpha=1, linewidths=0.5)
    # ax.scatter(x, y, alpha=0.8, s=0.5, c='#424242', edgecolor=None)


T = 0.9
N = 1000

# load data

bw_traj = torch.load('./samples/trajectory/backward/halfmoons.pt')
lyap_exps = torch.load('./samples/lyap-exp/halfmoons.pt')
lyap_vects = torch.load('./samples/lyap-vec/halfmoons.pt')
N, n_noise_realizations, n_samples, _  = list(bw_traj.shape)

singular = HalfMoons(N, T, device = 'cpu', noise_schedule='cosine')

##########################
# Plotting pdf 
##########################

t = T/N
g = singular.schedule_g(t)
L = 1.7
n_grid = 1000
grid = np.linspace(-L,L,n_grid)
X,Y = np.meshgrid(grid, grid)
xs = np.reshape(X, n_grid**2)
ys = np.reshape(Y, n_grid**2)
pts = np.stack((xs,ys), axis = 1)
pdfs = singular.pdf(Tensor(pts), g)
pdfs = pdfs.reshape((n_grid,n_grid))

# Define your z range and contour levels
z_min, z_max = pdfs.min(), pdfs.max()
levels = np.linspace(z_min, z_max, 10)

fig, ax = plt.subplots(figsize=(4, 4))
cf = ax.contourf(X, Y, pdfs, levels=levels, cmap='YlGn', vmin=z_min, vmax=z_max)
# cbar = fig.colorbar(cf, ax=ax)
# cbar.set_label(r"Probability Density", rotation=270, labelpad=15)

# ax.set_title(r"Density")
ax.set_xticks([-1,0,1])
ax.set_yticks([-1,0,1])
ax.set_xlim(-L, L)
ax.set_ylim(-L, L)
ax.set_aspect('equal')  

plt.tight_layout()
plt.savefig('./image/halfmoons/sgm_pdf.png')
plt.clf()


##########################
# Plotting pdf + Lyapunov Vectors
##########################
n_plot = 15

z_min, z_max = pdfs.min(), pdfs.max()
levels = np.linspace(z_min, z_max, 10)

fig, ax = plt.subplots(figsize=(4, 4))
cf = ax.contourf(X, Y, pdfs, levels=levels, cmap='YlGn', vmin=z_min, vmax=z_max)
# cbar = fig.colorbar(cf, ax=ax)
# cbar.set_label(r"Probability Density", rotation=270, labelpad=15)

# ax.set_xlabel(r"$x$")
# ax.set_ylabel(r"$y$")
# ax.set_title(r"\textbf{Probability Density Function}")
# for i in range(n_plot):
#     plt.plot(bw_traj[:, i, 0, 0], bw_traj[:, i, 0, 1], lw = 0.5, color = 'grey', alpha = 0.5)
#     plt.plot(bw_traj[0, i, 0, 0], bw_traj[0, i, 0, 1], 'b.')
#     plt.plot(bw_traj[-1, i, 0, 0], bw_traj[-1, i, 0, 1], 'r.')


pts = bw_traj[0, 0:n_plot, 0, :]
lvects = lyap_vects[0:n_plot, 0, :, : ]


x_pts = pts[:, 0]
y_pts = pts[:, 1]
u_pts_top = lvects[:, 0, 0]
v_pts_top = lvects[:, 1, 0]

# u_pts_bot = lvects[:, 0, 1]
# v_pts_bot = lvects[:, 1, 1]

ax.quiver(x_pts, y_pts, u_pts_top, v_pts_top, 
           color = 'black', scale = 10, width = 0.01, 
           headwidth = 3, headlength = 3, headaxislength = 1, label=r'$v_1$')
# plt.quiver(x_pts, y_pts, u_pts_bot, v_pts_bot, 
#            scale = 20, width = 0.005, headwidth = 2, 
#            headlength = 2, headaxislength = 1, label = r'$v_2$')
# plt.legend(loc='upper right', frameon=False)
ax.plot(x_pts, y_pts, '.', color = 'black')

# ax.set_title('Top Lyapunov Vector')
ax.set_xticks([-1,0,1])
ax.set_yticks([-1,0,1])
ax.set_xlim(-L, L)
ax.set_ylim(-L, L)
ax.set_aspect('equal')
plt.tight_layout()
plt.savefig('./image/halfmoons/sgm_lvects.png')
plt.clf()

#############################
# Plotting Lyapunov Exponents for multiple noise realizations
#############################

times = np.linspace(T / N, T, num=N, endpoint=True)
n_samples_plot = 10

fig, ax = plt.subplots(figsize=(8, 5))

for i in range(n_samples_plot):
    ax.plot(times[:N // 2], lyap_exps[:N // 2, i, 0, 0], lw=1, color='darkgreen', label=r'$\lambda_1$' if i == 0 else "")
    ax.plot(times[:N // 2], lyap_exps[:N // 2, i, 0, 1], lw=1, color='black', label=r'$\lambda_2$' if i == 0 else "")

ax.set_xlabel(r"$t$")
ax.set_ylabel(r"$\lambda_i$")
ax.grid(True, which='both', linestyle='--', linewidth=0.5, alpha=0.5)
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.legend(loc='upper right', frameon=False)

plt.tight_layout()
plt.savefig(f'./image/halfmoons/sgm_lyap_exps.png')
plt.clf()
plt.close()


###############################
# Plot samples and perturbed samples
###############################

bw_traj = torch.load('./samples/trajectory/backward/halfmoons.pt')
lyap_vects = torch.load('./samples/lyap-vec/halfmoons.pt')
bw_traj_pert = torch.load('./samples/trajectory/backward/halfmoons-1.0.pt')
lyap_vects_pert = torch.load('./samples/lyap-vec/halfmoons-1.0.pt')

N, n_noise_realizations, n_samples, _  = list(bw_traj.shape)

n_plot_samples = 3000
n_s = 500
upp_grid = np.linspace(-0.5, np.pi, n_s)
low_grid = np.linspace(0, np.pi+0.5, n_s)
upper_curve= HalfMoons._upper_curve(Tensor(upp_grid)).detach().numpy()
lower_curve= HalfMoons._lower_curve(Tensor(low_grid)).detach().numpy()

fig, ax = plt.subplots(figsize=(4, 4))
plot_samples(bw_traj[0, :3000, 0, :], ax)

# ax.plot(upper_curve[:,0], upper_curve[:,1], color = 'red', alpha = 1)
# ax.plot(lower_curve[:,0], lower_curve[:,1], color = 'red', alpha = 1)
# ax.set_title('Unperturbed ')

ax.set_xticks([-1,0,1])
ax.set_yticks([-1,0,1])
ax.set_ylim([-L,L])
ax.set_xlim([-L,L])
ax.set_aspect('equal') 
plt.tight_layout()
plt.savefig('./image/halfmoons/sgm_samples.png')
plt.close()

fig, ax = plt.subplots(figsize=(4, 4))
plot_samples(bw_traj_pert[0, :3000, 0, :], ax)

# ax.plot(upper_curve[:,0], upper_curve[:,1], color = 'red', alpha = 1)
# ax.plot(lower_curve[:,0], lower_curve[:,1], color = 'red', alpha = 1)
# ax.set_title(r'Perturbed $\varepsilon = 1$')


ax.set_xticks([-1,0,1])
ax.set_yticks([-1,0,1])
ax.set_ylim([-L,L])
ax.set_xlim([-L,L])
ax.set_aspect('equal')  
plt.tight_layout()
plt.savefig('./image/halfmoons/sgm_samples_pert.png')
plt.close()

#########################################
# Plot score at various timesteps
#########################################

t1, t2, t3 = 0.01, 0.1, 0.2

def plt_score(t):
    g = singular.schedule_g(t)
    n_grid = 1000
    grid = np.linspace(-L,L,n_grid)
    X,Y = np.meshgrid(grid, grid)
    xs = np.reshape(X, n_grid**2)
    ys = np.reshape(Y, n_grid**2)
    pts = np.stack((xs,ys), axis = 1)
    pdfs = singular.pdf(Tensor(pts), g)

    pdfs = pdfs.reshape((n_grid,n_grid))
    z_min, z_max = pdfs.min(), pdfs.max()
    levels = np.linspace(z_min, z_max, 10)

    fig, ax = plt.subplots(figsize=(4, 4))
    cf = ax.contourf(X, Y, pdfs, levels=levels, cmap='YlGn', vmin=z_min, vmax=z_max)



    n_score_grid = 20
    score_grid = np.linspace(-L,L,n_score_grid)
    X,Y = np.meshgrid(score_grid, score_grid)
    xs = np.reshape(X, n_score_grid**2)
    ys = np.reshape(Y, n_score_grid**2)
    pts = np.stack((xs,ys), axis = 1)
    scores = singular.batched_score(Tensor(pts).unsqueeze(0), g)
    scores_u = scores[0, :, 0]
    scores_v = scores[0, :, 1]
    scores_u = scores_u.reshape((n_score_grid, n_score_grid))
    scores_v = scores_v.reshape((n_score_grid, n_score_grid))

    ax.quiver(X, Y, scores_u, scores_v)
    # cbar = fig.colorbar(cf, ax=ax)
    # cbar.set_label(r"Probability Density", rotation=270, labelpad=15)

    # ax.set_title(f"Score at time {t}")
    ax.set_xticks([-1,0,1])
    ax.set_yticks([-1,0,1])
    ax.set_xlim(-L, L)
    ax.set_ylim(-L, L)
    ax.set_aspect('equal')  

    plt.tight_layout()
    plt.savefig(f'./image/halfmoons/sgm_score_{t}.png')
    plt.clf()


plt_score(t1)
plt_score(t2)
plt_score(t3)


####################################
# Plot histogram differences
####################################

bw_traj_lg = torch.load('./samples/trajectory/backward/halfmoons.pt')
bw_traj_lg_pert = torch.load('./samples/trajectory/backward/halfmoons-1.0.pt')
bw_traj_lg = bw_traj_lg[0, :, 0, :].numpy()
bw_traj_lg_pert = bw_traj_lg_pert[0, :, 0, :].numpy()

hist_grid = 30

hist_bw, x_edges, y_edges = np.histogram2d(bw_traj_lg[:,0], 
                                           bw_traj_lg[:,1], 
                                           bins = hist_grid, 
                                           range = [[-L, L], [-L, L]], density=True)
hist_bw_lg, _, _ = np.histogram2d(bw_traj_lg_pert[:,0], 
                                        bw_traj_lg_pert[:,1], 
                                        bins = [x_edges, y_edges], 
                                        range = [[-L, L], [-L, L]], density = True)

fig, ax = plt.subplots(figsize=(4, 4))
im = ax.imshow(hist_bw - hist_bw_lg, cmap = 'YlGn', extent=(-L, L, -L, L))
cbar = fig.colorbar(im, ax = ax, shrink = 0.5)
# ax.set_title(r"$\pi - \pi^\varepsilon$")
ax.set_xticks([-1,0,1])
ax.set_yticks([-1,0,1])
ax.set_aspect('equal')  
plt.tight_layout()
plt.savefig('./image/halfmoons/sgm_hist_diff.png')
plt.close()
#