import argparse
import json
import os
import warnings
import datetime
import time
from shutil import copyfile

import math
import torch
import torch.nn.functional as F
from tensorboardX import SummaryWriter
from datasets.lmdb_datasets import LMDBDataset
from torchvision.utils import make_grid, save_image
from torchvision import transforms
from tqdm import tqdm, trange

import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler
from torch.multiprocessing import Process
from torch.nn.parallel import DistributedDataParallel


def broadcast_params(params):
    for param in params:
        dist.broadcast(param.data, src=0)


def evaluate(dpm, args, device):
    from metrics.both import get_inception_and_fid_score

    dpm.eval()
    with torch.no_grad():
        images = []
        desc = "generating images"
        eval_batch = args.eval_batch * args.num_gpus
        for i in trange(0, args.num_images, eval_batch, desc=desc):
            batch_size = min(eval_batch, args.num_images - i)
            x_T = torch.randn((batch_size, args.img_channels,
                               args.img_size, args.img_size))
            x_T = x_T.to(device)
            batch_images = dpm(x_T, mode=args.sampling_mode, m=args.m)
            images.append((batch_images + 1) / 2)
        images = torch.cat(images, dim=0).cpu().numpy()
    dpm.train()
    (IS, IS_std), FID = get_inception_and_fid_score(
        images, args.fid_cache, num_images=args.num_images,
        use_torch=args.fid_use_torch, verbose=True)
    return (IS, IS_std), FID, images


