import matplotlib.pyplot as plt
import numpy as np
import math
import torch
import torchvision.transforms as T
from torchvision.datasets import MNIST
import jax
import jax.numpy as jnp
import equinox as eqx
import os
from typing import Sequence
import optax

def invert_latent(loss_and_grad,
                  obs_img,
                  n_steps,
                  n_mc,
                  lr,
                  theta_init,
                  key):
    obs = jnp.asarray(obs_img).reshape(-1)
    optimiser = optax.adam(lr)
    opt_state = optimiser.init(theta_init)

    @eqx.filter_jit
    def step(mean, opt_state, key, obs, n_mc, optimiser):
        key, subkey = jax.random.split(key)
        loss, grad  = loss_and_grad(mean, subkey,
                                    obs=obs, 
                                    n_mc=n_mc)
        updates, opt_state = optimiser.update(grad, opt_state, mean)
        mean = eqx.apply_updates(mean, updates)
        return mean, opt_state, loss, key

    mean = theta_init
    for _ in range(n_steps):
        mean, opt_state, _, key = step(mean, 
                                       opt_state, 
                                       key,
                                       obs=obs, 
                                       n_mc=n_mc,
                                       optimiser=optimiser)
    return mean 

def load_model(save_dir,
               generator_name,
               generator_state_name,
               key,
               latent_dim,
               image_shape):
    gen_path = os.path.join(save_dir, generator_name)
    gen_state_path = os.path.join(save_dir, generator_state_name)

    key, subkey = jax.random.split(key) 

    gen_template = Generator(latent_dim, image_shape, key=subkey)
    g_state_template = eqx.nn.State(gen_template)
    gen = eqx.tree_deserialise_leaves(gen_path, gen_template)
    g_state = eqx.tree_deserialise_leaves(gen_state_path, g_state_template)

    return gen, g_state
    
def get_loader(image_shape, batch_size):
    h, w, _ = image_shape
    transform = T.Compose([
        T.Resize((h, w)),
        T.ToTensor(),
        T.Normalize((0.5,), (0.5,)),
    ])

    ds = MNIST(root="./data", train=True, download=True, transform=transform)
    return torch.utils.data.DataLoader(
        ds,
        batch_size=batch_size,
        shuffle=False, 
        num_workers=0 
    )
    
def visualize_images_grid(
    images,
    img_height,
    img_width,
    figsize_scale=1.2,
    save_path=None,
    title="Generated Images"
):
    num_images, flattened_size = images.shape
    normalize_range = (-1.0, 1.0)
    cmap = 'gray'
    images_np = np.array(images)

    min_val, max_val = normalize_range
    images_processed = (images_np - min_val) / (max_val - min_val)
    images_processed = np.clip(images_processed, 0.0, 1.0)

    ncols = math.ceil(math.sqrt(num_images))
    nrows = math.ceil(num_images / ncols)
    figsize = (ncols * figsize_scale, nrows * figsize_scale)
    fig, axes = plt.subplots(nrows, ncols, figsize=figsize, squeeze=False)

    axes_flat = axes.flatten()


    for i in range(num_images):
        img_flat = images_processed[i]
        img_flat = images_processed[i]
        img_2d = img_flat.reshape((img_height, img_width))
        ax = axes_flat[i]
        ax.imshow(img_2d, cmap=cmap, vmin=0, vmax=1)
        ax.axis('off')
    for j in range(num_images, nrows * ncols):
        axes_flat[j].axis('off')

    if title:
        fig.suptitle(title, fontsize=16)

    fig.tight_layout(rect=[0, 0.03, 1, 0.95 if title else 1])

    if save_path:
        try:
            fig.savefig(save_path)
            print(f"Figure saved to {save_path}")
        except Exception as e:
            print(f"Error saving figure to {save_path}: {e}")

    plt.show()

    return fig, axes

def _conv_block(
    cin,
    cout,
    kernel,
    stride,
    pad,
    *,
    transpose,
    norm=None,
    activate=True,
    key,
):
    if norm is None:
        norm = transpose  
    cls = eqx.nn.ConvTranspose2d if transpose else eqx.nn.Conv2d
    layers: list[eqx.Module | callable] = [cls(cin, cout, kernel, stride, padding=pad, use_bias=False, key=key)]
    if norm:
        layers.append(eqx.nn.BatchNorm(cout, axis_name="batch"))
    if activate:
        layers.append(jax.nn.relu if transpose else (lambda x: jax.nn.leaky_relu(x, 0.2)))
    return layers

class Generator(eqx.Module):
    layers: Sequence
    
    def __init__(self, latent_dim, image_shape, key):
        _, _, c = image_shape
        k1, k2, k3, k4 = jax.random.split(key, 4)
        self.layers = (
            _conv_block(latent_dim, 128, 4, 1, 0, transpose=True, key=k1)
            + _conv_block(128, 64, 4, 2, 1, transpose=True, key=k2)
            + _conv_block(64, 32, 4, 2, 1, transpose=True, key=k3)
            + [eqx.nn.ConvTranspose2d(32, c, 3, stride=1, padding=1, use_bias=False, key=k4), jax.nn.tanh]
        )

    def __call__(self, x, state):
        for layer in self.layers:
            if isinstance(layer, eqx.nn.BatchNorm):
                x, state = layer(x, state=state)
            else:
                x = layer(x)
        return x, state