from augmentor import VAE
import torch
from torchvision.utils import save_image
import torchvision
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import seaborn as sns
import torchvision.transforms.functional as F
# from scipy.optimize import root
import numpy as np
# from functools import partial
# from scipy.stats import multivariate_normal
import os
import shutil
import re
from vendi_score import image_utils
from PIL import Image as im


def main(checkpoint_name, 
         plot_prototypes: bool = False,
         generate_TSNE: bool = False,
         generate_random_examples: bool = False,
         generate_dimension_prototypes: bool = False,
         generate_dimensions_scatter: bool = False
         ):

    distill_data = torch.load(f'checkpoints/{checkpoint_name}')
    parent_dir = 'data'
    save_path = os.path.join(parent_dir, checkpoint_name)
    if os.path.exists(save_path) and os.path.isdir(save_path):
        shutil.rmtree(save_path)
    os.mkdir(save_path)

    kernel_num = distill_data['kernel_num']
    z_dim =  distill_data['z_dim']
    title_info = re.findall(r'\d+', checkpoint_name)
    num_class = 10
    ipc = int(title_info[0])
    n_prototypes = num_class * ipc

    im_size = (32, 32)
    channel = 3
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    dsa_params = VAE("test", im_size[0], channel, kernel_num=kernel_num, z_size=z_dim)
    image_syn = torch.randn(size=(n_prototypes, z_dim, 2))

    dsa_params.load_state_dict(distill_data['vae_state_dict'])
    image_syn = distill_data['syn_img']
    dsa_params.to(device) 
    image_syn.to(device)

    dsa_params.eval()
    eps = torch.zeros(size=(n_prototypes, z_dim), device=device)
    prototypes = dsa_params.forward_fixed(image_syn, eps)

    # plot prototypes
    if plot_prototypes:
        grid = torchvision.utils.make_grid(prototypes, nrow=10, normalize=True, scale_each=True)
        save_image(grid, os.path.join(save_path, f'prototype.png'))

        upsampled = torch.repeat_interleave(prototypes, repeats=4, dim=2)
        upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=3)
        grid = torchvision.utils.make_grid(upsampled, nrow=10, normalize=True, scale_each=True)
        save_image(grid, os.path.join(save_path, f'prototype_upsampled.png'))

    # TSNE
    if generate_TSNE:
        n_images = 1_000
        eps = torch.randn(n_images, image_syn.shape[1], device=device)
        augmented_z = torch.empty(n_prototypes, n_images, 64, device=device)
        for i in range(image_syn.shape[0]):
            std = image_syn[i, :, 1].mul(0.5).exp_()
            mean = image_syn[i, :, 0]
            augmented_z[i] = eps.mul(std).add_(mean)

        augmented_y = torch.arange(num_class)
        augmented_y = augmented_y.repeat_interleave(n_images * ipc)
        augmented_z = augmented_z.reshape(-1, 64)
        print('t-SNE started!')

        tsne = TSNE(n_components=2, verbose=1, perplexity=50, n_iter=300, learning_rate='auto', init='pca')
        tsne_results = tsne.fit_transform(augmented_z.detach().cpu().numpy())

        print('t-SNE done!')

        plt.figure(figsize=(16,10))
        sns.scatterplot(
            x=tsne_results[:,0], y=tsne_results[:,1],
            hue=augmented_y,
            palette=sns.color_palette("hls", 10),
            legend="full",
            alpha=1.0
        )
        plt.xlabel("TSNE_0")
        plt.ylabel("TSNE_1")
        plt.savefig(os.path.join(save_path, f'tsne.png'))

        print("TSNE done")

        # generate TSNE grid examples
        x=np.linspace(np.min(tsne_results[:,0]), np.max(tsne_results[:,0]), 10)
        y =np.linspace(np.min(tsne_results[:,1]), np.max(tsne_results[:,1]), 10)
        xv, yv = np.meshgrid(x, y, indexing='ij')
        grid = np.empty((len(x), len(y)), dtype=int)
        for i in range(len(x)):
            for j in range(len(y)):
                coord = (xv[i, j], yv[i, j])
                distance = np.linalg.norm(tsne_results - coord, axis=1)
                grid[i, j] = np.argmin(distance)
                
        grid = grid.flatten()
        grid_images = dsa_params.forward_given(augmented_z[grid])
        grid_images = torch.repeat_interleave(grid_images, repeats=2, dim=2)
        grid_images = torch.repeat_interleave(grid_images, repeats=2, dim=3)
        coordinatesList = tsne_results[grid]

        plt.figure()
        ax = plt.gca()
        scale=60
        ax.set_xlim(np.min(tsne_results[:,0])*scale*1.2, np.max(tsne_results[:,0])*scale*1.2)
        ax.set_ylim(np.min(tsne_results[:,1])*scale*1.2, np.max(tsne_results[:,1])*scale*1.2)
        sns.scatterplot(
            x=tsne_results[:,0]*scale, y=tsne_results[:,1]*scale,
            hue=augmented_y,
            palette=sns.color_palette("hls", 10),
            legend="full",
            alpha=0.01,
            s=80
        )
        for i in range(len(grid_images)):
            imageFile = grid_images[i].detach().cpu().numpy()
            imageFile = np.transpose(imageFile, (1, 2, 0))
            tx, ty = coordinatesList[i] * scale
            ax.imshow(imageFile, extent=(tx-32, tx + 32, ty-32, ty + 32))
        
        plt.savefig(os.path.join(save_path, f'tsne_grid.png'), dpi=800)
    
    # generate random examples (naive)
    if generate_random_examples:
        n_images = 5000 // ipc
        augmented_images = torch.empty(n_prototypes, n_images, 3, 32, 32)
        # augmented_images[:, 0] = prototypes
        for i in range(image_syn.shape[0]):
            eps = torch.randn(n_images, image_syn.shape[1], device=device)
            std = image_syn[i, :, 1].mul(0.5).exp_()
            mean = image_syn[i, :, 0]
            deviation = eps.mul(std)
            augmented_images[i, :] = dsa_params.forward_fixed(image_syn[i], deviation).cpu().detach()

        # augmented_images = augmented_images.flatten(0, 1)
        if ipc > 1:
            augmented_images = augmented_images.reshape(num_class, n_images*ipc, *augmented_images.shape[2:])
        augmented_images = torch.transpose(augmented_images, 2, 4)
        print(augmented_images.shape)
        
        all_images = []
        for i in range(augmented_images.shape[0]):
            image_class = []
            for j in range(augmented_images.shape[1]):
                img_array = augmented_images[i, j]
                img_array = (img_array.detach().cpu().numpy() * 255).astype(np.uint8)
                img = im.fromarray(img_array)
                image_class.append(img)
            
            all_images.append(image_class)
        print(len(all_images))
        pixel_vs = [image_utils.embedding_vendi_score(imgs, device="cuda")  for imgs in all_images]
        print("vendi score in feature space")
        for y, pvs in enumerate(pixel_vs): print(f"{pvs:.03f}")


    # generate prototype by only extending in a certain dimension 
    if ipc < 5 and generate_dimension_prototypes:
        for p in range(n_prototypes):
            augmented_images = torch.empty(64, 3, 32, 32, device=device)
            std = image_syn[p, :, 1].mul(0.5).exp_()
            mean = image_syn[p, :, 0]
            for i in range(64):
                deviation = torch.zeros(64, device=device)
                deviation[i] = 5. * std[i]
                augmented_images[i] = dsa_params.forward_fixed(image_syn[p], deviation)
            upsampled = torch.repeat_interleave(augmented_images, repeats=4, dim=2)
            upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=3)
            grid = torchvision.utils.make_grid(upsampled, nrow=8, normalize=True, scale_each=True)
            save_image(grid, os.path.join(save_path, f'prototype_single_dimension_augmented_{p}.png'))
        print("Dimension Augmentation done")

    # plot all dimensions as sequential pairs
    if generate_dimensions_scatter:
        std = image_syn[:, :, 1].mul(0.5).exp_()
        mean = image_syn[:, :, 0]
        n_images = 1000
        for i in range(0, 64, 2):
            eps = torch.randn(2, n_images, device=device)
            deviation = torch.zeros(n_prototypes, 2, n_images, device=device)
            deviation[:, 0, :] =  mean[:, i, None].repeat_interleave(n_images, dim=-1) + eps[0, :] * std[:, i, None].repeat_interleave(n_images, dim=-1)
            deviation[:, 1, :] =  mean[:, i+1, None].repeat_interleave(n_images, dim=-1) + eps[1, :] * std[:, i+1, None].repeat_interleave(n_images, dim=-1)
            deviation = deviation.detach().cpu().numpy()

            augmented_y = torch.arange(num_class)
            augmented_y = augmented_y.repeat_interleave(n_images * ipc)

            plt.figure(figsize=(16,10))
            sns.scatterplot(
                x=deviation[:,0, :].flatten(), y=deviation[:, 1, :].flatten(),
                hue=augmented_y,
                palette=sns.color_palette("hls", 10),
                legend="full",
                alpha=1.0
            )
            plt.xlabel(f"Dim {i}")
            plt.ylabel(f"Dim {i+1}")
            plt.savefig(os.path.join(save_path, f'dimension_{i}.png'))
            plt.close()

    return

