import argparse
import os

import numpy as np
np.set_printoptions(threshold=np.sys.maxsize)
import torch as th
th.manual_seed(0)
np.random.seed(0)
import torch.nn.functional as F
import sys
from PIL import Image
sys.path.append("..")

from dpm_solver_pytorch import NoiseScheduleVP, model_wrapper, DPM_Solver
from guided_diffusion import dist_util, logger
from guided_diffusion.script_util import (
    NUM_CLASSES,
    model_and_diffusion_defaults,
    classifier_defaults,
    create_model_and_diffusion,
    create_classifier,
    add_dict_to_argparser,
    args_to_dict,
)


def save_torch_example(img, name, idxs, base_path):
    path = base_path + '/' + name + 'batch_is_'
    img = ((img + 1) * 127.5).clamp(0, 255).to(th.uint8)
    img = img.permute(0, 2, 3, 1)
    img = img.contiguous()
    img = img.cpu().numpy().astype(np.uint8)
    for i in idxs:
        pathi = path + str(i) + '.png'
        imgi = Image.fromarray(img[i])
        imgi.save(pathi)

def get_module_type(m,name):
    names = name.split('.')
    mark = '1st_order'
    tag1 = None
    tag2 = None
    for s in names:
        if s != 'weight' and s != 'bias':
            if s.isdigit():
                m = m[int(s)]
            elif mark == '1st_order':
                if s == 'time_embed':
                    m = m.time_embed
                elif s == 'label_emb':
                    m = m.label_emb
                elif s == 'input_blocks':
                    m = m.input_blocks
                elif s == 'middle_block':
                    m = m.middle_block
                elif s == 'output_blocks':
                    m = m.output_blocks
                elif s == 'out':
                    m = m.out
                else:
                    raise Exception('Unknown 1st_order module name!')
                mark = '2nd_order'
            elif mark == '2nd_order':
                if s =='in_layers':
                    m = m.in_layers
                elif s == 'emb_layers':
                    m = m.emb_layers
                elif s == 'out_layers':
                    m = m.out_layers
                elif s == 'skip_connection':
                    m = m.skip_connection
                elif s == 'norm':
                    m = m.norm
                elif s == 'qkv':
                    m = m.qkv
                elif s == 'proj_out':
                    m = m.proj_out
                else:
                    raise Exception('Unknown 2nd_order module name!')
        else:
            tag2 = s
            break
    tag1 = str(m)
    if tag1.startswith('Conv2d'):
        tag1 = 'Conv2d'
    elif tag1.startswith('Linear'):
        tag1 = 'Linear'
    elif tag1.startswith('GroupNorm'):
        tag1 = 'GroupNorm'
    elif tag1.startswith('Conv1d'):
        tag1 = 'Conv1d'
    return tag1,tag2

def create_argparser():
    defaults = dict(
        clip_denoised=True,
        num_samples=10000,
        batch_size=16,
        use_ddim=False,
        model_path="",
        classifier_path="",
        classifier_scale=1.0,
    )
    defaults.update(model_and_diffusion_defaults())
    defaults.update(classifier_defaults())
    parser = argparse.ArgumentParser()
    add_dict_to_argparser(parser, defaults)
    parser.add_argument(
        "--label",type=int,default=207,help="Label for pruning"
    )
    parser.add_argument(
        "--steps", type=int, default=20, help="dpm-solver steps"
    )
    parser.add_argument(
        "--order", type=int, default=2, help="dpm-solver order"
    )
    parser.add_argument(
        "--save_path",type=str,default=" "
    )
    parser.add_argument(
        "--begin_id", type=int, default=0
    )
    parser.add_argument(
        "--split_size", type=int, default=30
    )
    return parser


