from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import importlib
import argparse

import _init_paths
import torch 
import torch.optim as optim

from model.scorenet import get_score_net
from utils.utils import rescore_json


def parse_args():
    parser = argparse.ArgumentParser(description='Train rescore network')
    # general
    parser.add_argument('--hpe_dataset',
                        help='dataset name, please choosing from coco or crowdpose',
                        required=True,
                        type=str)
    parser.add_argument('--hpe_model',
                        help='model name',
                        required=True,
                        type=str)
    parser.add_argument('--model_file',
                        help='trained parameters path',
                        required=False,
                        type=str)
    parser.add_argument('opts',
                        help="Modify config options using the command-line",
                        default=None,
                        nargs=argparse.REMAINDER)
    args = parser.parse_args()
    return args

def import_cfg(args):
    dataset, model = args.hpe_dataset, args.hpe_model
    cfg_lib = f"{dataset}.{model}"
    cfg = importlib.import_module(cfg_lib)
    return cfg
  
def create_save_path(cfg):
    save_path = os.path.join('output', cfg.train['save_path'])
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    return save_path

def main():
    args = parse_args()
    cfg = import_cfg(args)
    
    # create model
    model = get_score_net(cfg, input_channel=cfg.model['feature_channels'], is_train=True)
    if args.model_file:
        model.load_state_dict(torch.load(args.model_file), strict=True)
    model.cuda()
    model.eval()

    rescore_json(cfg, args, model)
    
if __name__ == '__main__':
    main()
