import argparse
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
import numpy as np
import time
import os
import visdom
from skimage.io import imsave
from PIL import Image

plt.ion()

from modules.autoencoder import SAE, VAE
from modules.flow import Glow1d
from modules.loss import mix_rbf_mmd, poly_mmd, LapLoss
from modules.utils import weighted_max, SubsetDataset
from modules.kernel_fn import *
from dataset import CelebA64Dataset

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from torchvision import datasets, transforms

from tqdm import tqdm

# device = 'cuda'

def get_samples(dataset, sample_size=-1):
    inds = np.random.permutation(len(dataset))
    if sample_size >= 0:
        inds = inds[:sample_size]
    return torch.stack([dataset[ind][0] for ind in inds]).float()

def pre_process(image):
    return image * 2 - 1

def post_process(logits):
    return torch.clamp((logits + 1) / 2, 0., 1.)

def train_ae(ae, train_loader, epochs, lr, lr_epochs, lr_frac, args, device=torch.device('cuda'), visualize=True, viz=None):
    assert not visualize or viz is not None

    ae = ae.to(device)
    optim = torch.optim.Adam(ae.parameters(), lr=lr)
    lr_lambda = lambda epoch: np.power(lr_frac, int(epoch) // lr_epochs)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda)
    criterion = lambda x, y: torch.mean(torch.sum((x - y).pow(2), dim=1))
    for epoch in range(epochs):
        for i, (X, y) in enumerate(train_loader):
            X = pre_process(X).float().to(device)

            noise = torch.randn_like(X).to(device) * args.ae_input_noise_var
            X_corrupted = X + noise

            logits, latents = ae(X_corrupted)
            rec_loss = criterion(logits, X)
            loss = rec_loss

            optim.zero_grad()
            loss.backward()
            optim.step()

            print("Epoch {:d}, Step {:d}, Rec Loss: {:.4f}".format(epoch + 1, i, rec_loss))

            def _visualize():
                # Switch temporarily to eval mode
                ae.eval()

                with torch.no_grad():
                    interp = ae.module.bilinear_interpolate(latents, steps=args.visdim[0])

                X_img = post_process(X).view(-1, *args.img_dim).detach().cpu()
                X_corrupted_img = post_process(X_corrupted).view(-1, *args.img_dim).detach().cpu()
                pred_img = post_process(logits).view(-1, *args.img_dim).detach().cpu()
                interp_img = post_process(interp).view(args.visdim[0] * args.visdim[0], *args.img_dim).detach().cpu()

                vis_count = np.prod(args.visdim)

                viz.images(X_img[:vis_count], nrow=args.visdim[0], win='x', opts={'title': 'data'})
                viz.images(X_corrupted_img[:vis_count], nrow=args.visdim[0], win='x_corrupted', opts={'title': 'data_corrupted'})
                viz.images(pred_img[:vis_count], nrow=args.visdim[0], win='pred', opts={'title': 'prediction'})
                viz.images(interp_img[:vis_count], nrow=args.visdim[0], win='interp', opts={'title': 'Interpolation'})

                ae.train()

            if visualize and i % 500 == 0:
                _visualize()
        if visualize:
            _visualize()

        scheduler.step()

        ae_state = ae.module.state_dict()
        torch.save(ae_state, os.path.join(args.save_dir, "{}-ae.pth".format(args.experiment_id)))

def main(args):
    device = 'cuda'

    viz = visdom.Visdom(port=8097)

    def get_transform(img_dim):
        t = transforms.Compose([
            transforms.Resize((*img_dim[-2:],)),
            transforms.ToTensor(),
        ])
        return t

    if args.dataset == "mnist":
        args.img_dim = (1, 32, 32)
        train_ds = datasets.MNIST("data", train=True, transform=get_transform(args.img_dim), download=True)
    elif args.dataset == "lsun_bedroom":
        args.img_dim = (3, 256, 256)
        train_ds = SubsetDataset(datasets.LSUN("data/LSUN", classes=['bedroom_train'], transform=get_transform(args.img_dim), target_transform=None), subset_size=60000)
    elif args.dataset == "cifar":
        args.img_dim = (3, 32, 32)
        train_ds = datasets.CIFAR10("data", train=True, transform=get_transform(args.img_dim), download=True)
    elif args.dataset == "celeba256":
        args.img_dim = (3, 256, 256)
        train_ds = datasets.CelebA("data", split='train', transform=get_transform(args.img_dim), download=True)
    elif args.dataset == "celeba64":
        args.img_dim = (3, 64, 64)
        train_ds = CelebA64Dataset("data", transform=get_transform(args.img_dim))
    else:
        raise ValueError("Dataset {:s} not available".format(args.dataset))
    train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, drop_last=True, pin_memory=True, num_workers=args.workers)

    # Train Autoencoder
    ae = SAE(args.img_dim,
             n_l=args.ae_layers, 
             depths=args.ae_depths, 
             encoding_depth=args.ae_encoding_depth, 
             scale_factor=args.ae_scale_factor,
             groups=args.ae_groups,
             use_sn=True)
    # ae = VAE(args.img_dim,
    #          n_l=args.ae_layers, 
    #          depths=args.ae_depths, 
    #          encoding_depth=args.ae_encoding_depth, 
    #          scale_factor=args.ae_scale_factor,
    #          groups=args.ae_groups)

    if args.ae_load_path is not None:
        ae_state = torch.load(args.ae_load_path)
        ae.load_state_dict(ae_state)
    else:
        print("AE load path is None. Training a new AE...")
        train_ae(nn.DataParallel(ae), train_loader, args.epochs, args.lr, args.lr_epochs, args.lr_frac, args, device=device, viz=viz)

    device = 'cuda:1'

    ae.to(device)

    latent_dim = args.ae_encoding_depth * args.ae_groups

    ignore_labels = True
    sample_size = args.pf_sample_size

    flow_train_ds = SubsetDataset(train_ds, sample_size, shuffle=True)

    print("Getting features from dataset...")
    loader = DataLoader(flow_train_ds, batch_size=args.batch_size, shuffle=False, drop_last=False, pin_memory=True, num_workers=args.workers)
    feature_list = []
    label_list = []
    with torch.no_grad():
        for i, (X, y) in enumerate(tqdm(loader)):
            X = pre_process(X).float().to(device)
            feature_batch = ae.encode(X)
            # feature_batch = ae.get_latents(X)[0]

            feature_list.append(feature_batch)
            if not ignore_labels:
                if type(y) not in (list, tuple):
                    y = [y]
                label_list.append(y)
    features = torch.cat([torch.cat(f, dim=0) for f in zip(*feature_list)], dim=-1)
    labels = [torch.cat(l, dim=0) for l in zip(*label_list)]

    print(features.shape)

    flow_train_loader = DataLoader(TensorDataset(features.data.cpu()), batch_size=args.batch_size, shuffle=True, drop_last=True, pin_memory=True, num_workers=args.workers)

    model = Glow1d(latent_dim, 10).to(device)
    optim = torch.optim.Adam(model.parameters(), lr=args.lr)

    model_save_path = os.path.join(args.save_dir, "{}-flow.pth".format(args.experiment_id))
    elapsed_time = 0.
    for i in range(30):
        for j, (X,) in enumerate(flow_train_loader):
            X = X.to(device)
            start = time.time()
            logp = model(X)
            loss = -logp.mean()

            optim.zero_grad()
            loss.backward()
            optim.step()
            
            elapsed_time += time.time() - start
            print("Epoch {:d}, step {:d}, -logp: {:.4f}".format(i, j, loss.item()))
        torch.save(model.state_dict(), model_save_path)
    print(elapsed_time)
    # model.load_state_dict(torch.load(model_save_path))
    

    fig_samples = plt.figure()
    fig_samples.set_size_inches(15, 15)
    grid_samples = ImageGrid(fig_samples, 111, args.visdim, axes_pad=0.)
    fig_samples.suptitle("Sampled")

    figs = [fig_samples]

    def imshow_axis(ax, img):
        ax.clear()
        ax.imshow(img, cmap=plt.cm.gray)
        ax.axis('off')
    
    print("Generating samples")
    for i in tqdm(range(10)):
        latent = model.sample(np.prod(args.visdim))
        samples = ae.decode(torch.split(latent, args.ae_encoding_depth, dim=1))
        samples = post_process(samples).view(-1, *args.img_dim).permute(0, 2, 3, 1).detach().cpu().numpy()

        for j, ax in enumerate(np.ravel(grid_samples)):
            imshow_axis(ax, np.squeeze(np.clip(samples[j], 0., 1.)))
        
        for f in figs:
            f.canvas.draw()
            f.canvas.flush_events()

        sample_file = os.path.join(args.sample_dir, "sample_{:02d}.png".format(i))
        fig_samples.savefig(sample_file)
    
    img_subdir = os.path.join(args.sample_dir, "images")
    if not os.path.exists(img_subdir):
        os.mkdir(img_subdir)
    counter = 0
    latent_list = []
    for i in tqdm(range(200)):
        with torch.no_grad():
            latent = model.sample(50, 0.6)
        latent = torch.split(latent, args.ae_encoding_depth, dim=1)
        latent_list.append(latent)
        samples = ae.decode(latent)
        samples = post_process(samples).view(-1, *args.img_dim).permute(0, 2, 3, 1).detach().cpu().numpy()

        for s in samples:
            imsave(os.path.join(img_subdir, "sampled_{:d}.png".format(counter)), np.round(s * 255).astype(np.uint8))
            counter += 1


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # Paths and names
    parser.add_argument("--experiment_id", default="vanilla", type=str)
    parser.add_argument("--save_dir", default="./saved_models", type=str)
    parser.add_argument("--save_freq", default=1, type=int)
    parser.add_argument("--cache_dir", default="./cache", type=str)
    parser.add_argument("--sample_dir", default="./samples", type=str)

    # Misc params
    parser.add_argument("--visdim", default=(8, 8), type=int, nargs='+')

    # Dataset params
    parser.add_argument("--dataset", default="mnist", type=str)
    parser.add_argument("--workers", default=8, type=int)

    # Autoencoder params
    parser.add_argument("--ae_load_path", default=None, type=str)
    parser.add_argument("--ae_depths", default=[32, 64, 128, 256, 512], nargs='+', type=int)
    parser.add_argument("--ae_scale_factor", default=1, type=int)
    parser.add_argument("--ae_layers", default=[1, 1, 1, 1], nargs='+', type=int)
    parser.add_argument("--ae_encoding_depth", default=128, type=int)
    parser.add_argument("--ae_groups", default=1, type=int)
    parser.add_argument("--ae_input_noise_var", default=0.2, type=float)

    # Flow model params
    parser.add_argument("--pf_use_cached_features", default=False, action='store_true')
    parser.add_argument("--pf_sample_size", default=10000, type=int)
    parser.add_argument("--pf_use_compression", default=False, action='store_true')
    parser.add_argument("--pf_nystrom_points", default=4096, type=int)

    # Training hyperparamters
    parser.add_argument("--batch_size", default=256, type=int)
    parser.add_argument("--epochs", default=100, type=int)
    parser.add_argument("--lr", default=1e-4, type=float)
    parser.add_argument("--lr_epochs", default=30, type=int)
    parser.add_argument("--lr_frac", default=0.5)
    
    args = parser.parse_args()

    args.sample_dir = os.path.join(args.sample_dir, args.experiment_id)
    
    for _dir in (args.save_dir, args.cache_dir, args.sample_dir):
        os.makedirs(_dir, exist_ok=True)

    main(args)