
import os
import os.path as osp
import argparse
import copy

import torch

from deep2flat.utils import load_yaml, load_txt
from deep2flat.dnn2dyn import dnn2dyn
from deep2flat.dyn2dnn import dyn2dnn

from model_zoo import get_model

def get_args_parser(add_help=True):
    parser = argparse.ArgumentParser(description="DyN Training", add_help=add_help)

    parser.add_argument("--dynconfig", default="./configs/vit_l_16_IMAGENET1K_SWAG_E2E_V1.yaml", type=str, help="dyn config path")
    parser.add_argument("--model", default="vit_l_16", type=str, help="model name")
    parser.add_argument("--vars_path", default="", type=str, help="variable name txt to be converted")
    parser.add_argument("--save_path", default="./deep_weights", type=str, help="save path of converted model")
    parser.add_argument("--weights", default="IMAGENET1K_SWAG_E2E_V1", type=str, help="weight name")
    parser.add_argument("--funetune_weights", default="", type=str, help="if there is funetune weight, reload it")
    
    return parser


def deep2flat2deep(dyn_configs, model, vars_list, deep_pth):
    """
    convert a dnn to dyn, and recover the dnn using dyn

    Args:
        dyn_configs (dict): configs of dyn
        model (nn.Module): dnn model
        vars_list (list[str]): parameter list for converted
        deep_pth (str): save path of the recovered dnn
    """
    print('dnn2dyn...')
    dyn_model = dnn2dyn(dyn_configs, model, vars_list)
    deepen_model = copy.deepcopy(model)
    print('dyn2dnn...')
    model = dyn2dnn(deepen_model, dyn_configs['SAVE_ROOT'])
    torch.save(model.state_dict(), deep_pth)

if __name__=="__main__":
    args = get_args_parser().parse_args()

    save_path = args.save_path
    if not osp.exists(save_path):
        os.mkdir(save_path)    
    save_path = osp.join(save_path, '{}_{}.pth'.format(args.model, args.weights))
    
    dyn_configs = load_yaml(args.dynconfig)
    device = dyn_configs['device']
    device = torch.device(device)
    vars_list = load_txt(args.vars_path)

    model = get_model(args.model, args.weights)
    #if args.model == 'vit_l_16':
    #    model = torchvision.models.vit_l_16(weights=args.weights)
    if osp.exists(args.funetune_weights):
        print('load state dict from {}'.format(args.funetune_weights))
        checkpoint = torch.load(args.funetune_weights)
        model.load_state_dict(checkpoint['model'])
    model.to(device)
    
    deep2flat2deep(dyn_configs, model, vars_list, save_path)