import torch
import torch.nn as nn
from models.cnn_vae import CNNVAE
import numpy as np
import PIL.Image as Image
from Record.file_management import read_obj_dumps, strip_instance
from data_utils.vae_dataset import ImageDataset
import glob


def plot_latent_space(model, scale=1.0, n=10, digit_size=64, figsize=15):
    # display a n*n 2D manifold of digits
    figure = np.zeros((digit_size * n, digit_size * n, 3))

    # construct a grid 
    grid_x = np.linspace(-scale, scale, n)
    grid_y = np.linspace(-scale, scale, n)[::-1]

    for i, yi in enumerate(grid_y):
        for j, xi in enumerate(grid_x):
            z_sample = torch.randn(1, latent_dim).to(device)
            x_decoded = model.decode(z_sample)
            recon = x_decoded[0].detach().cpu().permute(1, 2, 0).numpy()
            recon = recon * 255.0
            print(recon.shape, recon.min(), recon.max())
            figure[i * digit_size : (i + 1) * digit_size, j * digit_size : (j + 1) * digit_size, :] = recon

    image = Image.fromarray(figure.astype(np.uint8))
    image.save("latent_space.png")



if __name__ == "__main__":
    ##### WARNING: THIS PART DOESN"T WORK ANYMORE. PLEASE UPDATE ACCORDING TO record_vae_state.py #####
    device = 'cuda'
    latent_dim = 10
    model = CNNVAE(latent_dim=latent_dim, nc=3)
    model.to(device)

    # Load the model from a checkpoint
    # checkpoint_path = "data/20240109-123713_cnn_vae_z10_small_dataset/best_model.pth"
    checkpoint_path = "data/20240109-130700_cnn_vae_z10_small_dataset/best_model.pth"
    model.load_state_dict(torch.load(checkpoint_path))

    plot_latent_space(model)
