import argparse
import json
import os
import warnings
import datetime
from shutil import copyfile

import torch
from tensorboardX import SummaryWriter
from torchvision.datasets import CIFAR10
from torchvision.utils import make_grid, save_image
from torchvision import transforms
from tqdm import trange


def infiniteloop(dataloader):
    while True:
        for x, y in iter(dataloader):
            yield x


def evaluate(ddpm, args, device):
    from metrics.both import get_inception_and_fid_score

    ddpm.eval()
    with torch.no_grad():
        images = []
        desc = "generating images"
        eval_batch = args.eval_batch * args.num_gpus
        for i in trange(0, args.num_images, eval_batch, desc=desc):
            batch_size = min(eval_batch, args.num_images - i)
            x_T = torch.randn((batch_size, args.img_channels,
                               args.img_size, args.img_size))
            x_T = x_T.to(device)
            batch_images = ddpm(x_T, mode='sample')
            images.append((batch_images + 1) / 2)
        images = torch.cat(images, dim=0).cpu().numpy()
    ddpm.train()
    (IS, IS_std), FID = get_inception_and_fid_score(
        images, args.fid_cache, num_images=args.num_images,
        use_torch=args.fid_use_torch, verbose=True)
    return (IS, IS_std), FID, images


def train(args):
    from modules.models.ddpm import UNet
    from modules.diffusion import DDPM, get_beta_schedule
    from modules.ema import EMA
    device = torch.device('cuda:0')

    # dataset
    dataset = CIFAR10(
        root='./data/cifar', train=True, download=True,
        transform=transforms.Compose([
            transforms.Resize(args.img_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]))
    data_loader = torch.utils.data.DataLoader(
        dataset, batch_size=args.batch_size, shuffle=True,
        num_workers=args.num_workers, drop_last=True)
    data_looper = infiniteloop(data_loader)

    # model setup
    unet = UNet(
        T=args.T, ch=args.ngf, ch_mult=args.ch_mult, attn=args.attn,
        num_res_blocks=args.num_res_blocks, dropout=args.dropout, img_size=args.img_size)
    optimizer = torch.optim.Adam(unet.parameters(), lr=args.lr)
    optimizer = EMA(optimizer, ema_decay=args.ema_decay)
    warmup_lr = lambda step: min(step, args.warmup) / args.warmup
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_lr)
    betas = get_beta_schedule(args.beta_schedule, args.beta_1, args.beta_T, args.T)
    ddpm = DDPM(model=unet, betas=betas,
                mean_type=args.mean_type, var_type=args.var_type).to(device)
    print('Number of available gpus: ', torch.cuda.device_count())
    ddpm = torch.nn.DataParallel(ddpm)

    # log setup
    time_now = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
    args.model_dir = time_now
    logdir = os.path.join(args.logdir, time_now)
    os.makedirs(os.path.join(logdir, 'sample'))
    copyfile(__file__, os.path.join(logdir, os.path.basename(__file__)))
    writer = SummaryWriter(logdir)
    fixed_x_T = torch.randn(args.sample_size, args.img_channels,
                            args.img_size, args.img_size)
    fixed_x_T = fixed_x_T.to(device)
    grid = (make_grid(next(iter(data_loader))[0][:args.sample_size]) + 1) / 2
    writer.add_image('real_sample', grid)
    writer.flush()

    # backup all arguments
    with open(os.path.join(logdir, 'flagfile.txt'), 'a') as f:
        f.write(json.dumps(args.__dict__, indent=2) + '\n')

    # show model size
    model_size = 0
    for param in unet.parameters():
        model_size += param.data.nelement()
    print('Model params: %.2f M' % (model_size / 1024 / 1024))

    # start training
    with trange(args.total_steps, dynamic_ncols=True) as pbar:
        for global_step in pbar:
            unet.zero_grad()
            x_0 = next(data_looper).to(device)
            loss = ddpm(x_0, mode='train').mean()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(unet.parameters(), args.grad_clip)
            optimizer.step()
            scheduler.step()

            # log
            pbar.set_postfix(loss='%.4f' % loss)
            writer.add_scalar('loss', loss, global_step + 1)

            # sample
            if args.sample_step > 0 and (global_step + 1) % args.sample_step == 0:
                optimizer.swap_parameters_with_ema(store_params_in_ema=True)
                ddpm.eval()
                with torch.no_grad():
                    x_0 = ddpm(fixed_x_T, mode='sample')
                    grid = (make_grid(x_0) + 1) / 2
                    path = os.path.join(
                        logdir, 'sample', '%d.png' % (global_step + 1))
                    save_image(grid, path)
                    writer.add_image('sample', grid, global_step + 1)
                ddpm.train()
                optimizer.swap_parameters_with_ema(store_params_in_ema=True)

            # save
            if args.save_step > 0 and (global_step + 1) % args.save_step == 0:
                ckpt = {
                    'unet': unet.state_dict(),
                    'schedular': scheduler.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'step': global_step,
                    'fixed_x_T': fixed_x_T,
                    'args': args
                }
                torch.save(ckpt, os.path.join(logdir, 'ckpt.pth'))
                optimizer.swap_parameters_with_ema(store_params_in_ema=True)
                torch.save(unet.state_dict(),
                           os.path.join(logdir, 'unet_%d.pth' % (global_step + 1)))
                optimizer.swap_parameters_with_ema(store_params_in_ema=True)

    writer.close()


