import argparse
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
import numpy as np
import time
import os
import visdom
from skimage.io import imsave
from PIL import Image

plt.ion()

from modules.autoencoder import WAE
from modules.loss import mix_rbf_mmd, poly_mmd, LapLoss
from modules.utils import weighted_max, SubsetDataset
from modules.kernel_fn import *
from dataset import get_dataset

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from tqdm import tqdm

# device = 'cuda'

def get_samples(dataset, sample_size=-1):
    inds = np.random.permutation(len(dataset))
    if sample_size >= 0:
        inds = inds[:sample_size]
    return torch.stack([dataset[ind][0] for ind in inds]).float()

def pre_process(image):
    return image * 2 - 1

def post_process(logits):
    return torch.clamp((logits + 1) / 2, 0., 1.)

def imq_mmd(X: torch.Tensor,
            Y: torch.Tensor):
    batch_size = X.size(0)

    norms_x = X.pow(2).sum(1, keepdim=True)  # batch_size x 1
    prods_x = torch.mm(X, X.t())  # batch_size x batch_size
    dists_x = norms_x + norms_x.t() - 2 * prods_x

    norms_y = Y.pow(2).sum(1, keepdim=True)  # batch_size x 1
    prods_y = torch.mm(Y, Y.t())  # batch_size x batch_size
    dists_y = norms_y + norms_y.t() - 2 * prods_y

    dot_prd = torch.mm(X, Y.t())
    dists_c = norms_x + norms_y.t() - 2 * dot_prd

    stats = 0
    for scale in [.1, .2, .5, 1., 2., 5., 10.]:
        C = 2 * X.size(1) * 1.0 * scale
        res1 = C / (C + dists_x)
        res1 += C / (C + dists_y)

        if torch.cuda.is_available():
            res1 = (1 - torch.eye(batch_size).cuda()) * res1
        else:
            res1 = (1 - torch.eye(batch_size)) * res1

        res1 = res1.sum() / (batch_size - 1)
        res2 = C / (C + dists_c)
        res2 = res2.sum() * 2. / (batch_size)
        stats += res1 - res2

    return stats

