
import os
import os.path as osp
import argparse
import copy
import logging
import torch
import torchvision

from deep2flat.utils import load_yaml, load_txt
from deep2flat.dnn2dyn import dnn2dyn
#from deep2flat.dyn2dnn import dyn2dnn

from model_zoo import get_model
import multiprocessing as mp
from multiprocessing import Pool
from functools import partial

def get_args_parser(add_help=True):
    parser = argparse.ArgumentParser(description="DyN Training", add_help=add_help)

    parser.add_argument("--dynconfig", default="", type=str, help="dyn config path")
    parser.add_argument("--model", default="vit_l_16", type=str, help="model name")
    parser.add_argument("--num_gpus", default=4, type=int, help="threads number for dnn2dyn")
    parser.add_argument("--log_name", default="dyn_fitting", type=str, help="log file name")
    parser.add_argument("--num_threads", default=1, type=int, help="threads number for dnn2dyn")
    parser.add_argument("--vars_path", default="", type=str, help="variable name txt to be converted")
    parser.add_argument("--logdir", default="", type=str, help="log dir for multiprocessing")
    parser.add_argument("--save_path", default="./deep_weights", type=str, help="save path of converted model")
    parser.add_argument("--weights", default="", type=str, help="weight name")
    parser.add_argument("--funetune_weights", default="", type=str, help="if there is funetune weight, reload it")
    
    return parser


def myPID():
	# Returns relative PID of a pool process
	return mp.current_process()._identity[0]

def deep2flat_mp(vars_list, dyn_configs, model, num_gpus, logdir):
    """
    convert a dnn to dyn, and recover the dnn using dyn

    Args:
        vars_list (list[str]): parameter list for converted
        dyn_configs (dict): configs of dyn
        model (nn.Module): dnn model
    """
    if not osp.exists(logdir):
        os.makedirs(logdir)
    dyn_configs['device'] = 'cuda:{}'.format(myPID()%num_gpus)
    logging.basicConfig(level=logging.DEBUG,
                filename=osp.join(logdir, '{}_{}.log'.format(vars_list, myPID())),
                datefmt='%Y/%m/%d %H:%M:%S',
                format='%(asctime)s - %(name)s - %(levelname)s - %(lineno)d - %(module)s - %(message)s')
    logger = logging.getLogger(__name__)

    logger.info('************device is {}'.format(dyn_configs['device']))

    print('************device is {}'.format(dyn_configs['device']))
    device = dyn_configs['device']
    device = torch.device(device)
    model.to(device)

    assert isinstance(vars_list, str)
    vars_list = [vars_list]
    print('dnn2dyn...')
    dyn_model = dnn2dyn(dyn_configs, model, vars_list, logger)

def deep2flat(log_name, vars_list, dyn_configs, model, num_gpus, logdir):
    """
    convert a dnn to dyn, and recover the dnn using dyn

    Args:
        vars_list (list[str]): parameter list for converted
        dyn_configs (dict): configs of dyn
        model (nn.Module): dnn model
    """
    #dyn_configs['device'] = 'cuda:{}'.format(myPID()%num_gpus)
    logging.basicConfig(level=logging.DEBUG,
                filename=osp.join(logdir, '{}.log'.format(log_name)),
                datefmt='%Y/%m/%d %H:%M:%S',
                format='%(asctime)s - %(name)s - %(levelname)s - %(lineno)d - %(module)s - %(message)s')
    logger = logging.getLogger(__name__)

    logger.info('************device is {}'.format(dyn_configs['device']))

    print('************device is {}'.format(dyn_configs['device']))
    device = dyn_configs['device']
    device = torch.device(device)
    model.to(device)
    print(vars_list)
    if isinstance(vars_list, str):
        vars_list = [vars_list]
    print('dnn2dyn...')
    dyn_model = dnn2dyn(dyn_configs, model, vars_list, logger)

def make_vars(args):
    """
    Get parameter names for converted

    Args:
        args (parse): Input args

    Returns:
        list[str]: Parameter names
    """
    
    model = get_model(args.model, args.weights)
    
    vars_list = []
    for var_name, param in model.state_dict().items():
        if len(param.shape) == 2:
            vars_list.append(var_name)
        elif len(param.shape) == 4:
            vars_list.append(var_name)
    return vars_list

if __name__=="__main__":
    args = get_args_parser().parse_args()
    
    save_path = args.save_path
    if '.pth' not in save_path:
        if not osp.exists(save_path):
            os.mkdir(save_path)    
        save_path = osp.join(save_path, '{}_{}.pth'.format(args.model, args.weights))
    print(save_path)
    dyn_configs = load_yaml(args.dynconfig)
    vars_list = make_vars(args)
    vars_list.pop(0)
    vars_list.pop(-1)
    model = get_model(args.model, args.weights)
    for name, param in model.named_parameters():
        if param.requires_grad:  
            print(f"Parameter name: {name}, Shape: {param.shape}")

    checkpoint = None
    if osp.exists(args.funetune_weights):
        print('load state dict from {}'.format(args.funetune_weights))
        checkpoint = torch.load(args.funetune_weights)
        model.load_state_dict(checkpoint['model'])

    if args.num_threads > 1:
        mp.set_start_method('spawn')
        partial_deep2flat2deep = partial(deep2flat_mp, dyn_configs=dyn_configs, model=copy.deepcopy(model), num_gpus=args.num_gpus, logdir=args.logdir)
        with Pool(args.num_threads, maxtasksperchild=1) as p:
            p.map(partial_deep2flat2deep, vars_list, chunksize=1)
    else:
        deep2flat(args.log_name, vars_list, dyn_configs, model, args.num_gpus, logdir=args.logdir)
