"""
Train the rep-diffusion model on images.
"""

import argparse
import sys
import os
sys.path.append("..")

from guided_diffusion import dist_util, logger
from guided_diffusion.image_datasets import load_data
from guided_diffusion.resample import create_named_schedule_sampler
from guided_diffusion.script_util import (
    model_and_diffusion_defaults,
    create_rep_model_and_diffusion,
    args_to_dict,
    add_dict_to_argparser,
)
from guided_diffusion.train_util import RepTrain


def create_argparser():
    defaults = dict(
        data_dir="",
        schedule_sampler="uniform",
        lr=1e-4,
        weight_decay=0.0,
        lr_anneal_steps=0,
        batch_size=1,
        microbatch=-1,  # -1 disables microbatches
        ema_rate="0.9999",  # comma-separated list of EMA values
        log_interval=10,
        save_interval=10000,
        resume_checkpoint="",
        use_fp16=False,
        fp16_scale_growth=1e-3,
    )
    defaults.update(model_and_diffusion_defaults())
    parser = argparse.ArgumentParser()
    add_dict_to_argparser(parser, defaults)
    parser.add_argument(
        "--save_path", type=str, help="The path to save the log and the result"
    )
    parser.add_argument(
        "--single_label", type=int, default=-1,
        help="If your training wants to use single label, you have to make sure the label you assign match with the data"
    )
    parser.add_argument(
        "--use_simple_train", type=str, default="False", help="If you want to use simple train based on DPM Solver."
    )
    parser.add_argument(
        "--step_respacing", type=int, default=-1, help="A Utility variable for simple train"
    )
    parser.add_argument(
        "--dpm_solver_steps", type=int, default=20, help="dpm-solver steps"
    )
    parser.add_argument(
        "--lasso_strength", type=float, default=1e-4, help="the lasso strength"
    )
    parser.add_argument(
        "--mask_interval", type=int, default=200, help="The interval to update the mask"
    )
    parser.add_argument(
        "--before_mask_iters", type=int, default=10000, help="The num of iter before update the mask"
    )
    parser.add_argument(
        "--rep_train_thresh", type=float, default=1e-3, help="The threshhold for updating the mask"
    )
    parser.add_argument(
        "--num_upd_msk", type=int, default=-1,
        help="The num of channel to update the mask, if not -1, it will forcibly replace the rep_train_thresh."
    )
    parser.add_argument(
        "--rep_thresh_decay", type=int, default=10
    )
    parser.add_argument(
        "--rep_thresh_min", type=float, default=1e-5
    )
    parser.add_argument(
        "--num_upd_msk_increment", type=int, default=-1
    )

    return parser

def num_of_para(ckpt):
    num_of_paras = 0
    for k,v in ckpt.items():
        tmp = 1
        for i in range(len(v.shape)):
            tmp = tmp * v.shape[i]
        num_of_paras += tmp
    return num_of_paras

def get_module_type(m,name):
    names = name.split('.')
    mark = '1st_order'
    tag1 = None
    tag2 = None
    for s in names:
        if s != 'weight' and s != 'bias':
            if s.isdigit():
                m = m[int(s)]
            elif mark == '1st_order':
                if s == 'time_embed':
                    m = m.time_embed
                elif s == 'label_emb':
                    m = m.label_emb
                elif s == 'input_blocks':
                    m = m.input_blocks
                elif s == 'middle_block':
                    m = m.middle_block
                elif s == 'output_blocks':
                    m = m.output_blocks
                elif s == 'out':
                    m = m.out
                else:
                    raise Exception('Unknown 1st_order module name!')
                mark = '2nd_order'
            elif mark == '2nd_order':
                if s =='in_layers':
                    m = m.in_layers
                elif s == 'emb_layers':
                    m = m.emb_layers
                elif s == 'out_layers':
                    m = m.out_layers
                elif s == 'skip_connection':
                    m = m.skip_connection
                elif s == 'skip_com':
                    m = m.skip_com
                elif s == 'pwc':
                    m = m.pwc
                elif s == 'norm':
                    m = m.norm
                elif s == 'qkv':
                    m = m.qkv
                elif s == 'proj_out':
                    m = m.proj_out
                else:
                    raise Exception('Unknown 2nd_order module name!')
        else:
            tag2 = s
            break
    tag1 = str(m)
    if tag1.startswith('Conv2d'):
        tag1 = 'Conv2d'
    elif tag1.startswith('Linear'):
        tag1 = 'Linear'
    elif tag1.startswith('GroupNorm'):
        tag1 = 'GroupNorm'
    elif tag1.startswith('Conv1d'):
        tag1 = 'Conv1d'
    return tag1,tag2


