"""
Train a diffusion model on images.
"""

import torch
import argparse
import sys
import os
sys.path.append("..")

from guided_diffusion import dist_util, logger
from guided_diffusion.image_datasets import load_data
from guided_diffusion.resample import create_named_schedule_sampler
from guided_diffusion.script_util import (
    model_and_diffusion_defaults,
    create_model_and_diffusion,
    args_to_dict,
    add_dict_to_argparser,
)
from guided_diffusion.train_util import TrainLoop

from guided_diffusion.pruned_script_util import create_model_and_diffusion as create_pruned_model_and_diffusion
from guided_diffusion.pruned_script_util import model_and_diffusion_defaults as pruned_model_and_diffusion_defaults

def create_argparser():
    defaults = dict(
        data_dir="",
        schedule_sampler="uniform",
        lr=1e-4,
        weight_decay=0.0,
        lr_anneal_steps=0,
        batch_size=1,
        microbatch=-1,  # -1 disables microbatches
        ema_rate="0.9999",  # comma-separated list of EMA values
        log_interval=10,
        save_interval=10000,
        resume_checkpoint="",
        use_fp16=False,
        fp16_scale_growth=1e-3,
    )
    defaults.update(model_and_diffusion_defaults())
    defaults.update(pruned_model_and_diffusion_defaults())
    parser = argparse.ArgumentParser()
    add_dict_to_argparser(parser, defaults)
    parser.add_argument(
        "--save_path", type=str, help="The path to save the log and the result"
    )
    parser.add_argument(
        "--use_gen_loss", type=str, default="False", help="If you want to use std_model to get gen loss."
    )
    parser.add_argument(
        "--standard_model_path", type=str, default='', help="The path to load the standard model"
    )
    parser.add_argument(
        "--single_label", type=int, default=-1,
        help="If your training wants to use single label, you have to make sure the label you assign match with the data"
    )
    parser.add_argument(
        "--use_simple_train", type=str, default="False", help="If you want to use simple train based on DPM Solver."
    )
    parser.add_argument(
        "--step_respacing", type=int, default=-1, help="A Utility variable for simple train"
    )
    parser.add_argument(
        "--dpm_solver_steps", type=int, default=20, help="dpm-solver steps"
    )

    return parser

def num_of_para(ckpt):
    num_of_paras = 0
    for k,v in ckpt.items():
        tmp = 1
        for i in range(len(v.shape)):
            tmp = tmp * v.shape[i]
        num_of_paras += tmp
    return num_of_paras


def main():
    args = create_argparser().parse_args()

    dist_util.setup_dist(os.environ["CUDA_VISIBLE_DEVICES"])
    logger.configure(dir = args.save_path)

    logger.log("View all input specified parameters:")
    for arg in vars(args):
        logger.log(arg, ":", getattr(args, arg))
    logger.log("##########################################################################")


    if args.use_gen_loss == 'True':
        '''------------------------------------- Create standard_model --------------------------------------'''
        logger.log("creating model and diffusion...")
        standard_model, diffusion = create_model_and_diffusion(
            **args_to_dict(args, model_and_diffusion_defaults().keys())
        )
        standard_model.load_state_dict(
            dist_util.load_state_dict(args.standard_model_path, map_location="cpu")
        )
        standard_model.to(dist_util.dev())
        if args.use_fp16:
            standard_model.convert_to_fp16()
        standard_model.eval()
    else:
        standard_model = None


    '''--------------------------------- Create pruned_model and dataloader ---------------------------------'''
    pruned_model,diffusion_ = create_pruned_model_and_diffusion(
        **args_to_dict(args, pruned_model_and_diffusion_defaults().keys())
    )
    pruned_model.to(dist_util.dev())
    pruned_ckpt = torch.load(args.resume_checkpoint, map_location="cpu")
    paras_num = num_of_para(pruned_ckpt)
    logger.log("The number of parameters in the pruned model is", paras_num)

    if args.use_gen_loss != 'True':
        diffusion = diffusion_

    schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion)

    data = load_data(
        data_dir=args.data_dir,
        batch_size=args.batch_size,
        image_size=args.image_size,
        class_cond=args.class_cond,
    )


    '''----------------------------------------- Retrain the model ------------------------------------------'''
    if args.use_simple_train == 'False':
        use_simple_train = False
    else:
        use_simple_train = True
    logger.log("training...")
    TrainLoop(
        model=pruned_model,
        standard_model=standard_model,
        diffusion=diffusion,
        data=data,
        batch_size=args.batch_size,
        microbatch=args.microbatch,
        lr=args.lr,
        ema_rate=args.ema_rate,
        log_interval=args.log_interval,
        save_interval=args.save_interval,
        resume_checkpoint=args.resume_checkpoint,
        use_fp16=args.use_fp16,
        fp16_scale_growth=args.fp16_scale_growth,
        schedule_sampler=schedule_sampler,
        weight_decay=args.weight_decay,
        lr_anneal_steps=args.lr_anneal_steps,
        use_simple_train=use_simple_train,
        step_respacing=args.step_respacing,
        dpm_solver_steps=args.dpm_solver_steps,
    ).run_loop(args.single_label)


if __name__ == "__main__":
    main()
