"""
Generate a large batch of image samples from a model and save them as a large
numpy array. This can be used to produce samples for FID evaluation.
"""

import argparse
import os
import time
import copy
import numpy as np
import torch as th
import torch.distributed as dist

from cm import dist_util, logger
from cm.script_util import (
    NUM_CLASSES,
    model_and_diffusion_defaults,
    #create_sr_model_and_diffusion,
    #decoder_, decoder__,
    cm_train_defaults,
    ctm_train_defaults,
    ctm_loss_defaults,
    ctm_data_defaults,
    create_model_and_diffusion,
    add_dict_to_argparser,
    args_to_dict,
)
from cm.random_util import get_generator
from cm.karras_diffusion import karras_sample
import blobfile as bf
from torchvision.utils import make_grid, save_image
#import classifier_lib


def main():
    args = create_argparser().parse_args()

    dist_util.setup_dist(args.device_id)
    logger.configure()

    if "consistency" in args.training_mode:
        distillation = True
    else:
        distillation = False

    logger.log("creating model and diffusion...")
    if args.data_name in ['church']:
        model_and_diffusion_kwargs = args_to_dict(
            args, model_and_diffusion_defaults().keys()
        )
        model_and_diffusion_kwargs["distillation"] = distillation
        vpsde = classifier_lib.vpsde(beta_min=args.beta_min, beta_max=args.beta_max, multiplier=args.multiplier)
        model, diffusion = create_sr_model_and_diffusion(lambda_conditioned=args.lambda_conditioned,
                                                         vpsde=vpsde, data_name=args.data_name, sigma_data=args.sigma_data,
                                                         **model_and_diffusion_kwargs)
        sd = th.load(args.model_path, map_location="cpu")
        if "state_dict" in list(sd.keys()):
            sd = sd["state_dict"]
        if 'author' in args.model_path:
            sd_ = copy.deepcopy(sd)
            for name in sd:
                print(name)
                if name.split('.')[0] in 'model':
                    if name.split('.')[1] == 'diffusion_model':
                        sd_['.'.join(name.split('.')[2:])] = sd[name]
            model.load_state_dict(sd_, strict=False)
            del sd_
        else:
            model.load_state_dict(sd)
        del sd
        #import sys
        #sys.exit()

        #sd = th.load(args.latent_decoder_path, map_location="cpu")
        #if "state_dict" in list(sd.keys()):
        #    sd = sd["state_dict"]
        sd = th.load(args.decoder_path, map_location="cpu")
        if "state_dict" in list(sd.keys()):
            sd = sd["state_dict"]
        decoder = decoder_(normalization_ckpt=args.normalization_ckpt, device=dist_util.dev())
        #decoder = decoder__()
        decoder.load_state_dict(sd, strict=False)
        decoder.to(dist_util.dev())
        decoder.eval()
        del sd
    elif args.data_name in ['ImageNet256']:
        raise NotImplementedError
    else:
        if args.training_mode == 'edm':
            model, diffusion = create_model_and_diffusion(args, teacher=True)
        else:
            model, diffusion = create_model_and_diffusion(args)

        #model, diffusion = create_model_and_diffusion(lambda_conditioned=args.lambda_conditioned,
        #                                              free_embedding=args.free_embedding,
        #    **args_to_dict(args, model_and_diffusion_defaults().keys()),
        #    distillation=distillation,
        #)
        try:#if not args.edm_nn_ncsn and not args.edm_nn_ddpm:
            model.load_state_dict(
                dist_util.load_state_dict(args.model_path, map_location="cpu")
            )
        except:
            print("model path not loaded")
    model.to(dist_util.dev())
    if args.use_fp16:
        model.convert_to_fp16()
    model.eval()

    logger.log("sampling...")
    if args.sampler == "multistep":
        assert len(args.ts) > 0
        ts = tuple(int(x) for x in args.ts.split(","))
    elif args.sampler == "exact":
        try:
            ts = tuple(int(x) for x in args.ts.split(","))
        except:
            ts = []
    else:
        ts = None

    if args.stochastic_seed:
        args.seed = np.random.randint(1000000)
    #generator = get_generator(args.generator, args.num_samples, args.seed)
    generator = get_generator(args.generator, args.num_samples, args.seed)


    step = args.model_path.split('.')[-2][-6:]
    try:
        ema = float(args.model_path.split('_')[-2])
        assert ema in [0.999, 0.9999, 0.9999432189950708]
    except:
        ema = 'model'
    if args.sampler in ['multistep', 'exact']:
        save_dir = os.path.join(args.save_dir, f'{args.training_mode}_{args.sampler}_sampler_{args.steps}_steps_{step}_itrs_{ema}_ema_{"".join([str(i) for i in ts])}')
    else:
        save_dir = os.path.join(args.save_dir,
                                f'{args.training_mode}_{args.sampler}_sampler_{args.steps}_steps_{step}_itrs_{ema}_ema')
    os.makedirs(save_dir, exist_ok=True)
    itr = 0
    num_samples = 0
    while itr * args.batch_size < args.num_samples:
        x_T = generator.randn(
            *(args.batch_size, args.in_channels, args.image_size, args.image_size),
            device=dist_util.dev()) * args.sigma_max
        classes = generator.randint(0, 1000, (args.batch_size,))
        print("x_T: ", x_T[0][0][0][0])
        current = time.time()
        model_kwargs = {}
        if args.class_cond:
            if args.train_classes >= 0:
                classes = th.ones(size=(args.batch_size,), device=dist_util.dev(), dtype=int) * int(args.train_classes)
            elif args.train_classes == -2:
                classes = [0, 1, 9, 11, 29, 31, 33, 55, 76, 89, 90, 130, 207, 250, 279, 281, 291, 323, 386, 387,
                           388, 417, 562, 614, 759, 789, 800, 812, 848, 933, 973, 980]
                assert args.batch_size % len(classes) == 0
                #print("!!!!!!!!!!!!!!: ", [x for x in classes for _ in range(args.batch_size // len(classes))])
                #model_kwargs["y"] = th.from_numpy(np.array([[[x] * (args.batch_size // len(classes)) for x in classes]]).reshape(-1)).to(dist_util.dev())
                classes = th.tensor([x for x in classes for _ in range(args.batch_size // len(classes))], device=dist_util.dev())
            else:
                classes = th.randint(
                    low=0, high=NUM_CLASSES, size=(args.batch_size,), device=dist_util.dev()
                )
            model_kwargs["y"] = classes
            print("classes: ", model_kwargs)
        with th.no_grad():
            x = karras_sample(
                diffusion=diffusion,
                model=model,
                shape=(args.batch_size, args.in_channels, args.image_size, args.image_size),
                steps=args.steps,
                model_kwargs=model_kwargs,
                device=dist_util.dev(),
                clip_denoised=False if args.data_name in ['church'] else True if args.training_mode=='edm' else args.clip_denoised,
                sampler=args.sampler,
                sigma_min=args.sigma_min,
                sigma_max=args.sigma_max,
                s_churn=args.s_churn,
                s_tmin=args.s_tmin,
                s_tmax=args.s_tmax,
                s_noise=args.s_noise,
                generator=generator,
                ts=ts,
                #teacher = (args.training_mode == 'edm'),
                teacher = False,
                clip_output=args.clip_output,
                ctm=True if args.training_mode.lower() == 'ctm' else False,
                x_T=x_T if args.stochastic_seed == False else None,
            )
            print(x[0])
            
            '''import pickle
            with open('/dataset/LSUN/church/latents.pickle', 'rb') as handle:
                data = pickle.load(handle)
            data = th.from_numpy(data)
            data = data.reshape(data.shape[0], -1)
            data_mean = data.mean(0)
            data_normalized = data - data_mean[None, :]
            print(data_normalized.shape)
            covariance = th.cov(data_normalized.T)
            print(covariance.shape)
            print("covariance: ", covariance)

            L = th.linalg.cholesky(covariance).to('cuda:0')
            print(L.shape)
            L_inv = th.inverse(L).to('cuda:0')
            print(L_inv.shape)

            # decorrelated_1 = torch.matmul(data_normalized, L_inv)
            decorrelated_2 = th.matmul(sample.reshape(sample.shape[0],-1), L_inv.T)
            decorrelated_2 = decorrelated_2 + 0.1 * th.randn_like(decorrelated_2)
            sample = th.matmul(decorrelated_2, L.T).reshape(sample.shape[0],4,32,32)'''

            #x = decoder(sample + std[None,:,:,:] * 0.1 * th.randn_like(sample), teacher=(args.training_mode == 'edm'), mult=args.mult)
            if args.data_name in ['church']:
                x = decoder(x, teacher=(args.training_mode == 'edm'), mult=args.mult)

        sample = ((x + 1) * 127.5).clamp(0, 255).to(th.uint8)
        sample = sample.permute(0, 2, 3, 1)
        sample = sample.contiguous()
        itr += 1

        if dist.get_rank() == 0:
            sample = sample.cpu().detach()
            print(f"{itr * args.batch_size} sampling complete")
            r = np.random.randint(1000000)
            if args.save_format == 'npz':
                if args.class_cond:
                    np.savez(os.path.join(save_dir, f"sample_{r}.npz"), sample.numpy(), classes.cpu().detach().numpy())
                else:
                    np.savez(os.path.join(save_dir, f"sample_{r}.npz"), sample.numpy())
            if args.save_format == 'png' or itr == 1:
                print("x range: ", x.min(), x.max())
                nrow = int(np.sqrt(sample.shape[0]))
                image_grid = make_grid((x + 1.) / 2., nrow, padding=2)
                if args.class_cond:
                    with bf.BlobFile(os.path.join(save_dir, f"class_{args.train_classes}_sample_{r}.png"), "wb") as fout:
                        save_image(image_grid, fout)
                else:
                    with bf.BlobFile(os.path.join(save_dir, f"sample_{r}.png"), "wb") as fout:
                        save_image(image_grid, fout)
        num_samples += sample.shape[0]
        print(f"sample {num_samples} time {time.time() - current} sec")
        '''gathered_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())]
        dist.all_gather(gathered_samples, sample)  # gather not supported with NCCL
        all_images.extend([sample.cpu().numpy() for sample in gathered_samples])
        if args.class_cond:
            gathered_labels = [
                th.zeros_like(classes) for _ in range(dist.get_world_size())
            ]
            dist.all_gather(gathered_labels, classes)
            all_labels.extend([labels.cpu().numpy() for labels in gathered_labels])
        logger.log(f"created {len(all_images) * args.batch_size} samples")'''

    #arr = np.concatenate(all_images, axis=0)
    #arr = arr[: args.num_samples]
    #if args.class_cond:
    #    label_arr = np.concatenate(all_labels, axis=0)
    #    label_arr = label_arr[: args.num_samples]
    '''if dist.get_rank() == 0:
        #shape_str = "x".join([str(x) for x in arr.shape])
        #out_path = os.path.join(logger.get_dir(), f"samples_{shape_str}.npz")
        os.makedirs(args.save_dir, exist_ok=True)
        logger.log(f"saving to {args.save_dir}")
        r = np.random.randint(1000000)
        if args.class_cond:
            np.savez(os.path.join(args.save_dir, f"sample_{r}_0.0.png"), arr, label_arr)
        else:
            np.savez(os.path.join(args.save_dir, f"sample_{r}_0.0.png"), arr)
        nrow = int(np.sqrt(arr.shape[0]))
        image_grid = make_grid(th.tensor(arr).permute(0, 3, 1, 2) / 255., nrow, padding=2)
        with tf.io.gfile.GFile(os.path.join(args.save_dir, f"sample_{r}_0.0.png"), "wb") as fout:
            save_image(image_grid, fout)'''

    dist.barrier()
    logger.log("sampling complete")

def create_argparser():
    defaults = dict(
        training_mode="edm",
        generator="determ",
        clip_denoised=True,
        num_samples=10000,
        batch_size=16,
        sampler="heun",
        s_churn=0.0,
        s_tmin=0.0,
        s_tmax=float("inf"),
        s_noise=1.0,
        steps=40,
        model_path="",
        decoder_path="",
        seed=42,
        ts="",
        save_dir="",
        device_id=0,
        classes=-1,
        num_classes = 10,
        save_format='png',
        stochastic_seed=False,
        parametrization='euler',
        data_name='cifar10',
        beta_min=0.1,
        beta_max=20.,
        multiplier=1.,
        latent_decoder_path="",
        clip_output=True,
        sigma_data=0.5,
        schedule_sampler="lognormal",
    )
    defaults.update(model_and_diffusion_defaults())
    defaults.update(cm_train_defaults())
    defaults.update(ctm_train_defaults())
    defaults.update(ctm_loss_defaults())
    defaults.update(ctm_data_defaults())
    parser = argparse.ArgumentParser()
    add_dict_to_argparser(parser, defaults)
    return parser


if __name__ == "__main__":
    main()
