import os
import os.path as osp
import argparse


from deep2flat.utils import write_txt, write_yaml, load_yaml

from model_zoo import get_model

def get_args_parser(add_help=True):
    parser = argparse.ArgumentParser(description="DyN Training&Funetuning", add_help=add_help)

    parser.add_argument("--dynconfig", default="./configs/dyn/vit_b_16_IMAGENET1K_V1.yaml", type=str, help="dyn config path")
    parser.add_argument("--dnnconfig", default="./configs/dnn/vit_b_16_IMAGENET1K_V1.yaml", type=str, help="dnn config path")
    parser.add_argument("--model", default="vit_b_16", type=str, help="model name")
    parser.add_argument("--weights", default="IMAGENET1K_V1", type=str, help="weight name")

    parser.add_argument("--n_epoch", default=5, type=int, help="epoch for deep2flat2deep")

    parser.add_argument("--work_dirs", default="./work_dirs", type=str, help="root dir for save dyn and dnn weights")
    parser.add_argument("--ngpus", default=4, type=str, help="gpu num for dnn funetuning")
    
    return parser

def check_exist(dir_path):
    if not osp.exists(dir_path):
        print('makedirs {}'.format(dir_path))
        os.makedirs(dir_path)
    return dir_path

def make_vars(args):
    """
    Get parameter names for converted

    Args:
        args (parse): Input args

    Returns:
        list[str]: Parameter names
    """
    
    model = get_model(args.model, args.weights)
    
    #if args.model == 'vit_l_16':
    #    model = torchvision.models.vit_l_16(weights=args.weights)
    vars_list = []
    for var_name, param in model.state_dict().items():
        # Only support 2D matrix currently
        if len(param.shape) == 2:
            vars_list.append(var_name)
    return vars_list

if __name__=="__main__":

    args = get_args_parser().parse_args()

    dyn_configs = load_yaml(args.dynconfig)
    
    work_dirs = check_exist(args.work_dirs)

    logger_dir = check_exist(osp.join(args.work_dirs, 'train_log'))

    # save path of dnn converted from dyn
    weights_save_path = check_exist(osp.join(args.work_dirs, 'deep_weights'))
    weights_pth_path = osp.join(weights_save_path, '{}_{}.pth'.format(args.model, args.weights))

    # save path of funetuned dnn
    funetune_save_dir = check_exist(osp.join(args.work_dirs, 'funetune_weights'))
    funetune_weights_path = osp.join(funetune_save_dir, 'checkpoint.pth')

    # save the dyn config
    dyn_config_path = osp.join(args.work_dirs, 'dyn_configs.yaml')
    dyn_configs['SAVE_ROOT'] = check_exist(osp.join(args.work_dirs, dyn_configs['SAVE_ROOT']))
    write_yaml(dyn_configs, dyn_config_path)

    # select top-layer params to convert
    vars_path = osp.join(args.work_dirs, 'vars_list.txt')
    vars_list = make_vars(args)

    for layer_idx, var_group in enumerate(vars_list):
        if isinstance(var_group, str):
            var_group = [var_group]

        dyn_training_logger_dir = check_exist(osp.join(logger_dir, '_'.join(var_group), 'dyn'))
        dnn_training_logger_dir = check_exist(osp.join(logger_dir, '_'.join(var_group), 'dnn'))
        dnn_testing_logger_dir = check_exist(osp.join(logger_dir, '_'.join(var_group), 'test'))

        write_txt([], vars_path)
        # n_epoch times for a parameter matrix
        for epoch in range(args.n_epoch):

            dyn_training_logger_path = osp.join(dyn_training_logger_dir, '{}.log'.format(epoch))
            dnn_training_logger_path = osp.join(dnn_training_logger_dir, '{}.log'.format(epoch))
            dnn_testing_logger_path = osp.join(dnn_testing_logger_dir, '{}.log'.format(epoch))

            # funetune dnn
            dnn_funetuning_order = 'CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node {} dnn_funetune.py \
            --model {} --weights {} --dyn_weights {} --output-dir {} --vars-list {} --config_yaml {} > {}'.format(
                args.ngpus, args.model, args.weights, weights_pth_path, funetune_save_dir, vars_path, args.dnnconfig, dnn_training_logger_path
            )
            os.system(dnn_funetuning_order)

            write_txt(var_group, vars_path)


            # dnn->dyn->dnn
            dyn_training_order = 'python -u dyn_train.py \
            --dynconfig {} --model {} --weights {} --vars_path {} --save_path {} --funetune_weights {} > {}'.format(
                dyn_config_path, args.model, args.weights, vars_path, weights_save_path, funetune_weights_path, dyn_training_logger_path
            )
            os.system(dyn_training_order)
            
            # test dnn
            dnn_test_order = 'CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node {} dnn_funetune.py \
            --model {} --weights {} --dyn_weights {} --output-dir {} --vars-list {} --config_yaml {} --test-only > {}'.format(
                args.ngpus, args.model, args.weights, weights_pth_path, funetune_save_dir, vars_path, args.dnnconfig, dnn_testing_logger_path
            )
            os.system(dnn_test_order)



