import os
os.environ["D4RL_SUPPRESS_IMPORT_ERROR"] = "1"
import gym
gym.logger.set_level(40)
import yaml
from argparse import ArgumentParser
from nn_eval import batched_nn_eval, parallel_nn_eval
from train_bc import train_model

from logging_utils import logger

def main():
    parser = ArgumentParser()
    parser.add_argument("env_config_path", help="Path to environment config file")
    parser.add_argument("policy_config_path", help="Path to policy config file")
    parser.add_argument("--trials", type=int, default=100)
    parser.add_argument("--results_file_name", default=None)
    parser.add_argument("--split_policy_config_path", default=None)
    parser.add_argument("--batched", action="store_true")
    parser.add_argument("--dump_trial", action="store_true")
    args, _ = parser.parse_known_args()
    logger.info(f"Evaluating with {args.trials} trial{'s' if args.trials != 1 else ''}")

    with open(args.env_config_path, 'r') as f:
        env_cfg = yaml.load(f, Loader=yaml.FullLoader)
    with open(args.policy_config_path, 'r') as f:
        policy_cfg = yaml.load(f, Loader=yaml.FullLoader)


    dan = policy_cfg['model_config'].get("dan", False)
    lwr = policy_cfg['model_config']['type'] == 'lwr'
    regent = policy_cfg['model_config']['type'] == 'regent'
    env_cfg['seed'] = 42
    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    env_cfg['device'] = f"cuda:{local_rank}"
    policy_cfg['train_config']['force_retrain'] = False
    agent, _ = train_model(0, 1, env_cfg, policy_cfg)
    agent.eval()

    if args.split_policy_config_path is not None:
        with open(args.split_policy_config_path, 'r') as f:
            split_policy_cfg = yaml.load(f, Loader=yaml.FullLoader)
        split_policy_cfg['train_config']['force_retrain'] = False
        split_agent, _ = train_model(0, 1, env_cfg, split_policy_cfg)
        split_agent.eval()
        split_agent.save_deltas = True
        split_agent.save_queries = True
        split_agent.episode_deltas = []
        split_agent.episode_queries = []
    else:
        split_agent = None

    if args.batched:
        batched_nn_eval(env_cfg, agent, trials=args.trials, results=args.results_file_name, reset=True, dan=(dan or lwr or regent))
    else:
        parallel_nn_eval(env_cfg, agent, trials=args.trials, results=args.results_file_name, dan=dan, split_agent=split_agent, dump_trial=args.dump_trial)

if __name__ == "__main__":
    main()
