import argparse
import os

import numpy as np
np.set_printoptions(threshold=np.sys.maxsize)
import torch

import sys
sys.path.append("..")
from guided_diffusion import dist_util, logger
from guided_diffusion.script_util import (
    model_and_diffusion_defaults,
    classifier_defaults,
    create_rep_model_and_diffusion,
    create_model_and_diffusion,
    add_dict_to_argparser,
    args_to_dict,
)
from pruner.utils import (
    align_skip_channel,
    pruning_conv1d, sort_conv1d_score,
    pruning_lin, align_lin, sort_lin_score,
    get_rep_score, sort_rep_score, pruning_rep,
)

def create_argparser():
    defaults = dict(
        clip_denoised=True,
        num_samples=10000,
        batch_size=16,
        use_ddim=False,
        model_path="",
        classifier_path="",
        classifier_scale=1.0,
    )
    defaults.update(model_and_diffusion_defaults())
    defaults.update(classifier_defaults())
    parser = argparse.ArgumentParser()
    add_dict_to_argparser(parser, defaults)
    parser.add_argument(
        "--rep_path", type=str, help="The path saving the rep model"
    )
    parser.add_argument(
        "--thresh", type=float, default=0.001, help="The thresh used to prune Conv2d's out channel"
    )
    parser.add_argument(
        "--conv1d_scale", type=float, default=0.0, help="Percentage of channels to be pruned in Conv1d"
    )
    parser.add_argument(
        "--linear_scale", type=float, default=0.0, help="Percentage of channels to be pruned in Linear"
    )
    parser.add_argument(
        "--conv1d_score_path", type=str, default=" ", help="The path saving the conv1d score lists"
    )
    parser.add_argument(
        "--linear_score_path", type=str, default=" ", help="The path saving the linear score lists"
    )

    return parser

def get_module_type(m,name):
    names = name.split('.')
    mark = '1st_order'
    tag1 = None
    tag2 = None
    for s in names:
        if s != 'weight' and s != 'bias':
            if s.isdigit():
                m = m[int(s)]
            elif mark == '1st_order':
                if s == 'time_embed':
                    m = m.time_embed
                elif s == 'label_emb':
                    m = m.label_emb
                elif s == 'input_blocks':
                    m = m.input_blocks
                elif s == 'middle_block':
                    m = m.middle_block
                elif s == 'output_blocks':
                    m = m.output_blocks
                elif s == 'out':
                    m = m.out
                else:
                    raise Exception('Unknown 1st_order module name!')
                mark = '2nd_order'
            elif mark == '2nd_order':
                if s =='in_layers':
                    m = m.in_layers
                elif s == 'emb_layers':
                    m = m.emb_layers
                elif s == 'out_layers':
                    m = m.out_layers
                elif s == 'skip_connection':
                    m = m.skip_connection
                elif s == 'skip_com':
                    m = m.skip_com
                elif s == 'pwc':
                    m = m.pwc
                elif s == 'mask':
                    m = m.mask
                elif s == 'norm':
                    m = m.norm
                elif s == 'qkv':
                    m = m.qkv
                elif s == 'proj_out':
                    m = m.proj_out
                else:
                    logger.log(s)
                    raise Exception('Unknown 2nd_order module name!')
        else:
            tag2 = s
            break
    tag1 = str(m)
    if tag1.startswith('Conv2d'):
        tag1 = 'Conv2d'
    elif tag1.startswith('Linear'):
        tag1 = 'Linear'
    elif tag1.startswith('GroupNorm'):
        tag1 = 'GroupNorm'
    elif tag1.startswith('Conv1d'):
        tag1 = 'Conv1d'
    return tag1,tag2

def num_of_para(ckpt):
    num_of_paras = 0
    for k,v in ckpt.items():
        tmp = 1
        for i in range(len(v.shape)):
            tmp = tmp * v.shape[i]
        num_of_paras += tmp
    return num_of_paras