def eval(args):
    from modules.models.ddpm import UNet
    from modules.diffusion import DDPM, get_beta_schedule
    device = torch.device('cuda:0')

    # model setup
    unet = UNet(
        T=args.T, ch=args.ngf, ch_mult=args.ch_mult, attn=args.attn,
        num_res_blocks=args.num_res_blocks, dropout=args.dropout, img_size=args.img_size)

    betas = get_beta_schedule(args.beta_schedule, args.beta_1, args.beta_T, args.T)
    ddpm = DDPM(model=unet, betas=betas,
                mean_type=args.mean_type, var_type=args.var_type).to(device)
    print('Number of available gpus: ', torch.cuda.device_count())
    ddpm = torch.nn.DataParallel(ddpm)

    # load model and evaluate
    logdir = os.path.join(args.logdir, args.model_dir)
    for i in range(14, 15):
        print('======> start evaluating unet_%d.pth' % (args.save_step * i))
        ckpt = torch.load(os.path.join(logdir, 'unet_%d.pth' % (args.save_step * i)))
        unet.load_state_dict(ckpt)
        (IS, IS_std), FID, samples = evaluate(ddpm, args, device)
        metrics = {
            'IS': IS,
            'IS_std': IS_std,
            'FID': FID,
        }
        print("unet_%d     : IS:%6.3f(%.3f), FID:%7.3f"
              % (args.save_step * i, IS, IS_std, FID))
        with open(os.path.join(logdir, 'evaluation.txt'), 'a') as f:
            metrics['unet_id'] = args.save_step * i
            f.write(json.dumps(metrics) + "\n")
        save_image(torch.tensor(samples[:256]),
                   os.path.join(logdir, 'eval_sample_%d.png' % (args.save_step * i)),
                   nrow=16)


if __name__ == '__main__':
    parser = argparse.ArgumentParser('ddpm_cifar parameters')
    parser.add_argument('--seed', type=int, default=1024, help='seed used for initialization')
    parser.add_argument('--train', action='store_true', default=False, help='train from scratch')
    parser.add_argument('--eval', action='store_true', default=False, help='load model.pt and evaluate FID and IS')
    parser.add_argument('--img_channels', type=int, default=3, help='channel of image')
    parser.add_argument('--img_size', type=int, default=32, help='size of image')
    # Gaussian diffusion
    parser.add_argument('--beta_1', type=float, default=1e-4, help='start beta value')
    parser.add_argument('--beta_T', type=float, default=0.02, help='end beta value')
    parser.add_argument('--T', type=int, default=1000, help='total diffusion steps')
    parser.add_argument('--beta_schedule', type=str, default='linear', choices=['linear', 'cosine', 'vpsde'],
                        help='beta schedule type')
    parser.add_argument('--mean_type', type=str, default='epsilon', choices=['xprev', 'xstart', 'epsilon'],
                        help='predict variable')
    parser.add_argument('--var_type', type=str, default='fixedlarge', choices=['fixedlarge', 'fixedsmall'],
                        help='denoising kernel variance type')
    # DDPM unet
    parser.add_argument('--ngf', type=int, default=128, help='base channel of UNet')
    parser.add_argument('--ch_mult', nargs='+', type=int, default=[1, 2, 2, 2], help='channel multiplier')
    parser.add_argument('--attn', default=[16], help='add attention to these levels')
    parser.add_argument('--num_res_blocks', type=int, default=2, help='number of resnet blocks per scale')
    parser.add_argument('--dropout', type=float, default=0.1, help='drop-out rate')
    # Training
    parser.add_argument('--lr', type=float, default=2e-4, help='target learning rate')
    parser.add_argument('--grad_clip', type=float, default=1., help='gradient norm clipping')
    parser.add_argument('--total_step', type=int, default=800000, help='total training epochs')
    parser.add_argument('--warmup', type=int, default=5000, help='learning rate warmup')
    parser.add_argument('--batch_size', type=int, default=128, help='input batch size')
    parser.add_argument('--num_workers', type=int, default=4, help='workers of Dataloader')
    parser.add_argument('--ema_decay', type=float, default=0.9999, help='decay rate for EMA')
    # Logging & Sampling & Save
    parser.add_argument('--logdir', type=str, default='./logs/DDPM_CIFAR10_EPS', help='log directory')
    parser.add_argument('--sample_size', type=int, default=64, help='sampling size of images')
    parser.add_argument('--sample_step', type=int, default=5000, help='frequency of sampling')
    parser.add_argument('--save_step', type=int, default=50000,
                        help='frequency of saving checkpoints, 0 to disable during training')
    # Evaluation
    parser.add_argument('--eval_batch', type=int, default=100)
    parser.add_argument('--num_images', type=int, default=50000,
                        help='the number of generated images for evaluation')
    parser.add_argument('--fid_use_torch', action='store_true', default=False,
                        help='calculate IS and FID on gpu')
    parser.add_argument('--fid_cache', type=str, default='./data/stats/cifar10.npz',
                        help='FID cache')
    parser.add_argument('--model_dir', type=str, default='',
                        help='the dir name of the trained model')
    # Distributed data parallel
    parser.add_argument('--master_address', type=str, default='localhost',
                        help='address for master')
    parser.add_argument('--local_rank', type=int, default=0, help='rank of process in the node')
    parser.add_argument('--num_gpus', type=int, default=3, help='number of gpus')
    args = parser.parse_args()

    os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2'
    args.world_size = world_size = args.num_gpus

    # suppress annoying inception_v3 initialization warning
    warnings.simplefilter(action='ignore', category=FutureWarning)

    if args.train:
        train(args)
    if args.eval:
        eval(args)
    if not args.train and not args.eval:
        print('Add --train and/or --eval to execute corresponding tasks')

