import argparse
import os

import numpy as np
np.set_printoptions(threshold=np.sys.maxsize)
import torch
torch.manual_seed(500)
np.random.seed(1)
import torch.nn.functional as F
from PIL import Image

import sys
sys.path.append("..")
from guided_diffusion import dist_util, logger
from guided_diffusion.script_util import (
    create_model_and_diffusion,
    model_and_diffusion_defaults,
    NUM_CLASSES,
    classifier_defaults,
    create_classifier,
    add_dict_to_argparser,
    args_to_dict,
)
from guided_diffusion.train_util import get_scales, get_gate, gates2index, expand_gates
from guided_diffusion.unet import FNN
from dpm_solver_pytorch import NoiseScheduleVP, model_wrapper, DPM_Solver, get_dpmsolver_ts


def create_argparser():
    defaults = dict(
        clip_denoised=True,
        num_samples=10000,
        batch_size=16,
        use_ddim=False,
        model_path="",
        classifier_path="",
        classifier_scale=1.0,
    )
    defaults.update(model_and_diffusion_defaults())
    defaults.update(classifier_defaults())
    parser = argparse.ArgumentParser()
    add_dict_to_argparser(parser, defaults)
    parser.add_argument(
        "--label",type=int,default=207, help="Label for pruning"
    )
    parser.add_argument(
        "--steps", type=int, default=20, help="dpm-solver steps"
    )
    parser.add_argument(
        "--order", type=int, default=2, help="dpm-solver order"
    )
    parser.add_argument(
        "--save_path",type=str, default=""
    )
    parser.add_argument(
        "--step_respacing", type=int, default=-1, help="A Utility variable for dpm-solver"
    )
    parser.add_argument(
        "--gater_path", type=str, default=""
    )

    return parser

def save_torch_example(img, name, idxs, base_path):
    path = base_path + '/' + name + 'id_in_batch_is_'
    img = ((img + 1) * 127.5).clamp(0, 255).to(torch.uint8)
    img = img.permute(0, 2, 3, 1)
    img = img.contiguous()
    img = img.cpu().numpy().astype(np.uint8)
    for i in idxs:
        pathi = path + str(i) + '.png'
        imgi = Image.fromarray(img[i])
        imgi.save(pathi)

def num_of_para(ckpt):
    num_of_paras = 0
    for k,v in ckpt.items():
        tmp = 1
        for i in range(len(v.shape)):
            tmp = tmp * v.shape[i]
        num_of_paras += tmp
    return num_of_paras