def main():
    args = create_argparser().parse_args()

    dist_util.setup_dist(os.environ["CUDA_VISIBLE_DEVICES"])
    logger.configure()


    '''------------------------------------- Create model and diffusion -------------------------------------'''
    logger.log("creating model and diffusion...")
    model, diffusion = create_model_and_diffusion(
        **args_to_dict(args, model_and_diffusion_defaults().keys())
    )
    layer_after_att = model.layer_after_att

    rep_model, _ = create_rep_model_and_diffusion(
        **args_to_dict(args, model_and_diffusion_defaults().keys())
    )


    '''-------------------------------------- Prepare the rep score list ------------------------------------'''
    ori_ckpt = torch.load(args.model_path, map_location="cpu")
    out_c_of_para = []
    idx_of_scores = [0]
    idx = 0
    for k,p in model.named_parameters():
        type_p,_ = get_module_type(model,k)
        if (k.endswith('weight')) and (type_p == 'Conv2d'):
            out_c_of_para.append(p.shape[0])
            idx += p.shape[0]
            idx_of_scores.append(idx)
    idx_of_scores.pop()

    rep_scores = get_rep_score(args.rep_path, idx_of_scores)
    rep_scores = align_skip_channel(model, rep_scores, idx_of_scores, layer_after_att, get_module_type)
    cpiel, best_channels, pruned_kernel, pruned_bias = sort_rep_score(
        model,
        rep_model,
        args.model_path,
        args.rep_path,
        rep_scores,
        idx_of_scores,
        out_c_of_para,
        args.thresh,
        get_module_type,
    )

    '''------------------------------------------ Prune the model -------------------------------------------'''
    logger.log("pruning the Conv2d layers...")
    new_ckpt, reverse_info = pruning_rep(
        model,
        'Conv2d',
        get_module_type,
        args.model_path,
        cpiel,
        layer_after_att,
        best_channels,
        pruned_kernel,
        pruned_bias,
        args.rep_path,
    )


    logger.log("pruning the Conv1d layers...")
    assert args.conv1d_score_path != " "
    conv1d_in_score_list = np.load(args.conv1d_score_path+'/conv1d_in_score.npy')
    conv1d_out_score_list = np.load(args.conv1d_score_path+'/conv1d_out_score.npy')

    in_c_of_para = []
    out_c_of_para = []
    in_idx_of_scores = [0]
    out_idx_of_scores = [0]
    in_idx = 0
    out_idx = 0
    for k,p in model.named_parameters():
        type_p,_ = get_module_type(model,k)
        if (k.endswith('weight')) and (type_p == 'Conv1d'):
            in_c_of_para.append(p.shape[1])
            out_c_of_para.append(p.shape[0])
            in_idx += p.shape[1]
            out_idx += p.shape[0]
            in_idx_of_scores.append(in_idx)
            out_idx_of_scores.append(out_idx)
    in_idx_of_scores.pop()
    out_idx_of_scores.pop()

    in_pciel, out_pciel, in_best_channels, out_best_channels = sort_conv1d_score(
        conv1d_in_score_list, conv1d_out_score_list,
        args.conv1d_scale,
        in_idx_of_scores, in_c_of_para,
        out_idx_of_scores, out_c_of_para,
    )
    new_ckpt, reverse_info = pruning_conv1d(
        model, get_module_type, new_ckpt, reverse_info,
        in_pciel, out_pciel, in_best_channels, out_best_channels,
    )


    logger.log("pruning the Linear layers...")
    assert args.linear_score_path != " "
    linear_in_score_list = np.load(args.linear_score_path+'/linear_in_score.npy')
    linear_out_score_list = np.load(args.linear_score_path+'/linear_out_score.npy')

    in_c_of_para = []
    out_c_of_para = []
    in_idx_of_scores = [0]
    out_idx_of_scores = [0]
    in_idx = 0
    out_idx = 0
    layer_id = -1
    for k,p in model.named_parameters():
        type_p,_ = get_module_type(model,k)
        if (k.endswith('weight')) and (type_p == 'Linear'):
            layer_id += 1

            if layer_id >= 2:
                assert p.shape[0]%2 == 0
                nums_of_out_ch = int(p.shape[0]/2)
            else:
                nums_of_out_ch = p.shape[0]
            in_c_of_para.append(p.shape[1])
            out_c_of_para.append(nums_of_out_ch)
            in_idx += p.shape[1]
            out_idx += nums_of_out_ch
            in_idx_of_scores.append(in_idx)
            out_idx_of_scores.append(out_idx)
    in_idx_of_scores.pop()
    out_idx_of_scores.pop()

    linear_in_score_list, linear_out_score_list = align_lin(
        linear_in_score_list, linear_out_score_list,
        in_c_of_para, out_c_of_para,
    )

    in_pciel, out_pciel, in_best_channels, out_best_channels = sort_lin_score(
        linear_in_score_list, linear_out_score_list,
        args.linear_scale,
        in_idx_of_scores, in_c_of_para,
        out_idx_of_scores, out_c_of_para,
    )
    new_ckpt, reverse_info = pruning_lin(
        model, get_module_type, new_ckpt, reverse_info,
        in_pciel, out_pciel, in_best_channels, out_best_channels,
    )
    logger.log("pruning finished")

    pruned_paras_num = num_of_para(new_ckpt)
    origin_paras_num = num_of_para(ori_ckpt)
    perc = float(pruned_paras_num)/float(origin_paras_num)

    logger.log("The threshold of rep Conv2d's out channels is",args.thresh)
    logger.log("The pruning scale of Conv1d's channels is",args.conv1d_scale)
    logger.log("The pruning scale of Linear's channels is",args.linear_scale)
    logger.log("The reserve rate of total parameters is",pruned_paras_num,"/",origin_paras_num,", i.e.",perc)

    logger.log("saving check point...")
    logger.log("check points will be saved at scripts/models/")
    torch.save(new_ckpt,'models/pruned_diffusion.pt')
    torch.save(reverse_info,'models/reverse_info.pt')
    logger.log("finished")


if __name__ == "__main__":
    main()