def main():
    args = create_argparser().parse_args()
    dist_util.setup_dist(os.environ["CUDA_VISIBLE_DEVICES"])
    logger.configure(dir = args.save_path)


    '''------------------------------------- Create model and diffusion -------------------------------------'''
    logger.log("creating model and diffusion...")
    model, diffusion = create_model_and_diffusion(
        **args_to_dict(args, model_and_diffusion_defaults().keys())
    )

    ckpt = th.load(args.model_path, map_location="cpu")

    logger.log("loading classifier...")
    classifier = create_classifier(**args_to_dict(args, classifier_defaults().keys()))
    classifier.load_state_dict(
        dist_util.load_state_dict(args.classifier_path, map_location="cpu")
    )
    classifier.to(dist_util.dev())
    if args.classifier_use_fp16:
        classifier.convert_to_fp16()
    classifier.eval()


    '''----------------------------------------- Prepare DPM Solver -----------------------------------------'''
    logger.log("The label used for pruning is", args.label)
    labels = th.ones(args.batch_size,device=dist_util.dev())
    labels = labels*args.label
    labels = labels.type(th.int64)

    def cond_fn(x, t, y=None):
        assert y is not None
        with th.enable_grad():
            x_in = x.detach().requires_grad_(True)
            logits = classifier(x_in, t)
            log_probs = F.log_softmax(logits, dim=-1)
            selected = log_probs[range(len(logits)), y.view(-1)]
            return th.autograd.grad(selected.sum(), x_in)[0]

    model.load_state_dict(ckpt)
    model.to(dist_util.dev())
    if args.use_fp16:
        model.convert_to_fp16()
    model.eval()

    def model_fn_g(x, t, y=None):
        assert y is not None
        return model(x, t, y if args.class_cond else None)

    betas = th.tensor(diffusion.betas)
    model_kwargs = {}
    classifier_kwargs = {}
    noise_schedule = NoiseScheduleVP(schedule='discrete', betas=betas)
    model_fn = model_wrapper(
        model_fn_g,
        noise_schedule,
        model_type="noise",
        model_kwargs=model_kwargs,
        guidance_type="classifier",
        condition=labels,
        guidance_scale=2.5,
        classifier_fn=cond_fn,
        classifier_kwargs=classifier_kwargs,
    )
    dpm_solver_g = DPM_Solver(model_fn, noise_schedule, predict_x0=True,thresholding=True,max_val=1.0)
    shape = (args.batch_size, 3, args.image_size, args.image_size)

    x_T = th.randn(*shape, device=dist_util.dev())


    '''------------------------------------- Sample and get score list --------------------------------------'''
    logger.log("getting the ground truth first...")
    ground_truth_img = dpm_solver_g.sample(
        x_T,
        steps=args.steps,
        order=args.order,
        skip_type="time_uniform",
        method="multistep",
    )
    save_torch_example(ground_truth_img, 'ground_truth_', list(range(args.batch_size)), args.save_path)


    num = 0
    out_ch_nums = []
    in_ch_nums = []
    for k,v in ckpt.items():
        layer_type, tag2 = get_module_type(model, k)
        if layer_type == 'Conv1d' and k.endswith('weight'):
            in_ch_nums.append(int(v.shape[1]))
        if layer_type == 'Conv1d' and k.endswith('bias'):
            out_ch_nums.append(int(v.shape[0]))
            num += 1

    logger.log("This script will deal with", args.split_size, "layers.")

    out_score_list = []
    in_score_list = []
    for layerid_ in range(args.split_size):
        layerid = args.begin_id + layerid_
        if layerid == num:
            break
        layer_prunes = [layerid]
        for numm in range(out_ch_nums[layerid]):
            th.manual_seed(0)
            ckpt = th.load(args.model_path, map_location="cpu")
            conv1d_id = -1

            for k,v in ckpt.items():
                layer_type, tag2 = get_module_type(model, k)
                if layer_type == 'Conv1d':
                    if k.endswith('weight'):
                        conv1d_id += 1
                        if conv1d_id in layer_prunes:
                            logger.log("Dealing with layer", layer_prunes, ", it has", v.shape[0], "out channels in total.")
                            out_ch_prune = [numm]
                            logger.log("Set the out channel", out_ch_prune, "to zero.")
                            for ch in out_ch_prune:
                                v[ch] = 0 * v[ch]
                    if k.endswith('bias'):
                        if conv1d_id in layer_prunes:
                            for ch in out_ch_prune:
                                v[ch] = 0 * v[ch]

            model.load_state_dict(ckpt)
            model.to(dist_util.dev())
            if args.use_fp16:
                model.convert_to_fp16()
            model.eval()

            def model_fn0(x, t, y=None):
                assert y is not None
                return model(x, t, y if args.class_cond else None)
            
            model_fn = model_wrapper(
                model_fn0,
                noise_schedule,
                model_type="noise",
                model_kwargs=model_kwargs,
                guidance_type="classifier",
                condition=labels,
                guidance_scale=2.5,
                classifier_fn=cond_fn,
                classifier_kwargs=classifier_kwargs,
            )
            dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True,thresholding=True,max_val=1.0)
            shape = (args.batch_size, 3, args.image_size, args.image_size)

            x_T = th.randn(*shape, device=dist_util.dev())

            x_sample = dpm_solver.sample(
                x_T,
                steps=args.steps,
                order=args.order,
                skip_type="time_uniform",
                method="multistep",
            )
            if not os.path.exists(args.save_path + '/sample_layer_'+str(layerid)):
                os.mkdir(args.save_path + '/sample_layer_'+str(layerid))
            save_torch_example(x_sample, 'sample_newest_out', [0], args.save_path + '/sample_layer_'+str(layerid))
            loss = th.sum(th.abs(ground_truth_img - x_sample))
            logger.log("The loss of this pruning is", float(loss), ".")
            out_score_list.append(float(loss))
            np.save(args.save_path + '/out_score_list.npy',out_score_list)
            logger.log("The length of the out score list is", len(out_score_list))
            logger.log("########################################################")

        for numm in range(in_ch_nums[layerid]):
            th.manual_seed(0)
            ckpt = th.load(args.model_path, map_location="cpu")
            conv1d_id = -1

            for k,v in ckpt.items():
                layer_type, tag2 = get_module_type(model, k)
                if layer_type == 'Conv1d':
                    if k.endswith('weight'):
                        conv1d_id += 1
                        if conv1d_id in layer_prunes:
                            logger.log("Dealing with layer", layer_prunes, ", it has", v.shape[1], "in channels in total.")
                            in_ch_prune = [numm]
                            logger.log("Set the in channel", in_ch_prune, "to zero.")
                            for ch in in_ch_prune:
                                v[:,ch] = 0 * v[:,ch]

            model.load_state_dict(ckpt)
            model.to(dist_util.dev())
            if args.use_fp16:
                model.convert_to_fp16()
            model.eval()

            def model_fn0(x, t, y=None):
                assert y is not None
                return model(x, t, y if args.class_cond else None)
            
            model_fn = model_wrapper(
                model_fn0,
                noise_schedule,
                model_type="noise",
                model_kwargs=model_kwargs,
                guidance_type="classifier",
                condition=labels,
                guidance_scale=2.5,
                classifier_fn=cond_fn,
                classifier_kwargs=classifier_kwargs,
            )
            dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True,thresholding=True,max_val=1.0)
            shape = (args.batch_size, 3, args.image_size, args.image_size)

            x_T = th.randn(*shape, device=dist_util.dev())

            x_sample = dpm_solver.sample(
                x_T,
                steps=args.steps,
                order=args.order,
                skip_type="time_uniform",
                method="multistep",
            )
            if not os.path.exists(args.save_path + '/sample_layer_'+str(layerid)):
                os.mkdir(args.save_path + '/sample_layer_'+str(layerid))
            save_torch_example(x_sample, 'sample_newest_in', [0], args.save_path + '/sample_layer_'+str(layerid))
            loss = th.sum(th.abs(ground_truth_img - x_sample))
            logger.log("The loss of this pruning is", float(loss), ".")
            in_score_list.append(float(loss))
            np.save(args.save_path + '/in_score_list.npy',in_score_list)
            logger.log("The length of the in score list is", len(in_score_list))
            logger.log("########################################################")

    logger.log('finish.')


if __name__ == "__main__":
    main()