def train(rank, gpu, args):
    from modules.models.generator import NCSNppGenerator
    from modules.models.ebm import LargeEBM
    from modules.diffusion import DPM, get_beta_schedule, get_reduced_beta_schedule
    from modules.ema import EMA

    torch.manual_seed(args.seed + rank)
    torch.cuda.manual_seed(args.seed + rank)
    torch.cuda.manual_seed_all(args.seed + rank)
    device = torch.device('cuda:{}'.format(gpu))

    # dataset
    train_transform = transforms.Compose([
        transforms.Resize(args.img_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    dataset = LMDBDataset(root='./data/celebahq/celeba-lmdb/', name='celebahq',
                          train=True, transform=train_transform)
    train_sampler = DistributedSampler(dataset, num_replicas=args.world_size, rank=rank)
    data_loader = torch.utils.data.DataLoader(
        dataset, batch_size=args.batch_size, shuffle=False,
        num_workers=args.num_workers, pin_memory=True,
        sampler=train_sampler, drop_last=True)

    # model setup
    G = NCSNppGenerator(args).to(device)
    E = LargeEBM(image_channels=args.img_channels * 2, ndf=args.ndf,
                 reduced_T=args.reduced_T, temb_dim=args.t_emb_dim).to(device)
    broadcast_params(G.parameters())
    broadcast_params(E.parameters())
    optimizerG = torch.optim.Adam(G.parameters(), lr=args.lrG, betas=(0.5, 0.9))
    optimizerE = torch.optim.Adam(E.parameters(), lr=args.lrE, betas=(0.5, 0.9))
    optimizerG = EMA(optimizerG, ema_decay=args.ema_decay)
    schedulerG = torch.optim.lr_scheduler.CosineAnnealingLR(optimizerG, 800, eta_min=1e-5)
    schedulerE = torch.optim.lr_scheduler.CosineAnnealingLR(optimizerE, 800, eta_min=1e-5)
    G = DistributedDataParallel(G, device_ids=[gpu])
    E = DistributedDataParallel(E, device_ids=[gpu])

    betas = get_beta_schedule(args.beta_schedule, args.beta_1, args.beta_T, args.T)
    reduced_betas = get_reduced_beta_schedule(args.reduced_beta_schedule, betas, args.reduced_T)
    dpm = DPM(M=None, reduced_betas=reduced_betas, nz=args.nz, 
            mean_type=args.mean_type, var_type=args.var_type).to(device)
    dpm.M = G

    if rank == 0:
        # log setup
        time_now = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
        args.model_dir = time_now
        logdir = os.path.join(args.logdir, time_now)
        os.makedirs(os.path.join(logdir, 'sample'))
        os.makedirs(os.path.join(logdir, 'display'))
        copyfile(__file__, os.path.join(logdir, os.path.basename(__file__)))
        writer = SummaryWriter(logdir)
        fixed_x_T = torch.randn(args.sample_size, args.img_channels,
                                args.img_size, args.img_size)
        fixed_x_T = fixed_x_T.to(device)
        grid = (make_grid(next(iter(data_loader))[0][:args.sample_size]) + 1) / 2
        writer.add_image('real_sample', grid)
        writer.flush()
        # backup all arguments
        with open(os.path.join(logdir, 'flagfile.txt'), 'a') as f:
            f.write(json.dumps(args.__dict__, indent=2) + '\n')
        # show model size
        model_size = 0
        for param in G.parameters():
            model_size += param.data.nelement()
        print('G params: %.2f M' % (model_size / 1024 / 1024))
        model_size = 0
        for param in E.parameters():
            model_size += param.data.nelement()
        print('E params: %.2f M' % (model_size / 1024 / 1024))

    # start training
    total_epochs = math.ceil(args.total_iters / len(data_loader))
    global_step = 0
    with trange(total_epochs, dynamic_ncols=True) as pbar:
        for epoch in pbar:
            train_sampler.set_epoch(epoch)

            ppbar = tqdm(data_loader)
            for x, _ in ppbar:
                x_0 = x.to(device, non_blocking=True)
                bs = x_0.shape[0]

                # train EBM
                E.requires_grad_(True)
                G.requires_grad_(False)
                E.zero_grad()

                t = torch.randint(args.reduced_T, size=(bs,), device=device) + 1
                x_t = dpm.q_sample(x_0, t)
                x_0.requires_grad = True

                E_real = E(x_0, t - 1, x_t).view(-1)
                lossE_real = F.softplus(-E_real).mean()
                lossE_real.backward(retain_graph=True)

                if args.lazy_reg is None:
                    grad_real = torch.autograd.grad(outputs=E_real.sum(),
                                                    inputs=x_0, create_graph=True)[0]
                    grad_penalty = (grad_real.view(bs, -1).norm(2, dim=1) ** 2).mean()
                    grad_penalty = args.r1 / 2 * grad_penalty
                    grad_penalty.backward()
                else:
                    if global_step % args.lazy_reg == 0:
                        grad_real = torch.autograd.grad(outputs=E_real.sum(),
                                                        inputs=x_0, create_graph=True)[0]
                        grad_penalty = (grad_real.view(bs, -1).norm(2, dim=1) ** 2).mean()
                        grad_penalty = args.r1 / 2 * grad_penalty
                        grad_penalty.backward()

                u = torch.randn(bs, args.nz, device=device)
                x_0_fake = G(x_t, t - 1, u)

                E_fake = E(x_0_fake, t - 1, x_t).view(-1)
                lossE_fake = F.softplus(E_fake).mean()
                lossE_fake.backward()

                # torch.nn.utils.clip_grad_norm_(E.parameters(), args.grad_clip)
                optimizerE.step()

                # train Generator
                E.requires_grad_(False)
                G.requires_grad_(True)
                G.zero_grad()

                t = torch.randint(args.reduced_T, size=(bs,), device=device) + 1
                x_0.requires_grad = False

                x_t = dpm.q_sample(x_0, t)
                u = torch.randn(bs, args.nz, device=device)
                x_0_fake = G(x_t, t - 1, u)

                E_gen = E(x_0_fake, t - 1, x_t).view(-1)
                lossG = F.softplus(-E_gen).mean()

                lossG.backward()
                # torch.nn.utils.clip_grad_norm_(G.parameters(), args.grad_clip)
                optimizerG.step()

                global_step += 1
                ppbar.set_postfix(rank=rank, E_real='%.4f' % E_real.mean(), E_fake='%.4f' % E_fake.mean(),
                                  E_gen='%.4f' % E_gen.mean())

                if rank == 0:
                    writer.add_scalar('lossE_real', lossE_real, global_step)
                    writer.add_scalar('lossE_fake', -lossE_fake, global_step)
                    writer.add_scalar('lossG', lossG, global_step)
                    writer.add_scalar('E_real', E_real.mean(), global_step)
                    writer.add_scalar('E_fake', E_fake.mean(), global_step)
                    writer.add_scalar('E_gen', E_gen.mean(), global_step)
                    # display
                    if args.display_step > 0 and global_step % args.display_step == 0:
                        grid = (make_grid(torch.cat((x_0, x_0_fake), dim=0), nrow=bs) + 1) / 2
                        path = os.path.join(
                                logdir, 'display', '%d_x_0&x_0_fake.png' % global_step)
                        save_image(grid, path)
                        writer.add_image('x_0&x_0_fake', grid, global_step)
                        grid = (make_grid(x_t, nrow=bs) + 1) / 2
                        path = os.path.join(
                                logdir, 'display', '%d_x_t.png' % global_step)
                        save_image(grid, path)
                        writer.add_image('x_t', grid, global_step)
                    # sample
                    if args.sample_step > 0 and global_step % args.sample_step == 0:
                        optimizerG.swap_parameters_with_ema(store_params_in_ema=True)
                        G.eval()
                        with torch.no_grad():
                            x_0 = dpm(fixed_x_T, mode=args.sampling_mode, m=args.m)
                            grid = (make_grid(x_0) + 1) / 2
                            path = os.path.join(
                                logdir, 'sample', '%d.png' % global_step)
                            save_image(grid, path)
                            writer.add_image('sample', grid, global_step)
                        G.train()
                        optimizerG.swap_parameters_with_ema(store_params_in_ema=True)
                    # save
                    if args.save_step > 0 and global_step >= 250000 and global_step % args.save_step == 0:
                        ckpt = {
                            'G': G.state_dict(),
                            'E': E.state_dict(),
                            'optimizerG': optimizerG.state_dict(),
                            'optimizerE': optimizerE.state_dict(),
                            'schedularG': schedulerG.state_dict(),
                            'schedularE': schedulerE.state_dict(),
                            'step': global_step,
                            'fixed_x_T': fixed_x_T,
                            'args': args
                        }
                        torch.save(ckpt, os.path.join(logdir, 'ckpt.pth'))
                        optimizerG.swap_parameters_with_ema(store_params_in_ema=True)
                        torch.save(G.module.state_dict(),
                                os.path.join(logdir, 'G_%d.pth' % global_step))
                        torch.save(E.module.state_dict(),
                                os.path.join(logdir, 'E_%d.pth' % global_step))
                        optimizerG.swap_parameters_with_ema(store_params_in_ema=True)

            schedulerG.step()
            schedulerE.step()

    if rank == 0:
        writer.close()


def eval(args):
    from modules.models.generator import NCSNppGenerator
    from modules.diffusion import DPM, get_beta_schedule, get_reduced_beta_schedule
    device = torch.device('cuda:0')
    print('%d available gpus for evaluation ' % torch.cuda.device_count())

    # model setup
    G = NCSNppGenerator(args)
    betas = get_beta_schedule(args.beta_schedule, args.beta_1, args.beta_T, args.T)
    reduced_betas = get_reduced_beta_schedule(args.reduced_beta_schedule, betas, args.reduced_T)
    dpm = DPM(M=G, reduced_betas=reduced_betas, nz=args.nz, 
            mean_type=args.mean_type, var_type=args.var_type).to(device)
    dpm = torch.nn.DataParallel(dpm)

    # load model and evaluate    
    logdir = os.path.join(args.logdir, args.model_dir)
    for i in range(116, 119):
        print('======> start evaluating G_%d.pth' % (args.save_step * i))
        ckpt = torch.load(os.path.join(logdir, 'G_%d.pth' % (args.save_step * i)))
        G.load_state_dict(ckpt)
        (IS, IS_std), FID, samples = evaluate(dpm, args, device)
        metrics = {
            'G_id': args.save_step * i,
            'm': args.m,
            'IS': IS,
            'IS_std': IS_std,
            'FID': FID,
        }
        print("G_%d     : IS:%6.3f(%.3f), FID:%7.3f"
              % (args.save_step * i, IS, IS_std, FID))
        with open(os.path.join(logdir, 'evaluation.txt'), 'a') as f:
            f.write(json.dumps(metrics) + "\n")
        save_image(torch.tensor(samples[:48]),
                   os.path.join(logdir, 'm=%.2f_eval_sample_%d.png' % (args.m, args.save_step * i)),
                   nrow=6)


def init_process(rank, size, fn, args):
    """ Initialize the distributed environment. """
    os.environ['MASTER_ADDR'] = args.master_address
    os.environ['MASTER_PORT'] = '6020'
    gpu = args.local_rank
    torch.cuda.set_device(gpu)
    dist.init_process_group(backend='nccl', init_method='env://', rank=rank, world_size=size)
    fn(rank, gpu, args)
    dist.barrier()
    dist.destroy_process_group()


#%%
if __name__ == '__main__':
    parser = argparse.ArgumentParser('tgdim_celebahq parameters')
    parser.add_argument('--seed', type=int, default=1024, help='seed used for initialization')
    parser.add_argument('--train', action='store_true', default=False, help='train from scratch')
    parser.add_argument('--eval', action='store_true', default=False, help='load model.pt and evaluate FID and IS')
    parser.add_argument('--img_channels', type=int, default=3, help='channel of image')
    parser.add_argument('--img_size', type=int, default=256, help='size of image')
    # Gaussian diffusion
    parser.add_argument('--beta_1', type=float, default=1e-4, help='start beta value')
    parser.add_argument('--beta_T', type=float, default=0.02, help='end beta value')
    parser.add_argument('--T', type=int, default=1000, help='total diffusion steps')
    parser.add_argument('--reduced_T', type=int, default=4, help='reduced diffusion steps')
    parser.add_argument('--beta_schedule', type=str, default='linear', choices=['linear', 'cosine', 'vpsde'],
                        help='beta schedule type')
    parser.add_argument('--mean_type', type=str, default='xstart', choices=['xstart', 'epsilon'],
                        help='predict variable')
    parser.add_argument('--var_type', type=str, default='fixedlarge', choices=['fixedlarge', 'fixedsmall'],
                        help='denoising kernel variance type')
    parser.add_argument('--sampling_mode', type=str, default='iddim_sample', choices=['ddpm_sample', 'iddim_sample'],
                        help='sampleing mode')
    parser.add_argument('--reduced_beta_schedule', type=str, default='quadratic', choices=['linear', 'quadratic'],
                        help='reduced beta schedule type')
    parser.add_argument('--m', type=float, default=1.0, help='replace rate')
    # NCSN++
    parser.add_argument('--ngf', type=int, default=64, help='base channel of UNet')
    parser.add_argument('--ch_mult', nargs='+', type=int, default=[1, 1, 2, 2, 4, 4], help='channel multiplier')
    parser.add_argument('--attn', default=[16], help='add attention to these levels')
    parser.add_argument('--num_res_blocks', type=int, default=2, help='number of resnet blocks per scale')
    parser.add_argument('--dropout', type=float, default=0., help='drop-out rate')
    parser.add_argument('--act', type=str, default='swish', choices=['elu', 'relu', 'lrelu', 'swish'],
                        help='activation functions in NCSN++')
    parser.add_argument('--resamp_with_conv', type=bool, default=True, help='always up/down sampling with conv')
    parser.add_argument('--fir', type=bool, default=True, help='FIR')
    parser.add_argument('--fir_kernel', default=[1, 3, 3, 1], help='FIR kernel')
    parser.add_argument('--init_scale', type=float, default=0., help='scale for initialization')
    parser.add_argument('--skip_rescale', type=bool, default=True, help='skip rescale')
    parser.add_argument('--resblock_type', type=str, default='biggan', choices=['ddpm', 'biggan'],
                        help='tyle of resnet block, choice in biggan and ddpm')
    parser.add_argument('--progressive', type=str, default='none', choices=['none', 'output_skip', 'residual'],
                        help='progressive type for output')
    parser.add_argument('--progressive_input', type=str, default='residual', choices=['none', 'input_skip', 'residual'],
                        help='progressive type for input')
    parser.add_argument('--progressive_combine', type=str, default='sum', choices=['sum', 'cat'],
                        help='progressive combine method.')
    parser.add_argument('--embedding_type', type=str, default='positional', choices=['positional', 'fourier'],
                        help='type of time embedding')
    parser.add_argument('--fourier_scale', type=float, default=16., help='scale of fourier transform')
    # Generator & EBM
    parser.add_argument('--nz', type=int, default=100, help='dim of the latent variable')
    parser.add_argument('--z_emb_dim', type=int, default=256, help='embedding dim of z')
    parser.add_argument('--n_mlp', type=int, default=3, help='number of mlp layers for z')
    parser.add_argument('--use_tanh', type=bool, default=True)
    parser.add_argument('--ndf', type=int, default=64)
    parser.add_argument('--t_emb_dim', type=int, default=256)
    # Training
    parser.add_argument('--lrG', type=float, default=2e-4, help='learning rate for generator')
    parser.add_argument('--lrE', type=float, default=1e-4, help='learning rate for ebm')
    parser.add_argument('--grad_clip', type=float, default=1., help='gradient norm clipping')
    parser.add_argument('--total_iters', type=int, default=700000, help='total training epochs')
    parser.add_argument('--batch_size', type=int, default=8, help='input batch size')
    parser.add_argument('--num_workers', type=int, default=4, help='workers of Dataloader')
    parser.add_argument('--ema_decay', type=float, default=0.9999, help='decay rate for EMA')
    parser.add_argument('--lazy_reg', type=int, default=10, help='lazy regulariation.')
    parser.add_argument('--r1', type=float, default=2.0, help='coef for penalizing gradients')
    # Logging & Sampling
    parser.add_argument('--logdir', type=str, default='./logs/tGDIM_CELEBAHQ_XSTART', help='log directory')
    parser.add_argument('--sample_size', type=int, default=8, help='sampling size of images')
    parser.add_argument('--display_step', type=int, default=2500, help='frequency of displaying')
    parser.add_argument('--sample_step', type=int, default=5000, help='frequency of sampling')
    parser.add_argument('--save_step', type=int, default=5000,
                        help='frequency of saving checkpoints, 0 to disable during training')
    # Evaluation
    parser.add_argument('--eval_batch', type=int, default=32)
    parser.add_argument('--num_images', type=int, default=30000,
                        help='the number of generated images for evaluation')
    parser.add_argument('--fid_use_torch', action='store_true', default=False,
                        help='calculate IS and FID on gpu')
    parser.add_argument('--fid_cache', type=str, default='./data/stats/celebahq.npz',
                        help='FID cache')
    parser.add_argument('--model_dir', type=str, default='',
                        help='the dir name of the trained model')
    # Distributed data parallel
    parser.add_argument('--master_address', type=str, default='localhost',
                        help='address for master')
    parser.add_argument('--local_rank', type=int, default=0, help='rank of process in the node')
    parser.add_argument('--num_gpus', type=int, default=4, help='number of gpus')
    args = parser.parse_args()

    os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'
    args.world_size = world_size = args.num_gpus

    # suppress annoying inception_v3 initialization warning
    warnings.simplefilter(action='ignore', category=FutureWarning)

    if args.train:
        if world_size > 1:
            processes = []
            for rank in range(world_size):
                args.local_rank = rank
                p = Process(target=init_process, args=(rank, world_size, train, args))
                p.start()
                processes.append(p)
            for p in processes:
                p.join()
        else:
            print('starting in debug mode')
            init_process(0, 1, train, args)
    if args.eval:
        eval(args)
    if not args.train and not args.eval:
        print('Add --train and/or --eval to execute corresponding tasks')