# def plot_midpoints():
#     def midpoint(mean_1, mean_2, logvar_1, logvar_2, x):
#         sigma_1 = np.exp(logvar_1)
#         sigma_2 = np.exp(logvar_2)
#         det_ratio = np.sqrt(np.prod(sigma_2)/np.prod(sigma_1))
#         log_det_ratio = 2 * np.log(det_ratio)
#         distance = np.inner(1./sigma_1 * (x - mean_1), (x - mean_1)) - np.inner(1./sigma_2 * (x - mean_2), (x - mean_2))
#         return distance - log_det_ratio

#     image_syn_np = image_syn.detach().cpu().numpy()
#     estimate = (image_syn_np[0, :, 0] + image_syn_np[1, :, 0])/2
#     sol_x = root(partial(midpoint, image_syn_np[0, :, 0], image_syn_np[1, :, 0], image_syn_np[0, :, 1], image_syn_np[1, :, 1]), 
#                  estimate,
#                  method='excitingmixing')

#     sigma_1 = np.exp(image_syn_np[0, :, 1])
#     sigma_2 = np.exp(image_syn_np[1, :, 1])
#     check_solution = midpoint(image_syn_np[0, :, 0], image_syn_np[1, :, 0], image_syn_np[0, :, 1], image_syn_np[1, :, 1], sol_x.x)
#     # multivariate_normal_1 = np.sqrt((2 * np.pi)**64) * multivariate_normal.pdf(sol_x.x, mean=image_syn_np[0, :, 0], cov=sigma_1)
#     # multivariate_normal_2 = np.sqrt((2 * np.pi)**64) * multivariate_normal.pdf(sol_x.x, mean=image_syn_np[1, :, 0], cov=sigma_2)

#     multivariate_normal_1 = multivariate_normal.pdf(sol_x.x, mean=image_syn_np[0, :, 0], cov=sigma_1)
#     multivariate_normal_2 = multivariate_normal.pdf(sol_x.x, mean=image_syn_np[1, :, 0], cov=sigma_2)

#     print(check_solution, np.sqrt((2 * np.pi)**64), multivariate_normal_1, multivariate_normal_2)
#     print(sol_x)

if __name__ == "__main__":
    # dir_list = os.listdir('data/')
    dir_list = ['final_ipc5_zdim64_knum512_kl10.00.pt']
    for d in dir_list:
        print(d)
        main(checkpoint_name=d,
             plot_prototypes=False,
             generate_TSNE=False,
             generate_random_examples=True,
             generate_dimension_prototypes=True,
             generate_dimensions_scatter=False,
         ) 
