from modules.autoencoder import SAE
from modules.flow import Kernel_Perron_Frobenius
from modules.loss import mix_rbf_mmd, poly_mmd, LapLoss
from modules.utils import weighted_max, SubsetDataset
from modules.kernel_fn import *
import argparse
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import time
import os
import visdom
import plotly.graph_objects as go

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from tqdm import tqdm

device = 'cuda'

def get_samples(dataset, sample_size=-1):
    inds = np.random.permutation(len(dataset))
    if sample_size >= 0:
        inds = inds[:sample_size]
    return torch.stack([dataset[ind][0] for ind in inds]).float()

def pre_process(image):
    return image * 2 - 1

def post_process(logits):
    return torch.clamp((logits + 1) / 2, 0., 1.)

def main(args):
    viz = visdom.Visdom(port=8097)

    args.img_dim = (1, 32, 32)
    args.visdim = (8, 8)

    t = transforms.Compose([
        transforms.Resize((*args.img_dim[-2:],)),
        transforms.ToTensor(),
    ])
    train_ds = datasets.MNIST("data", train=True, transform=t, download=True)
    train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, drop_last=True, pin_memory=True, num_workers=8)

    ae = SAE(args.img_dim,
             n_l=args.ae_layers, 
             depths=args.ae_depths, 
             encoding_depth=3, 
             scale_factor=1,
             groups=1)

    if args.ae_load_path is not None:
        ae_state = torch.load(args.ae_load_path)
        ae.load_state_dict(ae_state)
    else:
        print("AE load path is None. Training a new AE...")
        from train import train_ae
        train_ae(nn.DataParallel(ae), train_loader, args.epochs, args.lr, args.lr_epochs, args.lr_frac, args, device=device, viz=viz)

    device = 'cuda:1'

    ae.to(device)

    latent_dim = 3

    cached = args.pf_use_cached_features
    ignore_labels = False
    sample_size = args.pf_sample_size

    _cache_file = os.path.join(args.cache_dir, "{:s}-kernelPF-features.pth".format(args.experiment_id))
    if cached and os.path.exists(_cache_file):
        print("Restoring cached features from {:s}...".format(_cache_file))
        features, labels = torch.load(_cache_file)
    else:
        print("Getting features from dataset...")
        loader = DataLoader(SubsetDataset(train_ds, sample_size, shuffle=True), batch_size=args.batch_size, drop_last=False, pin_memory=True, num_workers=4)
        feature_list = []
        label_list = []
        with torch.no_grad():
            for i, (X, y) in enumerate(tqdm(loader)):
                X = pre_process(X).float().to(device)
                feature_batch = ae.encode(X)

                feature_list.append(feature_batch)
                if not ignore_labels:
                    if type(y) not in (list, tuple):
                        y = [y]
                    label_list.append(y)
        features = [torch.cat(f, dim=0) for f in zip(*feature_list)]
        labels = [torch.cat(l, dim=0) for l in zip(*label_list)]
        print("Cacheing features for later use...")
        torch.save((features, labels), _cache_file)

    y_coord = features[0]
    y_class = labels[0]
    from modules.utils import unit_sphere_to_cartesian_3d
    y_coord_cartesian = unit_sphere_to_cartesian_3d(y_coord).detach().cpu().numpy()

    f, ax = plt.subplots()
    f.set_size_inches(6, 6, forward=True)
    f.tight_layout()
    sns.scatterplot(x=y_coord_cartesian[:, 0], y=y_coord_cartesian[:, 1], hue=y_class.detach().cpu().numpy().astype(str), hue_order=np.arange(0, 10).astype(str), palette="deep", ax=ax)
    ax.axis('off')
    ax.set_title("Data latent points", fontsize=24, fontname='Sans')

    plt.show()

    kernels = [
        # (MixRBFKernel(latent_dim, sigma=[1/4]), "RBF", 1e-7), 
        # (ArccosKernel(latent_dim, layers=1, deg=1, normalize=False), "Arc-cosine", 1e-6),
        # (ArccosKernel(latent_dim, layers=1, deg=2, normalize=False), "Arc-cosine", 1e-6),
        (NeuralKernel(latent_dim, 10000, n_layers=2, kernel_type='ntk'), "NTK", 1e-6)
    ]


    for k, kname, eps in kernels:
        model = Kernel_Perron_Frobenius(k, 
                                        features, 
                                        labels=None if ignore_labels else labels, 
                                        preimage_method='gi', 
                                        nystrom_compression=False, 
                                        epsilon=eps, 
                                        p_dim=-1, 
                                        device=device).to(device)
        input()

        print("Generating samples (full)")
        model.eval()
        sample_latent = []
        sample_latent_class = []
        for i in tqdm(range(100)):
            latent, latent_top_inds = model.sample(100)
            samples = ae.decode(latent)
            samples = post_process(samples).view(-1, *args.img_dim).permute(0, 2, 3, 1).detach().cpu().numpy()
            
            latent_coord = latent[0].squeeze()
            sample_latent.append(latent_coord)
            sample_latent_class.append(y_class[latent_top_inds])

        sample_cart_coord = unit_sphere_to_cartesian_3d(torch.cat(sample_latent, dim=0)).detach().cpu().numpy()
        sample_latent_class = torch.cat(sample_latent_class, dim=0).detach().cpu().numpy()

        f, ax = plt.subplots()
        f.set_size_inches(6, 6, forward=True)
        f.tight_layout()
        sns.scatterplot(x=sample_cart_coord[:, 0], y=sample_cart_coord[:, 1], hue=sample_latent_class.astype(str), hue_order=np.arange(0, 10).astype(str), palette="deep", ax=ax, legend=False)
        ax.axis('off')
        ax.set_title("Sampled points ({:s})".format(kname), fontsize=24, fontname='Sans')

        plt.show()

        # model = Kernel_Perron_Frobenius(kernel, 
        #                             features, 
        #                             labels=None if ignore_labels else labels, 
        #                             preimage_method='gi', 
        #                             nystrom_compression=True, 
        #                             nystrom_points=2000, 
        #                             epsilon=1e-7, 
        #                             p_dim=-1, 
        #                             device=device).to(device)

        # nystrom_y_coord = unit_sphere_to_cartesian_3d(y_coord[model.operator.y_nystrom_point_inds]).detach().cpu().numpy()
        # nystrom_y_class = y_class[model.operator.y_nystrom_point_inds].detach().cpu().numpy()
        # sns.scatterplot(x=nystrom_y_coord[:, 0], y=nystrom_y_coord[:, 1], hue=nystrom_y_class.astype(str), hue_order=np.arange(0, 10).astype(str), palette="deep", ax=ax[2], legend=False)
        # ax[2].axis('off')
        # ax[2].set_title("Nyström points", fontsize=24, fontname='Sans')
        
        # print("Generating samples (Nyström)")
        # model.eval()
        # sample_latent = []
        # sample_latent_class = []
        # for i in tqdm(range(100)):
        #     latent, latent_top_inds = model.sample(100)
        #     samples = ae.decode(latent)
        #     samples = post_process(samples).view(-1, *args.img_dim).permute(0, 2, 3, 1).detach().cpu().numpy()
            
        #     latent_coord = latent[0].squeeze()
        #     sample_latent.append(latent_coord)
        #     sample_latent_class.append(y_class[latent_top_inds])

        # sample_cart_coord = unit_sphere_to_cartesian_3d(torch.cat(sample_latent, dim=0)).detach().cpu().numpy()
        # sample_latent_class = torch.cat(sample_latent_class, dim=0).detach().cpu().numpy()
        # sns.scatterplot(x=sample_cart_coord[:, 0], y=sample_cart_coord[:, 1], hue=sample_latent_class.astype(str), hue_order=np.arange(0, 10).astype(str), palette="deep", ax=ax[3], legend=False)
        # ax[3].axis('off')
        # ax[3].set_title("Sampled latent points (Nyström)", fontsize=24, fontname='Sans')

        plt.show()

        all_y = torch.cat([_y.view(model.y_size, -1) for _y in model.y], dim=-1)
        all_latent = torch.cat(sample_latent, dim=0).to(device)
        mmd_score = poly_mmd(all_y, all_latent, deg=3)
        print("MMD:", mmd_score.item())


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # Paths and names
    parser.add_argument("--experiment_id", default="vanilla", type=str)
    parser.add_argument("--save_dir", default="./saved_models", type=str)
    parser.add_argument("--save_freq", default=10, type=int)
    parser.add_argument("--cache_dir", default="./cache", type=str)
    parser.add_argument("--sample_dir", default="./samples", type=str)

    # Autoencoder params
    parser.add_argument("--ae_load_path", default=None, type=str)
    parser.add_argument("--ae_depths", default=[32, 64, 128, 256], nargs='+', type=int)
    parser.add_argument("--ae_layers", default=[1, 1, 1], nargs='+', type=int)
    parser.add_argument("--ae_input_noise_var", default=0., type=float)

    # Flow model params
    parser.add_argument("--pf_use_cached_features", default=False, action='store_true')
    parser.add_argument("--pf_use_compression", default=False, action='store_true')
    parser.add_argument("--pf_sample_size", default=10000, type=int)

    # Training hyperparamters
    parser.add_argument("--batch_size", default=256, type=int)
    parser.add_argument("--epochs", default=100, type=int)
    parser.add_argument("--lr", default=1e-4, type=float)
    parser.add_argument("--lr_epochs", default=30, type=int)
    parser.add_argument("--lr_frac", default=0.5)
    
    args = parser.parse_args()

    args.sample_dir = os.path.join(args.sample_dir, args.experiment_id)
    
    for _dir in (args.save_dir, args.cache_dir, args.sample_dir):
        os.makedirs(_dir, exist_ok=True)

    main(args)