def train_wae(ae, train_loader, epochs, lr, lr_epochs, lr_frac, args, device=torch.device('cuda'), visualize=True, viz=None):
    assert not visualize or viz is not None

    ae = ae.to(device)
    optim = torch.optim.Adam(ae.parameters(), lr=lr)
    lr_lambda = lambda epoch: np.power(lr_frac, int(epoch) // lr_epochs)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda)
    criterion = lambda x, y: torch.mean(torch.sum((x - y).pow(2), dim=1))
    for epoch in range(epochs):
        for i, (X, y) in enumerate(train_loader):
            X = pre_process(X).float().to(device)

            noise = torch.randn_like(X).to(device) * args.ae_input_noise_var
            X_corrupted = X + noise

            logits, latents = ae(X_corrupted)
            rec_loss = criterion(logits, X)
            mmd_loss = torch.stack([imq_mmd(l, torch.randn_like(l)) for l in latents], dim=0).sum() / X.size(0)
            loss = rec_loss + mmd_loss

            optim.zero_grad()
            loss.backward()
            optim.step()

            print("Epoch {:d}, Step {:d}, Rec Loss: {:.4f}, MMD Loss: {:.4f}".format(epoch + 1, i, rec_loss, mmd_loss))

            def _visualize():
                # Switch temporarily to eval mode
                ae.eval()

                with torch.no_grad():
                    interp = ae.module.bilinear_interpolate(latents, steps=args.visdim[0])

                X_img = post_process(X).view(-1, *args.img_dim).detach().cpu()
                X_corrupted_img = post_process(X_corrupted).view(-1, *args.img_dim).detach().cpu()
                pred_img = post_process(logits).view(-1, *args.img_dim).detach().cpu()
                interp_img = post_process(interp).view(args.visdim[0] * args.visdim[0], *args.img_dim).detach().cpu()

                vis_count = np.prod(args.visdim)

                viz.images(X_img[:vis_count], nrow=args.visdim[0], win='x', opts={'title': 'data'})
                viz.images(X_corrupted_img[:vis_count], nrow=args.visdim[0], win='x_corrupted', opts={'title': 'data_corrupted'})
                viz.images(pred_img[:vis_count], nrow=args.visdim[0], win='pred', opts={'title': 'prediction'})
                viz.images(interp_img[:vis_count], nrow=args.visdim[0], win='interp', opts={'title': 'Interpolation'})

                ae.train()

            if visualize and i % 500 == 0:
                _visualize()
        if visualize:
            _visualize()

        scheduler.step()

        ae_state = ae.module.state_dict()
        torch.save(ae_state, os.path.join(args.save_dir, "{}-wae.pth".format(args.experiment_id)))

def main(args):
    device = 'cuda'

    viz = visdom.Visdom(port=8097)

    def get_transform(img_dim):
        t = transforms.Compose([
            transforms.Resize((*img_dim[-2:],)),
            transforms.ToTensor(),
        ])
        return t

    train_ds = get_dataset(args, get_transform)
    train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, drop_last=True, pin_memory=True, num_workers=args.workers)

    # Train Autoencoder
    ae = WAE(args.img_dim,
             n_l=args.ae_layers, 
             depths=args.ae_depths, 
             encoding_depth=args.ae_encoding_depth, 
             scale_factor=args.ae_scale_factor,
             groups=args.ae_groups)
    if args.ae_load_path is not None:
        ae_state = torch.load(args.ae_load_path)
        ae.load_state_dict(ae_state)
    else:
        print("AE load path is None. Training a new AE...")
        train_wae(nn.DataParallel(ae), train_loader, args.epochs, args.lr, args.lr_epochs, args.lr_frac, args, device=device, viz=viz)

    ae.to(device)

    fig_samples = plt.figure()
    fig_samples.set_size_inches(15, 15)
    grid_samples = ImageGrid(fig_samples, 111, args.visdim, axes_pad=0.)
    fig_samples.suptitle("Sampled")

    figs = [fig_samples]
    
    print("Generating samples")
    # ae.eval()
    for i in tqdm(range(10)):
        samples, _ = ae.sample(np.prod(args.visdim), device=device)
        samples = post_process(samples).view(-1, *args.img_dim).permute(0, 2, 3, 1).detach().cpu().numpy()

        for j, ax in enumerate(np.ravel(grid_samples)):
            ax.clear()
            ax.imshow(np.squeeze(np.squeeze(np.clip(samples[j], 0., 1.))), cmap=plt.cm.gray)
            ax.axis('off')
        
        for f in figs:
            f.canvas.draw()
            f.canvas.flush_events()

        sample_file = os.path.join(args.sample_dir, "sample_{:02d}.png".format(i))
        fig_samples.savefig(sample_file)
    
    img_subdir = os.path.join(args.sample_dir, "images")
    if not os.path.exists(img_subdir):
        os.mkdir(img_subdir)
    counter = 0
    latent_list = []
    for i in tqdm(range(200)):
        with torch.no_grad():
            samples, latents = ae.sample(50, device=device)
            latent_list.append(latents)
            samples = post_process(samples).view(-1, *args.img_dim).permute(0, 2, 3, 1).detach().cpu().numpy()

        for s in samples:
            imsave(os.path.join(img_subdir, "sampled_{:d}.png".format(counter)), np.round(s * 255).astype(np.uint8))
            counter += 1

    img_subdir = os.path.join(args.sample_dir, "images_rec")
    if not os.path.exists(img_subdir):
        os.mkdir(img_subdir)
    counter = 0
    loader = DataLoader(train_ds, batch_size=50, drop_last=False, pin_memory=True, num_workers=args.workers)
    rec_latent_list = []
    data_iter = iter(loader)
    for i in tqdm(range(200)):
        X, y = next(data_iter)
        with torch.no_grad():
            X = pre_process(X).float().to(device)
            reconstructions, latents = ae(X)
            rec_latent_list.append(latents)
            reconstructions = post_process(reconstructions).view(-1, *args.img_dim).permute(0, 2, 3, 1).detach().cpu().numpy()

        for r in reconstructions:
            imsave(os.path.join(img_subdir, "reconstructed_{:d}.png".format(counter)), np.round(r * 255).astype(np.uint8))
            counter += 1
    
    def _concat_latent(latent_batched_list):
        _tmp = [torch.cat(l, dim=0) for l in zip(*latent_batched_list)]
        return torch.cat([_t.view(_t.size(0), -1) for _t in _tmp], dim=-1)

    all_y = _concat_latent(rec_latent_list)
    all_latent = _concat_latent(latent_list)
    mmd_score = mix_rbf_mmd(all_y, all_latent, sigma_list=[1.])
    print("MMD:", mmd_score.item())


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # Paths and names
    parser.add_argument("--experiment_id", default="vanilla", type=str)
    parser.add_argument("--save_dir", default="./saved_models", type=str)
    parser.add_argument("--save_freq", default=1, type=int)
    parser.add_argument("--cache_dir", default="./cache", type=str)
    parser.add_argument("--sample_dir", default="./samples", type=str)

    # Misc params
    parser.add_argument("--visdim", default=(8, 8), type=int, nargs='+')

    # Dataset params
    parser.add_argument("--dataset", default="mnist", type=str)
    parser.add_argument("--workers", default=8, type=int)

    # Autoencoder params
    parser.add_argument("--ae_load_path", default=None, type=str)
    parser.add_argument("--ae_depths", default=[32, 64, 128, 256, 512], nargs='+', type=int)
    parser.add_argument("--ae_scale_factor", default=1, type=int)
    parser.add_argument("--ae_layers", default=[1, 1, 1, 1], nargs='+', type=int)
    parser.add_argument("--ae_encoding_depth", default=128, type=int)
    parser.add_argument("--ae_groups", default=1, type=int)
    parser.add_argument("--ae_input_noise_var", default=0., type=float)

    # Training hyperparamters
    parser.add_argument("--batch_size", default=256, type=int)
    parser.add_argument("--epochs", default=100, type=int)
    parser.add_argument("--lr", default=1e-4, type=float)
    parser.add_argument("--lr_epochs", default=30, type=int)
    parser.add_argument("--lr_frac", default=0.5)
    
    args = parser.parse_args()

    args.sample_dir = os.path.join(args.sample_dir, args.experiment_id)
    
    for _dir in (args.save_dir, args.cache_dir, args.sample_dir):
        os.makedirs(_dir, exist_ok=True)

    main(args)