import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import json
from models.trained_images import TrainedImage
import torch
import math
import os
import numpy as np

plt.rcParams.update({
    "text.usetex": True,
    "font.family": "serif",
    "font.serif": ["Computer Modern"],
    "font.size": 20,
    "axes.labelsize": 22,
    "axes.titlesize": 24,
    "xtick.labelsize": 18,
    "ytick.labelsize": 18,
    "legend.fontsize": 18,
    "figure.dpi": 300,
})


def plot_image_grid(images, dir, name, cols=10, figsize=(12, 6), 
                    labels = None, cmap=None, span = None):
    """
    Plots a list of images in a grid.

    Parameters:
    - images: list of image arrays (H x W or H x W x C)
    - cols: number of columns in the grid
    - titles: optional list of titles for each image
    - figsize: size of the entire figure
    - cmap: colormap for grayscale images (e.g., 'gray')
    """
    n_images = len(images)
    rows = math.ceil(n_images / cols)

    fig, axes = plt.subplots(rows, cols, figsize=figsize)
    axes = axes.flatten()

    for i in range(len(axes)):
        ax = axes[i]
        ax.axis('off')
        if labels is not None:
            label = labels[i]
            ax.set_title(label)
        if i < n_images:
            if span is not None:
                ax.imshow(images[i], cmap=cmap, vmin = span[0], vmax = span[1])
            else:
                ax.imshow(images[i], cmap=cmap)
    plt.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0, hspace=0)
    plt.tight_layout()
    plt.savefig(os.path.join(dir, name), bbox_inches='tight', pad_inches=0) 
    plt.clf()

kwargs = {}
with open('./conf/mnist.json', mode="r", encoding="utf-8") as file:
    kwargs = json.load(file)
T = 0.9
bw_traj = torch.load('./samples/trajectory/backward/mnist.pt')
bw_lyap = torch.load('./samples/lyap-exp/mnist.pt')
bw_vect = torch.load('./samples/lyap-vec/mnist.pt')
N, n_noise_realizations, n_samples, _  = list(bw_traj.shape)
name = kwargs.get('model_name')

image_model = TrainedImage(N, T, **kwargs)

##############################
## Plotting bw trajectory
##############################

n_denoise = 5

step = N//n_denoise
imgs = [image_model.vec_to_image(bw_traj[k,0,0])[0] for k in range(0,N, step)]
imgs = list(reversed(imgs))
# lbls = [r"$\Phi_{" + str(N-k) + "}$" for k in range(0, N, step)]
plot_image_grid(imgs, dir = './image/mnist', name = 'mnist_bw_sample', 
                cols = n_denoise, cmap = 'Greys')

##############################
## Plotting Lyapunov Spectrum
##############################

plt.figure(figsize=(20,4))
sorted, idxs = torch.sort(bw_lyap[0,0,0,:],)
indices = idxs[None, :]
indices = indices.expand(N, -1)
spectrum = torch.gather(bw_lyap[:, 0,0 ,:], dim = -1, index=indices)

plt.pcolormesh(spectrum.T, vmin = sorted[0]-0.5, vmax=sorted[-1]+0.5)
plt.colorbar()
plt.savefig('./image/mnist/mnist_spectogram.png')
plt.clf()


#############################
## Lyapunov Exponents
#############################
i_target = 21  # example index
lam_target = sorted[-i_target]
min_lam = min(sorted)
max_lam = max(sorted)

lower_lim = int(min_lam - 1)
upper_lim = int(max_lam + 1) 

plt.figure(figsize=(6, 4))
plt.plot(reversed(sorted), '.', color='black', markersize=2)
plt.xlabel(r"$i$")
plt.ylabel(r"$\lambda_i$")
plt.grid(True, which='both', linestyle='--', linewidth=0.5, alpha=0.5)

# plt.axvline(x=i_target, color='black', linestyle='--', linewidth=0.5, alpha = 1)
# plt.plot([i_target, i_target], [lower_lim, lam_target], color='black', linestyle='--', linewidth=0.5, alpha = 1)

# plt.plot(i_target, lam_target, 'o', color='black')
# plt.annotate(f"$({i_target}, {lam_target:.3f})$", xy=(i_target, lam_target), xytext=(i_target + 50, lam_target),
#              fontsize=15, color='black')

plt.ylim([lower_lim, upper_lim])
ax = plt.gca()
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
# ax.set_title('Lyapunov exponents for sampling dynamics.')

plt.tight_layout()
plt.savefig('./image/mnist/mnist_lyap_spectrum.svg')
plt.clf()


#########################
## Sample perturbed by Backward Lyapunov vectors
#########################

n_tangents = 5
eps = 10
img = bw_traj[0,0,0]
m = min(img)
M = max(img)
span = (m, M)

lbl = r"$X_T$"
imgs1 = [image_model.vec_to_image(img + eps*bw_vect[0,0, :, k])[0] for k in range(0,n_tangents-1)]
lbls1 = [r"$X_T+\epsilon v_{" + str(k+1) + "}$" for k in range(0,n_tangents-1)]

offset = 20
imgs2 = [image_model.vec_to_image(img + eps*bw_vect[0,0, :, k])[0] for k in range(offset,offset + n_tangents)]
lbls2 = [r"$X_T+\epsilon v_{" + str(k+1) + "}$" for k in range(offset,offset + n_tangents)]

imgs = [image_model.vec_to_image(img)[0]]
imgs = imgs + imgs1 + imgs2
lbls = [lbl]
lbls = lbls + lbls1 + lbls2

