import argparse
import json
import os
import warnings

import torch
from torchvision.utils import make_grid, save_image
from tqdm import trange


def evaluate(dpm, args, device):
    from metrics.both import get_inception_and_fid_score

    dpm.eval()
    with torch.no_grad():
        images = []
        desc = "generating images"
        eval_batch = args.eval_batch * args.num_gpus
        for i in trange(0, args.num_images, eval_batch, desc=desc):
            batch_size = min(eval_batch, args.num_images - i)
            x_T = torch.randn((batch_size, args.img_channels,
                               args.img_size, args.img_size))
            x_T = x_T.to(device)
            batch_images = dpm(x_T, mode=args.sampling_mode, m=args.m)
            images.append((batch_images + 1) / 2)
        images = torch.cat(images, dim=0).cpu().numpy()
    dpm.train()
    (IS, IS_std), FID = get_inception_and_fid_score(
        images, args.fid_cache, num_images=args.num_images,
        use_torch=args.fid_use_torch, verbose=True)
    return (IS, IS_std), FID, images


def eval(args):
    from modules.models.ddpm import UNet
    from modules.diffusion import DPM, get_beta_schedule, get_reduced_beta_schedule
    device = torch.device('cuda:0')
    print('Number of available gpus: ', torch.cuda.device_count())
    logdir = os.path.join(args.logdir, args.sampling_mode)

    # model setup
    unet = UNet(
        T=args.T, ch=args.ngf, ch_mult=args.ch_mult, attn=args.attn,
        num_res_blocks=args.num_res_blocks, dropout=args.dropout, img_size=args.img_size)
    ckpt = torch.load(args.pre_unet)
    unet.load_state_dict(ckpt)
    print('loaded unet from ' + args.pre_unet)

    betas = get_beta_schedule(args.beta_schedule, args.beta_1, args.beta_T, args.T)

    steps_list = [4, 10, 20, 50, 100, 1000]
    for i in steps_list:
        reduced_betas = get_reduced_beta_schedule(args.reduced_beta_schedule, betas, i)
        dpm = DPM(M=unet, reduced_betas=reduced_betas, nz=100,
            mean_type=args.mean_type, var_type=args.var_type).to(device)
        dpm = torch.nn.DataParallel(dpm)

        print('======> start evaluating [%s] with [%d] [%s] sampling steps' 
            % (args.sampling_mode, i, args.reduced_beta_schedule))
        (IS, IS_std), FID, samples = evaluate(dpm, args, device)
        metrics = {
            'unet_id': 70000,
            'sampling_steps': i,
            'schedule_type': args.reduced_beta_schedule,
            'sampling_mode': args.sampling_mode,
            'm': args.m,
            'IS': IS,
            'IS_std': IS_std,
            'FID': FID,
        }
        print("%d steps    : IS:%6.3f(%.3f), FID:%7.3f"
              % (i, IS, IS_std, FID))
        with open(os.path.join(logdir, 'evaluation.txt'), 'a') as f:
            f.write(json.dumps(metrics) + "\n")
        save_image(torch.tensor(samples[:256]),
                   os.path.join(logdir, '%s_with_%d_%s_m=%.2f_steps.png' % (args.sampling_mode, i, args.reduced_beta_schedule, args.m)),
                   nrow=16)


if __name__ == '__main__':
    parser = argparse.ArgumentParser('dpm_cifar parameters')
    parser.add_argument('--seed', type=int, default=1024, help='seed used for initialization')
    parser.add_argument('--img_channels', type=int, default=3, help='channel of image')
    parser.add_argument('--img_size', type=int, default=32, help='size of image')
    # Diffusion
    parser.add_argument('--beta_1', type=float, default=1e-4, help='start beta value')
    parser.add_argument('--beta_T', type=float, default=0.02, help='end beta value')
    parser.add_argument('--T', type=int, default=1000, help='total diffusion steps')
    parser.add_argument('--beta_schedule', type=str, default='linear', choices=['linear', 'cosine', 'vpsde'],
                        help='beta schedule type')
    parser.add_argument('--mean_type', type=str, default='epsilon', choices=['xprev', 'xstart', 'epsilon'],
                        help='predict variable')
    parser.add_argument('--var_type', type=str, default='fixedlarge', choices=['fixedlarge', 'fixedsmall'],
                        help='denoising kernel variance type')
    parser.add_argument('--sampling_mode', type=str, default='iddim_sample', choices=['ddpm_sample', 'iddim_sample'],
                        help='sampleing mode')
    parser.add_argument('--reduced_beta_schedule', type=str, default='quadratic', choices=['linear', 'quadratic'],
                        help='reduced beta schedule type')
    parser.add_argument('--m', type=float, default=1.0, help='replace rate')
    # DDPM unet
    parser.add_argument('--ngf', type=int, default=128, help='base channel of UNet')
    parser.add_argument('--ch_mult', nargs='+', type=int, default=[1, 2, 2, 2], help='channel multiplier')
    parser.add_argument('--attn', default=[16], help='add attention to these levels')
    parser.add_argument('--num_res_blocks', type=int, default=2, help='number of resnet blocks per scale')
    parser.add_argument('--dropout', type=float, default=0., help='drop-out rate')
    parser.add_argument('--pre_unet', type=str, default='./pretrained_model/ddpm_unet_700000.pth',
                        help='path of the pretrained ddpm unet model')
    # Logging & Sampling & Save
    parser.add_argument('--logdir', type=str, default='./logs/DPM_CIFAR10_EVAL', help='log directory')
    # Evaluation
    parser.add_argument('--eval_batch', type=int, default=100)
    parser.add_argument('--num_images', type=int, default=50000,
                        help='the number of generated images for evaluation')
    parser.add_argument('--fid_use_torch', action='store_true', default=False,
                        help='calculate IS and FID on gpu')
    parser.add_argument('--fid_cache', type=str, default='./data/stats/cifar10.npz',
                        help='FID cache')
    # Distributed data parallel
    parser.add_argument('--master_address', type=str, default='localhost',
                        help='address for master')
    parser.add_argument('--local_rank', type=int, default=0, help='rank of process in the node')
    parser.add_argument('--num_gpus', type=int, default=4, help='number of gpus')
    args = parser.parse_args()

    os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'
    args.world_size = world_size = args.num_gpus

    # suppress annoying inception_v3 initialization warning
    warnings.simplefilter(action='ignore', category=FutureWarning)

    # eval(args)
    eval(args)
