import argparse
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
import numpy as np
import time
import os
import visdom
from skimage.io import imsave
from PIL import Image

plt.ion()

from modules.autoencoder import VanillaAE
from modules.flow import Kernel_Perron_Frobenius
from modules.loss import mix_rbf_mmd, poly_mmd, LapLoss
from modules.utils import weighted_max, SubsetDataset
from modules.kernel_fn import *
from dataset import get_dataset

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms

from tqdm import tqdm

# device = 'cuda'

def get_samples(dataset, sample_size=-1):
    inds = np.random.permutation(len(dataset))
    if sample_size >= 0:
        inds = inds[:sample_size]
    return torch.stack([dataset[ind][0] for ind in inds]).float()

def pre_process(image):
    return image * 2 - 1

def post_process(logits):
    return torch.clamp((logits + 1) / 2, 0., 1.)

def train_ae(ae, train_loader, epochs, lr, lr_epochs, lr_frac, args, device=torch.device('cuda'), viz=None):
    ae = ae.to(device)
    optim = torch.optim.Adam(ae.parameters(), lr=lr)
    lr_lambda = lambda epoch: np.power(lr_frac, int(epoch) // lr_epochs)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda)
    criterion = lambda x, y: torch.mean(torch.sum((x - y).pow(2), dim=1))
    for epoch in range(epochs):
        train_iter = tqdm(train_loader)
        train_iter.set_description("Epoch {:d}".format(epoch + 1))
        for i, (X, y) in enumerate(train_iter):
            X = pre_process(X).float().to(device)

            noise = torch.randn_like(X).to(device) * args.ae_input_noise_var
            X_corrupted = X + noise

            logits, latents = ae(X_corrupted)
            rec_loss = criterion(logits, X)
            reg_loss = torch.cat(latents, dim=-1).pow(2).mean()
            loss = rec_loss + 1e-2 * reg_loss

            optim.zero_grad()
            loss.backward()
            optim.step()

            train_iter.set_postfix({'Rec Loss': rec_loss.item()})

        def _visualize(_viz):
            # Switch temporarily to eval mode
            ae.eval()

            with torch.no_grad():
                interp = ae.module.bilinear_interpolate(latents, steps=args.visdim[0])

            X_img = post_process(X).view(-1, *args.img_dim).detach().cpu()
            X_corrupted_img = post_process(X_corrupted).view(-1, *args.img_dim).detach().cpu()
            pred_img = post_process(logits).view(-1, *args.img_dim).detach().cpu()
            interp_img = post_process(interp).view(args.visdim[0] * args.visdim[0], *args.img_dim).detach().cpu()

            vis_count = np.prod(args.visdim)

            _viz.images(X_img[:vis_count], nrow=args.visdim[0], win='x', opts={'title': 'data'})
            _viz.images(X_corrupted_img[:vis_count], nrow=args.visdim[0], win='x_corrupted', opts={'title': 'data_corrupted'})
            _viz.images(pred_img[:vis_count], nrow=args.visdim[0], win='pred', opts={'title': 'prediction'})
            _viz.images(interp_img[:vis_count], nrow=args.visdim[0], win='interp', opts={'title': 'Interpolation'})

            ae.train()

        if viz is not None:
            _visualize(viz)

        scheduler.step()

        ae_state = ae.module.state_dict()
        torch.save(ae_state, os.path.join(args.save_dir, "{}-ae.pth".format(args.experiment_id)))

def main(args):
    device = 'cuda'

    if args.visualize:
        viz = visdom.Visdom(port=8097)
    else:
        viz = None

    def get_transform(img_dim):
        t = transforms.Compose([
            transforms.Resize((*img_dim[-2:],)),
            transforms.ToTensor(),
        ])
        return t

    train_ds = get_dataset(args, get_transform)
    train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, drop_last=True, pin_memory=True, num_workers=args.workers)

    # Train Autoencoder
    ae = SAE(args.img_dim,
             n_l=args.ae_layers, 
             depths=args.ae_depths, 
             encoding_depth=args.ae_encoding_depth, 
             scale_factor=args.ae_scale_factor,
             groups=args.ae_groups,
             use_sn=True)

    if args.ae_load_path is not None:
        ae_state = torch.load(args.ae_load_path)
        ae.load_state_dict(ae_state)
    else:
        print("AE load path is None. Training a new AE...")
        train_ae(nn.DataParallel(ae), train_loader, args.epochs, args.lr, args.lr_epochs, args.lr_frac, args, device=device, viz=viz)

    device = 'cuda:1'

    ae.to(device)

    latent_dim = args.ae_encoding_depth * args.ae_groups

    # kernel = BachKernel(latent_dim, 2, 15)
    # kernel = MixRBFKernel(latent_dim, sigma=[1/4])
    # kernel = PolynomialKernel(latent_dim, deg=2, c=1)
    # kernel = ArccosKernel(latent_dim, layers=1, deg=1, normalize=True)
    # kernel = RandFeatureKernel(latent_dim, 5000, activation=nn.Tanh, n_layers=2, kernel_fn=lambda c: MixRBFKernel(latent_dim, sigma=[1/128], normalize=True))
    # kernel = RandFeatureKernel(latent_dim, 5000, n_layers=2, kernel_fn=lambda c: ArccosKernel(latent_dim, deg=1, layers=1, normalize=True))
    kernel = NeuralKernel(latent_dim, 10000, n_layers=4, kernel_type='ntk', activation='erf')

    cached = args.pf_use_cached_features
    ignore_labels = True
    sample_size = args.pf_sample_size

    pf_train_ds = SubsetDataset(train_ds, sample_size, shuffle=True)

    _cache_file = os.path.join(args.cache_dir, "{:s}-kernelPF-features.pth".format(args.experiment_id))
    if cached and os.path.exists(_cache_file):
        print("Restoring cached features from {:s}...".format(_cache_file))
        features, labels = torch.load(_cache_file)
    else:
        print("Getting features from dataset...")
        loader = DataLoader(pf_train_ds, batch_size=args.batch_size, shuffle=False, drop_last=False, pin_memory=True, num_workers=args.workers)
        feature_list = []
        label_list = []
        with torch.no_grad():
            for i, (X, y) in enumerate(tqdm(loader)):
                X = pre_process(X).float().to(device)
                feature_batch = ae.encode(X)
                # feature_batch = ae.get_latents(X)[0]

                feature_list.append(feature_batch)
                if not ignore_labels:
                    if type(y) not in (list, tuple):
                        y = [y]
                    label_list.append(y)
        features = [torch.cat(f, dim=0) for f in zip(*feature_list)]
        labels = [torch.cat(l, dim=0) for l in zip(*label_list)]
        print("Cacheing features for later use...")
        torch.save((features, labels), _cache_file)

    start = time.time()
    model = Kernel_Perron_Frobenius(kernel, 
                                    features, 
                                    labels=None if ignore_labels else labels, 
                                    nystrom_compression=args.pf_use_compression, 
                                    nystrom_points=args.pf_nystrom_points, 
                                    preimage_method='gi', 
                                    epsilon=1e-12, 
                                    p_dim=-1, 
                                    device=device).to(device)
    print("Elapsed time: {:4f}".format(time.time() - start))
    topk = 10

    fig_samples = plt.figure()
    fig_samples.set_size_inches(15, 15)
    grid_samples = ImageGrid(fig_samples, 111, args.visdim, axes_pad=0.)
    fig_samples.suptitle("Sampled")

    fig_nearest = plt.figure()
    fig_nearest.set_size_inches(15, 3 * topk)
    grid_nearest = ImageGrid(fig_nearest, 111, (5, topk + 1), axes_pad=0.)
    fig_nearest.tight_layout()

    figs = [fig_samples, fig_nearest]

    def imshow_axis(ax, img):
        ax.clear()
        ax.imshow(img, cmap=plt.cm.gray)
        ax.axis('off')
    
    print("Generating samples")
    model.eval()
    # ae.eval()
    for i in tqdm(range(10)):
        latent, latent_inds = model.sample(np.prod(args.visdim), topk)
        samples = ae.decode(latent)
        samples = post_process(samples).view(-1, *args.img_dim).permute(0, 2, 3, 1).detach().cpu().numpy()

        for j, ax in enumerate(np.ravel(grid_samples)):
            imshow_axis(ax, np.squeeze(np.clip(samples[j], 0., 1.)))
        
        for j in range(5):
            imshow_axis(grid_nearest[j * (topk + 1)], np.squeeze(np.clip(samples[j], 0., 1.)))
            for k in range(topk):
                imshow_axis(grid_nearest[j * (topk + 1) + k + 1], pf_train_ds[latent_inds[j, k]][0].permute(1, 2, 0).detach().cpu().numpy())
        
        for f in figs:
            f.canvas.draw()
            f.canvas.flush_events()

        sample_file = os.path.join(args.sample_dir, "sample_{:02d}.png".format(i))
        fig_samples.savefig(sample_file)

        nearest_file = os.path.join(args.sample_dir, "nearest_{:02d}.png".format(i))
        fig_nearest.savefig(nearest_file)
    
    img_subdir = os.path.join(args.sample_dir, "images")
    if not os.path.exists(img_subdir):
        os.mkdir(img_subdir)
    counter = 0
    latent_list = []

    elapsed_time = 0.
    for i in tqdm(range(200)):
        with torch.no_grad():
            start = time.time()
            latent, _ = model.sample(50)
            elapsed_time += time.time() - start
        latent_list.append(latent)
        samples = ae.decode(latent)
        samples = post_process(samples).view(-1, *args.img_dim).permute(0, 2, 3, 1).detach().cpu().numpy()

        for s in samples:
            imsave(os.path.join(img_subdir, "sampled_{:d}.png".format(counter)), np.round(s * 255).astype(np.uint8))
            counter += 1
    print("Average sampling time {:f}".format(elapsed_time / (10000)))
    
    img_subdir = os.path.join(args.sample_dir, "images_rec")
    if not os.path.exists(img_subdir):
        os.mkdir(img_subdir)
    counter = 0
    for i in tqdm(range(200)):
        latent = [_y[i * 50: (i + 1) * 50] for _y in model.y]
        samples = ae.decode(latent)
        samples = post_process(samples).view(-1, *args.img_dim).permute(0, 2, 3, 1).detach().cpu().numpy()

        for s in samples:
            imsave(os.path.join(img_subdir, "sampled_{:d}.png".format(counter)), np.round(s * 255).astype(np.uint8))
            counter += 1
    
    all_y = torch.cat([_y.view(model.y_size, -1) for _y in model.y], dim=-1)
    _tmp = [torch.cat(l, dim=0) for l in zip(*latent_list)]
    all_latent = torch.cat([_t.view(_t.size(0), -1) for _t in _tmp], dim=-1)
    mmd_score = poly_mmd(all_y, all_latent, deg=3)
    print("MMD:", mmd_score.item())


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # Paths and names
    parser.add_argument("--experiment_id", default="vanilla", type=str)
    parser.add_argument("--save_dir", default="./saved_models", type=str)
    parser.add_argument("--save_freq", default=1, type=int)
    parser.add_argument("--cache_dir", default="./cache", type=str)
    parser.add_argument("--sample_dir", default="./samples", type=str)

    # Misc params
    parser.add_argument("--visdim", default=(8, 8), type=int, nargs='+')
    parser.add_argument("--visualize", default=False, action="store_true")

    # Dataset params
    parser.add_argument("--dataset", default="mnist", type=str)
    parser.add_argument("--workers", default=8, type=int)

    # Autoencoder params
    parser.add_argument("--ae_load_path", default=None, type=str)
    parser.add_argument("--ae_depths", default=[128, 256, 512, 1024], nargs='+', type=int)
    parser.add_argument("--ae_scale_factor", default=1, type=int)
    parser.add_argument("--ae_layers", default=[1, 1, 1], nargs='+', type=int)
    parser.add_argument("--ae_encoding_depth", default=64, type=int)
    parser.add_argument("--ae_groups", default=1, type=int)
    parser.add_argument("--ae_input_noise_var", default=0., type=float)

    # Flow model params
    parser.add_argument("--pf_use_cached_features", default=False, action='store_true')
    parser.add_argument("--pf_sample_size", default=10000, type=int)
    parser.add_argument("--pf_use_compression", default=False, action='store_true')
    parser.add_argument("--pf_nystrom_points", default=4096, type=int)

    # Training hyperparamters
    parser.add_argument("--batch_size", default=256, type=int)
    parser.add_argument("--epochs", default=100, type=int)
    parser.add_argument("--lr", default=1e-4, type=float)
    parser.add_argument("--lr_epochs", default=30, type=int)
    parser.add_argument("--lr_frac", default=0.5)
    
    args = parser.parse_args()

    args.sample_dir = os.path.join(args.sample_dir, args.experiment_id)
    
    for _dir in (args.save_dir, args.cache_dir, args.sample_dir):
        os.makedirs(_dir, exist_ok=True)
    print(args)

    main(args)