import argparse
import os
import sys

import wandb
import gym

from configs.train_behavioral_policy_defaults import get_cfg_defaults

# sys.path.append('~/d4mrl/')
# sys.path.append('/home/orilinial/d4mrl/')
sys.path.append(f"{os.path.expanduser('~')}/d4mrl/")

import utils
from algorithms import SB3


# def get_agent_model(algorithm):
#     if config.train.algorithm == 'td3':
#         return TD3
#     elif config.train.algorithm == 'sac':
#         return SB3
#     else:
#         raise Exception('%s algorithm is not supported yet.' % algorithm)


def train_agent(config, train_env, eval_env):
    print("------------------------------------------------------------------------------")
    print(f"Training agent on environment: {config.env.train_env}, Evaluation env: {config.env.eval_env}, Seed: {config.system.seed}")
    print("------------------------------------------------------------------------------")

    save_path = os.path.join(config.train.agent_path)
    os.makedirs(save_path, exist_ok=True)
    config_dict = utils.save_config(config, os.path.join(save_path, 'config.yaml'))
    agent_path = os.path.join(save_path, 'agent')
    evaluations_path = os.path.join(save_path, 'evaluations.pkl')

    train_env.seed(config.system.seed)
    eval_env.seed(config.system.seed)

    if config.wandb.enable:
        wandb.config.update(config_dict)

    # model = TD3(train_env, eval_env, config, agent_path, evaluations_path)
    model = SB3(train_env, eval_env, config, agent_path, evaluations_path)
    model.train()
    print('Created behavioral agents, saved in path %s' % save_path)


def evaluate_agent(config, env, agent_path):
    print("------------------------------------------------------------------------------")
    print(f"Evaluating agent on environment: {config.env.type}")
    print("------------------------------------------------------------------------------")
    model = SB3(env, eval_env, config, agent_path, None)
    model.test_policy()


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="parse args")
    parser.add_argument('--config-file', type=str, default=None)
    parser.add_argument('--config-list', nargs="+", default=None)
    args = parser.parse_args()

    config = get_cfg_defaults(args.config_file, args.config_list)

    train_env = gym.make(config.env.train_env)
    eval_env = gym.make(config.env.eval_env)

    if config.wandb.enable:
        wandb.init(project=config.wandb.project_name, entity=config.wandb.entity_name, config={})

    train_agent(config, train_env, eval_env)