plot_image_grid(imgs, dir = './image/mnist', name = 'mnist_sample_perturbed', 
                labels =lbls, cols = n_denoise, cmap = 'Greys', span=span)


#########################
## Backward Lyapunov vectors
#########################


n_tangents = 5
eps = 10
img = bw_traj[0,0,0]

imgs1 = [image_model.vec_to_image(bw_vect[0,0, :, k])[0] for k in range(0,n_tangents)]
lbls1 = [r"$v_{" + str(k+1) + "}$" for k in range(0,n_tangents)]

imgs2 = [image_model.vec_to_image(bw_vect[0,0, :, k])[0] for k in range(offset,offset + n_tangents)]
lbls2 = [r"$v_{" + str(k+1) + "}$" for k in range(offset,offset + n_tangents)]

imgs =  imgs1 + imgs2
lbls =  lbls1 + lbls2

plot_image_grid(imgs, dir = './image/mnist', name = 'bw_vects_mnist', 
                labels = lbls, cols = n_denoise, cmap = 'YlGn')


#################################
### Combined Perturb + Vect
#################################

m = min(bw_traj[0, 0, 0])
M = max(bw_traj[0, 0, 0])


span = (m, M)
span = (-1.5, 1.5)

bw_indices = [0, 1, 19, 49, 99]
eps = 1
scale = max(abs(bw_traj[0,0,0]))

n_denoise = len(bw_indices)
step = N//n_denoise
imgs_traj = [image_model.vec_to_image(bw_traj[k,0,0])[0] for k in range(0,N, step)]
lbls_traj = [r"$X_{" + str(N-k) + "}$" for k in range(0, N, step)]
imgs_traj = list(reversed(imgs_traj))
lbls_traj = list(reversed(lbls_traj))

imgs_vec = [image_model.vec_to_image(bw_vect[0,0, :, k]/max(abs(bw_vect[0,0, :, k])))[0] for k in bw_indices]
lbls_vec = [r"$v_{" + str(k+1) + "}$" for k in bw_indices]

imgs_pert = [image_model.vec_to_image(img + eps*scale*bw_vect[0,0, :, k]/max(abs(bw_vect[0,0, :, k])))[0] 
             for k in bw_indices]
lbls_pert = [r"$X_T+\epsilon v_{" + str(k+1) + "}$" for k in bw_indices]

plot_image_grid(imgs_traj, dir = './image/mnist', name = 'mnist_traj', labels=lbls_traj, 
                cols = len(bw_indices), cmap = 'gray', span=span)
plot_image_grid(imgs_vec, dir = './image/mnist', name = 'mnist_lvect', labels=lbls_vec, 
                cols = len(bw_indices), cmap = 'gray')
plot_image_grid(imgs_pert, dir = './image/mnist', name = 'mnist_lvect_pert', labels=lbls_pert, 
                cols = len(bw_indices), cmap = 'gray', span = span)

###############################
# Plot specific images
###############################

sample = image_model.vec_to_image(bw_traj[0,0,0])[0]
pert_1 = image_model.vec_to_image(img + eps*scale*bw_vect[0,0, :, 0]/max(abs(bw_vect[0,0, :, 0])))[0] 
pert_100 = image_model.vec_to_image(img + eps*scale*bw_vect[0,0, :, 99]/max(abs(bw_vect[0,0, :, 99])))[0] 
cmap = 'gray'
imgs = [sample, pert_1, pert_100]
lbls = ['mnist_sample.svg','mnist_pert_1.svg', 'mnist_pert_100.svg']
dir = './image/mnist'
span = (-1.5, 1.5)

for im, lbl in zip(imgs, lbls):
    fig, ax = plt.subplots(figsize = (4,4))
    ax.imshow(im, cmap=cmap, vmin = span[0], vmax = span[1])
    ax.axis('off')
    plt.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0, hspace=0)
    plt.tight_layout()
    plt.savefig(os.path.join(dir, lbl), bbox_inches='tight', pad_inches=0) 
    plt.close()


#######################################
# Combined perts + spectrum
#######################################


fig = plt.figure(figsize=(10, 3))  
axs = []
axs.append(fig.add_axes([0.05, 0.1, 0.2, 0.8]))  
axs.append(fig.add_axes([0.30, 0.1, 0.2, 0.8]))  
axs.append(fig.add_axes([0.55, 0.1, 0.2, 0.8]))  
axs.append(fig.add_axes([0.80, 0.1, 0.15, 0.8])) 

for i in range(3):
    axs[i].imshow(imgs[i], cmap=cmap, vmin=span[0], vmax=span[1])
    axs[i].axis('off')

axs[3].plot(reversed(sorted), '.', color='black', markersize=2)
axs[3].set_xlabel(r"$i$")
axs[3].set_ylabel(r"$\lambda_i$")
axs[3].grid(True, which='both', linestyle='--', linewidth=0.5, alpha=0.5)

plt.savefig('./image/mnist/mnist_combined.png')


##########################################
# Plot a grid of samples
##########################################

n_samps = 20
bw_traj_ = torch.load('./samples/trajectory/backward/mnist-samples.pt')

imgs = [image_model.vec_to_image(bw_traj_[0,k,0])[0] for k in range(0,n_samps)]
plot_image_grid(imgs, dir = './image/mnist', name = 'mnist_samples', 
                cols = 5, cmap = 'Greys')
