import argparse
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
import numpy as np
import time
import os
import visdom
from skimage.io import imsave
from PIL import Image

plt.ion()

from modules.autoencoder import VanillaAE
from modules.flow import Kernel_Perron_Frobenius
from modules.loss import mix_rbf_mmd, poly_mmd, LapLoss
from modules.utils import weighted_max, SubsetDataset
from modules.kernel_fn import *
from dataset import CelebA64Dataset

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from tqdm import tqdm

# device = 'cuda'

def get_samples(dataset, sample_size=-1):
    inds = np.random.permutation(len(dataset))
    if sample_size >= 0:
        inds = inds[:sample_size]
    return torch.stack([dataset[ind][0] for ind in inds]).float()

def pre_process(image):
    return image * 2 - 1

def post_process(logits):
    return torch.clamp((logits + 1) / 2, 0., 1.)

def train_ae(ae, train_loader, epochs, lr, lr_epochs, lr_frac, args, device=torch.device('cuda'), visualize=True, viz=None):
    assert not visualize or viz is not None

    ae = ae.to(device)
    optim = torch.optim.Adam(ae.parameters(), lr=lr)
    lr_lambda = lambda epoch: np.power(lr_frac, int(epoch) // lr_epochs)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda)
    criterion = lambda x, y: torch.mean(torch.sum((x - y).pow(2), dim=1))
    for epoch in range(epochs):
        for i, (X, y) in enumerate(train_loader):
            X = pre_process(X).float().to(device)

            noise = torch.randn_like(X).to(device) * args.ae_input_noise_var
            X_corrupted = X + noise

            logits, latents = ae(X_corrupted)
            rec_loss = criterion(logits, X)
            reg_loss = torch.norm(torch.cat(latents, dim=1), p=2, dim=1).mean()
            loss = rec_loss + reg_loss

            optim.zero_grad()
            loss.backward()
            optim.step()

            print("Epoch {:d}, Step {:d}, Rec Loss: {:.4f}, Reg Loss: {:.4f}".format(epoch + 1, i, rec_loss, reg_loss))

            def _visualize():
                # Switch temporarily to eval mode
                ae.eval()

                with torch.no_grad():
                    interp = ae.module.bilinear_interpolate(latents, steps=args.visdim[0])

                X_img = post_process(X).view(-1, *args.img_dim).detach().cpu()
                X_corrupted_img = post_process(X_corrupted).view(-1, *args.img_dim).detach().cpu()
                pred_img = post_process(logits).view(-1, *args.img_dim).detach().cpu()
                interp_img = post_process(interp).view(args.visdim[0] * args.visdim[0], *args.img_dim).detach().cpu()

                vis_count = np.prod(args.visdim)

                viz.images(X_img[:vis_count], nrow=args.visdim[0], win='x', opts={'title': 'data'})
                viz.images(X_corrupted_img[:vis_count], nrow=args.visdim[0], win='x_corrupted', opts={'title': 'data_corrupted'})
                viz.images(pred_img[:vis_count], nrow=args.visdim[0], win='pred', opts={'title': 'prediction'})
                viz.images(interp_img[:vis_count], nrow=args.visdim[0], win='interp', opts={'title': 'Interpolation'})

                ae.train()

            if visualize and i % 500 == 0:
                _visualize()
        if visualize:
            _visualize()

        scheduler.step()

        ae_state = ae.module.state_dict()
        torch.save(ae_state, os.path.join(args.save_dir, "{}-ae.pth".format(args.experiment_id)))

def main(args):
    device = 'cuda'

    viz = visdom.Visdom(port=8097)

    def get_transform(img_dim):
        t = transforms.Compose([
            transforms.Resize((*img_dim[-2:],)),
            transforms.ToTensor(),
        ])
        return t

    if args.dataset == "mnist":
        args.img_dim = (1, 32, 32)
        train_ds = datasets.MNIST("data", train=True, transform=get_transform(args.img_dim), download=True)
    elif args.dataset == "lsun_bedroom":
        args.img_dim = (3, 256, 256)
        train_ds = SubsetDataset(datasets.LSUN("data/LSUN", classes=['bedroom_train'], transform=get_transform(args.img_dim), target_transform=None), subset_size=60000)
    elif args.dataset == "cifar":
        args.img_dim = (3, 32, 32)
        train_ds = datasets.CIFAR10("data", train=True, transform=get_transform(args.img_dim), download=True)
    elif args.dataset == "celeba256":
        args.img_dim = (3, 256, 256)
        train_ds = datasets.CelebA("data", split='train', transform=get_transform(args.img_dim), download=True)
    elif args.dataset == "celeba64":
        args.img_dim = (3, 64, 64)
        train_ds = CelebA64Dataset("data", transform=get_transform(args.img_dim))
    else:
        raise ValueError("Dataset {:s} not available".format(args.dataset))
    train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, drop_last=True, pin_memory=True, num_workers=args.workers)

    # Train Autoencoder
    ae = VanillaAE(args.img_dim,
             n_l=args.ae_layers, 
             depths=args.ae_depths, 
             encoding_depth=args.ae_encoding_depth, 
             scale_factor=args.ae_scale_factor,
             groups=args.ae_groups)
    # ae = VAE(args.img_dim,
    #          n_l=args.ae_layers, 
    #          depths=args.ae_depths, 
    #          encoding_depth=args.ae_encoding_depth, 
    #          scale_factor=args.ae_scale_factor,
    #          groups=args.ae_groups)

    if args.ae_load_path is not None:
        ae_state = torch.load(args.ae_load_path)
        ae.load_state_dict(ae_state)
    else:
        print("AE load path is None. Training a new AE...")
        train_ae(nn.DataParallel(ae), train_loader, args.epochs, args.lr, args.lr_epochs, args.lr_frac, args, device=device, viz=viz)

    device = 'cuda:1'

    ae.to(device)

    latent_dim = args.ae_encoding_depth * args.ae_groups
    # kernel = BachKernel(latent_dim, 2, 15)
    # kernel = MixRBFKernel(latent_dim)
    # kernel = PolynomialKernel(latent_dim, deg=2, c=1)
    # kernel = ArccosKernel(latent_dim, layers=1, deg=1, normalize=False)
    # kernel = RandFeatureKernel(latent_dim, 5000, activation=nn.Tanh, n_layers=2, kernel_fn=lambda c: MixRBFKernel(latent_dim, sigma=[1/128], normalize=True))
    # kernel = RandFeatureKernel(latent_dim, 5000, n_layers=2, kernel_fn=lambda c: ArccosKernel(latent_dim, deg=1, layers=1, normalize=True))
    kernel = NeuralKernel(latent_dim, 10000, n_layers=4, kernel_type='ntk', activation='relu')
    # kernel = NeuralKernel(latent_dim, 10000, n_layers=4, kernel_type='nngp')

    cached = args.pf_use_cached_features
    ignore_labels = True
    sample_size = args.pf_sample_size

    _cache_file = os.path.join(args.cache_dir, "{:s}-kernelPF-features.pth".format(args.experiment_id))
    if cached and os.path.exists(_cache_file):
        print("Restoring cached features from {:s}...".format(_cache_file))
        features, labels = torch.load(_cache_file)
    else:
        print("Getting features from dataset...")
        loader = DataLoader(SubsetDataset(train_ds, sample_size, shuffle=True), batch_size=args.batch_size, drop_last=False, pin_memory=True, num_workers=args.workers)
        feature_list = []
        label_list = []
        with torch.no_grad():
            for i, (X, y) in enumerate(tqdm(loader)):
                X = pre_process(X).float().to(device)
                feature_batch = ae.encode(X)
                # feature_batch = ae.get_latents(X)[0]

                feature_list.append(feature_batch)
                if not ignore_labels:
                    if type(y) not in (list, tuple):
                        y = [y]
                    label_list.append(y)
        features = [torch.cat(f, dim=0) for f in zip(*feature_list)]
        labels = [torch.cat(l, dim=0) for l in zip(*label_list)]
        print("Cacheing features for later use...")
        torch.save((features, labels), _cache_file)

    model = Kernel_Perron_Frobenius(kernel, 
                                    features, 
                                    labels=None if ignore_labels else labels, 
                                    nystrom_compression=args.pf_use_compression, 
                                    nystrom_points=args.pf_nystrom_points, 
                                    preimage_method='fm', 
                                    epsilon=5e-8, 
                                    p_dim=-1, 
                                    device=device).to(device)

    fig_samples = plt.figure()
    fig_samples.set_size_inches(15, 15)
    grid_samples = ImageGrid(fig_samples, 111, args.visdim, axes_pad=0.)
    fig_samples.suptitle("Sampled")

    figs = [fig_samples]
    
    print("Generating samples")
    model.eval()
    ae.eval()
    for i in tqdm(range(10)):
        latent, _ = model.sample(np.prod(args.visdim))
        samples = ae.decode(latent)
        samples = post_process(samples).view(-1, *args.img_dim).permute(0, 2, 3, 1).detach().cpu().numpy()

        for j, ax in enumerate(np.ravel(grid_samples)):
            ax.clear()
            ax.imshow(np.squeeze(np.squeeze(np.clip(samples[j], 0., 1.))), cmap=plt.cm.gray)
            ax.axis('off')
        
        for f in figs:
            f.canvas.draw()
            f.canvas.flush_events()

        sample_file = os.path.join(args.sample_dir, "sample_{:02d}.png".format(i))
        fig_samples.savefig(sample_file)
    
    img_subdir = os.path.join(args.sample_dir, "images")
    if not os.path.exists(img_subdir):
        os.mkdir(img_subdir)
    counter = 0
    latent_list = []
    for i in tqdm(range(200)):
        with torch.no_grad():
            latent, _ = model.sample(50)
        latent_list.append(latent)
        samples = ae.decode(latent)
        samples = post_process(samples).view(-1, *args.img_dim).permute(0, 2, 3, 1).detach().cpu().numpy()

        for s in samples:
            imsave(os.path.join(img_subdir, "sampled_{:d}.png".format(counter)), np.round(s * 255).astype(np.uint8))
            counter += 1
    
    img_subdir = os.path.join(args.sample_dir, "images_rec")
    if not os.path.exists(img_subdir):
        os.mkdir(img_subdir)
    counter = 0
    for i in tqdm(range(200)):
        latent = [_y[i * 50: (i + 1) * 50] for _y in model.y]
        samples = ae.decode(latent)
        samples = post_process(samples).view(-1, *args.img_dim).permute(0, 2, 3, 1).detach().cpu().numpy()

        for s in samples:
            imsave(os.path.join(img_subdir, "sampled_{:d}.png".format(counter)), np.round(s * 255).astype(np.uint8))
            counter += 1
    
    all_y = torch.cat([_y.view(model.y_size, -1) for _y in model.y], dim=-1)
    _tmp = [torch.cat(l, dim=0) for l in zip(*latent_list)]
    all_latent = torch.cat([_t.view(_t.size(0), -1) for _t in _tmp], dim=-1)
    mmd_score = poly_mmd(all_y, all_latent, deg=3)
    print("MMD:", mmd_score.item())


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # Paths and names
    parser.add_argument("--experiment_id", default="vanilla", type=str)
    parser.add_argument("--save_dir", default="./saved_models", type=str)
    parser.add_argument("--save_freq", default=1, type=int)
    parser.add_argument("--cache_dir", default="./cache", type=str)
    parser.add_argument("--sample_dir", default="./samples", type=str)

    # Misc params
    parser.add_argument("--visdim", default=(8, 8), type=int, nargs='+')

    # Dataset params
    parser.add_argument("--dataset", default="mnist", type=str)
    parser.add_argument("--workers", default=8, type=int)

    # Autoencoder params
    parser.add_argument("--ae_load_path", default=None, type=str)
    parser.add_argument("--ae_depths", default=[32, 64, 128, 256, 512], nargs='+', type=int)
    parser.add_argument("--ae_scale_factor", default=1, type=int)
    parser.add_argument("--ae_layers", default=[1, 1, 1, 1], nargs='+', type=int)
    parser.add_argument("--ae_encoding_depth", default=128, type=int)
    parser.add_argument("--ae_groups", default=1, type=int)
    parser.add_argument("--ae_input_noise_var", default=0., type=float)

    # Flow model params
    parser.add_argument("--pf_use_cached_features", default=False, action='store_true')
    parser.add_argument("--pf_sample_size", default=10000, type=int)
    parser.add_argument("--pf_use_compression", default=False, action='store_true')
    parser.add_argument("--pf_nystrom_points", default=4096, type=int)

    # Training hyperparamters
    parser.add_argument("--batch_size", default=256, type=int)
    parser.add_argument("--epochs", default=100, type=int)
    parser.add_argument("--lr", default=1e-4, type=float)
    parser.add_argument("--lr_epochs", default=30, type=int)
    parser.add_argument("--lr_frac", default=0.5)
    
    args = parser.parse_args()

    args.sample_dir = os.path.join(args.sample_dir, args.experiment_id)
    
    for _dir in (args.save_dir, args.cache_dir, args.sample_dir):
        os.makedirs(_dir, exist_ok=True)

    main(args)