def main():
    args = create_argparser().parse_args()

    dist_util.setup_dist(os.environ["CUDA_VISIBLE_DEVICES"])
    logger.configure(dir = args.save_path)


    '''---------------------------------- Create rep_model and dataloader -----------------------------------'''
    logger.log("creating model and diffusion...")
    rep_model,diffusion = create_rep_model_and_diffusion(
        **args_to_dict(args, model_and_diffusion_defaults().keys())
    )


    logger.log("View the rep model structure:")
    logger.log(rep_model)
    logger.log("##########################################################################")

    logger.log("View all input specified parameters:")
    for arg in vars(args):
        logger.log(arg, ":", getattr(args, arg))
    logger.log("##########################################################################")

    logger.log("You have enough time to check your training hyper-parameters.")
    logger.log('If the check is over, enter "Continue" at the terminal to continue the train.')
    logger.log("waiting for the continue signal...")
    while True:
        p = input()
        if p == 'Continue':
            break
    logger.log("continue training...")

    rep_model.to(dist_util.dev())

    schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion)

    data = load_data(
        data_dir=args.data_dir,
        batch_size=args.batch_size,
        image_size=args.image_size,
        class_cond=args.class_cond,
    )


    '''----------------------------------------- Prepare model info -----------------------------------------'''
    out_c_of_para = []
    idx_of_scores = [0]
    idx = 0
    num_of_compactor = 0
    for k,p in rep_model.named_parameters():
        if 'pwc' in k:
            num_of_compactor += 1
            idx += p.shape[0]
            idx_of_scores.append(idx)
            out_c_of_para.append(p.shape[0])
        # type_p,_ = get_module_type(rep_model,k)
        # if (k.endswith('weight')) and (type_p == 'Conv2d') and ('pwc' not in k):
        #     out_c_of_para.append(p.shape[0])
        #     idx += p.shape[0]
        #     idx_of_scores.append(idx)
    logger.log("There are", num_of_compactor, "compactors in total.")
    idx_of_scores.pop()


    '''----------------------------------------- Train the rep-model ----------------------------------------'''
    if args.use_simple_train == 'False':
        use_simple_train = False
    else:
        use_simple_train = True

    if args.num_upd_msk_increment == -1:
        num_upd_msk_increment = args.num_upd_msk
    else:
        assert args.num_upd_msk_increment > 0
        num_upd_msk_increment = args.num_upd_msk_increment

    logger.log("training the rep model...")
    RepTrain(
        model=rep_model,
        diffusion=diffusion,
        data=data,
        batch_size=args.batch_size,
        microbatch=args.microbatch,
        lr=args.lr,
        log_interval=args.log_interval,
        save_interval=args.save_interval,
        resume_checkpoint=args.resume_checkpoint,
        idx_of_scores=idx_of_scores,
        out_c_of_para=out_c_of_para,
        save_mask_score_path=args.save_path,
        use_fp16=args.use_fp16,
        fp16_scale_growth=args.fp16_scale_growth,
        schedule_sampler=schedule_sampler,
        weight_decay=args.weight_decay,
        lr_anneal_steps=args.lr_anneal_steps,
        use_simple_train=use_simple_train,
        step_respacing=args.step_respacing,
        dpm_solver_steps=args.dpm_solver_steps,
        lasso_strength=args.lasso_strength,
        mask_interval=args.mask_interval,
        before_mask_iters=args.before_mask_iters,
        rep_train_thresh=args.rep_train_thresh,
        num_upd_msk=args.num_upd_msk,
        rep_thresh_decay=args.rep_thresh_decay,
        rep_thresh_min=args.rep_thresh_min,
        num_upd_msk_increment=num_upd_msk_increment,
    ).run_loop(args.single_label)


if __name__ == "__main__":
    main()
