import argparse
import os

import numpy as np
np.set_printoptions(threshold=np.sys.maxsize)
import torch
torch.manual_seed(500)
np.random.seed(1)
import torch.nn.functional as F
from PIL import Image

import sys
sys.path.append("..")
from guided_diffusion import dist_util, logger
from guided_diffusion.script_util import (
    NUM_CLASSES,
    classifier_defaults,
    create_classifier,
    add_dict_to_argparser,
    args_to_dict,
)
from guided_diffusion.pruned_script_util import create_model_and_diffusion,model_and_diffusion_defaults
from dpm_solver_pytorch import NoiseScheduleVP, model_wrapper, DPM_Solver
from guided_diffusion.script_util import model_and_diffusion_defaults as std_model_and_diffusion_defaults
from guided_diffusion.script_util import create_model_and_diffusion as std_create_model_and_diffusion


def create_argparser():
    defaults = dict(
        clip_denoised=True,
        num_samples=10000,
        batch_size=16,
        use_ddim=False,
        model_path="",
        classifier_path="",
        classifier_scale=1.0,
    )
    defaults.update(model_and_diffusion_defaults())
    defaults.update(classifier_defaults())
    parser = argparse.ArgumentParser()
    add_dict_to_argparser(parser, defaults)
    parser.add_argument(
        "--label",type=int,default=207,help="Label for pruning"
    )
    parser.add_argument(
        "--steps", type=int, default=20, help="dpm-solver steps"
    )
    parser.add_argument(
        "--order", type=int, default=2, help="dpm-solver order"
    )
    parser.add_argument(
        "--save_path",type=str,default=" "
    )
    parser.add_argument(
        "--standard_model_path", type=str, help="The path to load the standard model"
    )
    parser.add_argument(
        "--start_step_of_pruned", type=int, default=-1,
        help="The first id of steps to use pruned model. If it's -1, use original model completely."
    )
    parser.add_argument(
        "--size_of_group", type=int, default=5
    )
    return parser

def save_torch_example(img, name, idxs, base_path):
    path = base_path + '/' + name + 'id_in_batch_is_'
    img = ((img + 1) * 127.5).clamp(0, 255).to(torch.uint8)
    img = img.permute(0, 2, 3, 1)
    img = img.contiguous()
    img = img.cpu().numpy().astype(np.uint8)
    for i in idxs:
        pathi = path + str(i) + '.png'
        imgi = Image.fromarray(img[i])
        imgi.save(pathi)

def num_of_para(ckpt):
    num_of_paras = 0
    for k,v in ckpt.items():
        tmp = 1
        for i in range(len(v.shape)):
            tmp = tmp * v.shape[i]
        num_of_paras += tmp
    return num_of_paras

def get_real_t(t,steps,classes):
    t = t*steps/classes
    t = t.type(torch.int64)
    t = list(t.cpu().numpy())
    t = t[0]
    return t


def main():
    args = create_argparser().parse_args()

    dist_util.setup_dist(os.environ["CUDA_VISIBLE_DEVICES"])
    logger.configure(dir = args.save_path)


    '''------------------------------------- Create model and diffusion -------------------------------------'''
    logger.log("creating model and diffusion...")
    model, diffusion = create_model_and_diffusion(
        **args_to_dict(args, model_and_diffusion_defaults().keys())
    )
    logger.log("Load pruned model from:",args.model_path)

    ckpt = torch.load(args.model_path, map_location="cpu")
    paras_num = num_of_para(ckpt)
    logger.log("The number of parameters in pruned model is", paras_num)


    model.load_state_dict(
        dist_util.load_state_dict(args.model_path, map_location="cpu")
    )
    model.to(dist_util.dev())
    if args.use_fp16:
        model.convert_to_fp16()
    model.eval()

    std_model, _ = std_create_model_and_diffusion(
        **args_to_dict(args, std_model_and_diffusion_defaults().keys())
    )
    std_model.load_state_dict(
        dist_util.load_state_dict(args.standard_model_path, map_location="cpu")
    )
    std_model.to(dist_util.dev())
    if args.use_fp16:
        std_model.convert_to_fp16()
    std_model.eval()

    logger.log("loading classifier...")
    classifier = create_classifier(**args_to_dict(args, classifier_defaults().keys()))
    classifier.load_state_dict(
        dist_util.load_state_dict(args.classifier_path, map_location="cpu")
    )
    classifier.to(dist_util.dev())
    if args.classifier_use_fp16:
        classifier.convert_to_fp16()
    classifier.eval()

    if args.start_step_of_pruned != -1:
        start_id = args.start_step_of_pruned
        steps_use_pruned = list(range(start_id, start_id+args.size_of_group))
    else:
        steps_use_pruned = [-1]
    logger.log("Use pruned model in the list of steps:", steps_use_pruned)


    '''------------------------------------- Prepare sampling -------------------------------------'''
    label = args.label
    labels = torch.ones(args.batch_size,device=dist_util.dev())
    labels = labels*label
    labels = labels.type(torch.int64)
    if args.label != -1:
        logger.log("The label used for sampling is", label)


    def cond_fn(x, t, y=None):
        assert y is not None
        with torch.enable_grad():
            x_in = x.detach().requires_grad_(True)
            logits = classifier(x_in, t)
            log_probs = F.log_softmax(logits, dim=-1)
            selected = log_probs[range(len(logits)), y.view(-1)]
            return torch.autograd.grad(selected.sum(), x_in)[0]

    mk1 = {}
    mk1['y'] = labels

    def model_fn0(x, t, y=None):
        assert y is not None
        real_t = get_real_t(t, args.steps, args.diffusion_steps)
        if real_t in steps_use_pruned:
            return model(x, t, y if args.class_cond else None)
        else:
            return std_model(x, t, y if args.class_cond else None)
    
    betas = torch.tensor(diffusion.betas)
    model_kwargs = {}
    classifier_kwargs = {}
    noise_schedule = NoiseScheduleVP(schedule='discrete', betas=betas)
    model_fn = model_wrapper(
        model_fn0,
        noise_schedule,
        model_type="noise",
        model_kwargs=model_kwargs,
        guidance_type="classifier",
        condition=labels,
        guidance_scale=2.5,
        classifier_fn=cond_fn,
        classifier_kwargs=classifier_kwargs,
    )

    dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True,thresholding=True,max_val=1.0)
    shape = (args.batch_size, 3, args.image_size, args.image_size)


    '''---------------------------------------------- Sampling ----------------------------------------------'''
    for i in range(int(args.num_samples/args.batch_size)):
        x_T = torch.randn(*shape, device=dist_util.dev())
        if args.label == -1:
            labels = torch.randint(low=0, high=NUM_CLASSES, size=(args.batch_size,), device=dist_util.dev())
            model_fn_random_label = model_wrapper(
                model_fn0,
                noise_schedule,
                model_type="noise",
                model_kwargs=model_kwargs,
                guidance_type="classifier",
                condition=labels,
                guidance_scale=2.5,
                classifier_fn=cond_fn,
                classifier_kwargs=classifier_kwargs,
            )
            dpm_solver = DPM_Solver(model_fn_random_label, noise_schedule, predict_x0=True,thresholding=True,max_val=1.0)

        logger.log("sampling",i+1,"/",int(args.num_samples/args.batch_size))
        x_sample = dpm_solver.sample(
            x_T,
            steps=args.steps,
            order=args.order,
            skip_type="time_uniform",
            method="multistep",
        )
        save_torch_example(x_sample, 'sample_'+str(i)+'_batch_', list(range(args.batch_size)), args.save_path)
    logger.log("sample finished")


if __name__ == "__main__":
    main()
