import os
import importlib
import logging
import torch

from train.pytorch_wrapper.prediction import Predictor
from train.pytorch_wrapper.utils import BColors

from train.behavioral_cloning.run_train import parse_args, get_arg_string, build_model
from train.behavioral_cloning.eval_hooks.eval_env import EvalEnv

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

bcols = BColors()


def get_dataset(args):
    module = os.path.splitext(args.dataset)[0].replace("/", ".")
    dataset = importlib.import_module(module)
    return dataset


def test_model(args):
    """
    train model
    """
    
    # init model
    logging.info("BehavioralCloning: Initializing model ...")
    model = build_model(args).to(DEVICE)
    dataset = get_dataset(args)
    
    # init prediction
    if args.paramfile:
        param_file = args.paramfile
    else:
        arg_str = get_arg_string(args)
        if args.checkpoint:
            arg_str = "%s_cp%06d" % (arg_str, args.checkpoint)
        param_file = os.path.join(args.param_root, "params_%s.npy" % arg_str)
    logging.debug("param_file:", param_file)
    predictor = Predictor(model, param_file=param_file)

    # initialize eval env
    env = EvalEnv(env_name=args.env, input_space=dataset.INPUT_SPACE, action_space=dataset.ACTION_SPACE,
                  transforms=dataset.DATA_TRANSFORM, seq_len=dataset.SEQ_LENGTH,
                  trials=args.trials, max_steps=args.maxsteps, record_dir=args.recordto, argstr=get_arg_string(args),
                  seed=args.envseed, verbosity=args.verbosity, frame_skip=dataset.FRAME_SKIP, env_server=False)

    # call evaluation
    rewards = env(predictor)
    
    for k, v in rewards.items():
        print(bcols.print_colored("%s: %.3f" % (k, v), bcols.FAIL))


if __name__ == "__main__":
    """ main """
    args = parse_args()
    test_model(args)
