from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import importlib
import argparse

import _init_paths
import torch 
import torch.optim as optim

from model.scorenet import get_score_net
from loss.multi_loss import MultiLossFactory
from utils.rescore import valid_core
from utils.rescore import read_rescore_data

def parse_args():
    parser = argparse.ArgumentParser(description='Train rescore network')
    # general
    parser.add_argument('--hpe_dataset',
                        help='dataset name, please choosing from coco or crowdpose',
                        required=True,
                        type=str)
    parser.add_argument('--hpe_model',
                        help='model name',
                        required=True,
                        type=str)
    parser.add_argument('--model_file',
                        help='trained parameters path',
                        required=False,
                        type=str)
    parser.add_argument('opts',
                        help="Modify config options using the command-line",
                        default=None,
                        nargs=argparse.REMAINDER)
    args = parser.parse_args()
    return args

def import_cfg(args):
    dataset, model = args.hpe_dataset, args.hpe_model
    cfg_lib = f"{dataset}.{model}"
    cfg = importlib.import_module(cfg_lib)
    return cfg
  
def create_save_path(cfg):
    save_path = os.path.join('output', cfg.train['save_path'])
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    return save_path

def main():
    args = parse_args()
    cfg = import_cfg(args)

    # create data
    x_data, y_data, s_data = read_rescore_data(args, cfg, is_train=False)
    
    # create model
    model = get_score_net(cfg, input_channel=x_data.size(1), is_train=True)
    if args.model_file:
        model.load_state_dict(torch.load(args.model_file), strict=True)
    model.cuda()
    model.eval()
    
    # create loss
    loss_fn = MultiLossFactory(cfg).cuda()

    # start testing
    test_l1_loss, test_order_loss = valid_core(x_data, y_data, model, loss_fn, cfg.test['batch_size'], cfg.test['pose_size'])
    print("test_l1_loss:", test_l1_loss, "test_order_loss:", test_order_loss)

    # TODO: Convert pkl to json
    

if __name__ == '__main__':
    main()