def main():
    args = create_argparser().parse_args()

    dist_util.setup_dist(os.environ["CUDA_VISIBLE_DEVICES"])
    logger.configure(dir = args.save_path)


    ##########################################
    fnn_ckpt = torch.load(args.gater_path, map_location="cpu")
    out_size = 0
    for k in fnn_ckpt.keys():
        out_size = int(fnn_ckpt[k].shape[0])

    fnn = FNN(dims=[args.t_spli_num, 64, out_size], threshold=args.threshold)
    fnn.load_state_dict(
        fnn_ckpt,
    )
    fnn.to(dist_util.dev())
    fnn.eval()

    layer_used_num_list = get_scales(fnn, args.t_spli_num)
    args.layer_used_num_list = layer_used_num_list
    ##########################################


    '''------------------------------------- Create model and diffusion -------------------------------------'''
    logger.log("creating model and diffusion...")
    model, diffusion = create_model_and_diffusion(
        **args_to_dict(args, model_and_diffusion_defaults().keys())
    )
    logger.log("Load model from:",args.model_path)


    model.load_state_dict(
        dist_util.load_state_dict(args.model_path, map_location="cpu")
    )
    model.to(dist_util.dev())
    if args.use_fp16:
        model.convert_to_fp16()
    model.eval()


    logger.log("loading classifier...")
    classifier = create_classifier(**args_to_dict(args, classifier_defaults().keys()))
    classifier.load_state_dict(
        dist_util.load_state_dict(args.classifier_path, map_location="cpu")
    )
    classifier.to(dist_util.dev())
    if args.classifier_use_fp16:
        classifier.convert_to_fp16()
    classifier.eval()


    '''------------------------------------------ Prepare sampling ------------------------------------------'''
    label = args.label
    labels = torch.ones(args.batch_size,device=dist_util.dev())
    labels = labels*label
    labels = labels.type(torch.int64)
    if args.label != -1:
        logger.log("The label used for sampling is", label)


    def cond_fn(x, t, y=None):
        assert y is not None
        with torch.enable_grad():
            x_in = x.detach().requires_grad_(True)
            logits = classifier(x_in, t)
            log_probs = F.log_softmax(logits, dim=-1)
            selected = log_probs[range(len(logits)), y.view(-1)]
            return torch.autograd.grad(selected.sum(), x_in)[0]

    mk1 = {}
    mk1['y'] = labels

    ################################################
    logger.log("Prepare gates and index for layers...")
    used_ts = get_dpmsolver_ts(args.step_respacing, args.steps)

    gates = get_gate(fnn, used_ts, args.t_spli_num, model.steps)
    # idx_list = gates2index(gates, used_ts, model.expand_scale_list, model.steps, se=True)
    # gates = expand_gates(gates, self.model.expand_scale_list, t, self.model.t_spli_num, self.model.steps)

    total_steps = model.steps
    ################################################
    def model_fn0(x, t, y=None):
        assert y is not None
        # t_th = int(t[0].item() * args.steps / total_steps)
        # layer_idx = idx_list[t_th]
        # print(layer_idx)
        gates = get_gate(fnn, t, args.t_spli_num, model.steps)
        gates = expand_gates(gates, model.expand_scale_list, t, model.t_spli_num, model.steps)
        # print("确认t", t[0])
        # print(gates)

        # return model.sample(x, t, y if args.class_cond else None, layer_idx)
        return model(x, t, y if args.class_cond else None, True, gates)
    
    betas = torch.tensor(diffusion.betas)
    model_kwargs = {}
    classifier_kwargs = {}
    noise_schedule = NoiseScheduleVP(schedule='discrete', betas=betas)
    model_fn = model_wrapper(
        model_fn0,
        noise_schedule,
        model_type="noise",
        model_kwargs=model_kwargs,
        guidance_type="classifier",
        condition=labels,
        guidance_scale=2.5,
        classifier_fn=cond_fn,
        classifier_kwargs=classifier_kwargs,
    )

    dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True,thresholding=True,max_val=1.0)
    shape = (args.batch_size, 3, args.image_size, args.image_size)


    '''---------------------------------------------- Sampling ----------------------------------------------'''
    for i in range(int(args.num_samples/args.batch_size)):
        x_T = torch.randn(*shape, device=dist_util.dev())
        if args.label == -1:
            labels = torch.randint(low=0, high=NUM_CLASSES, size=(args.batch_size,), device=dist_util.dev())
            model_fn_random_label = model_wrapper(
                model_fn0,
                noise_schedule,
                model_type="noise",
                model_kwargs=model_kwargs,
                guidance_type="classifier",
                condition=labels,
                guidance_scale=2.5,
                classifier_fn=cond_fn,
                classifier_kwargs=classifier_kwargs,
            )
            dpm_solver = DPM_Solver(model_fn_random_label, noise_schedule, predict_x0=True,thresholding=True,max_val=1.0)

        logger.log("sampling",i+1,"/",int(args.num_samples/args.batch_size))
        x_sample = dpm_solver.sample(
            x_T,
            steps=args.steps,
            order=args.order,
            skip_type="time_uniform",
            method="multistep",
        )
        save_torch_example(x_sample, 'sample_'+str(i)+'_batch_', list(range(args.batch_size)), args.save_path)
    logger.log("sample finished")


if __name__ == "__main__":
    main()
