import argparse
import os
import sys

import wandb
import gymnasium
sys.modules["gym"] = gymnasium

import highway
import algorithms
import utils
from configs.train_on_highway_defaults import get_cfg_defaults
from utils.results_utils import get_date_time_str, save_config, set_seed, num


def get_agent_model(algorithm):
    if algorithm == 'td3':
        return algorithms.TD3
    if algorithm == 'sac':
        return algorithms.SB3
    elif algorithm == 'mopo':
        return algorithms.MOPO
    elif algorithm == 'mopo_plus':
        return algorithms.MOPOPlus
    elif algorithm == 'td3_bc':
        return algorithms.TD3_BC
    elif algorithm == 'h2o':
        return algorithms.H2O
    elif algorithm == 'iql':
        return algorithms.IQL
    else:
        raise Exception('%s algorithm is not supported yet.' % algorithm)


def train_agent(config, env, eval_env):
    print("---------------------------------------")
    print(f"Training agent with algorithm: {config.train.algorithm}")
    print(f"Evaluation env: {config.env.eval_env}. Seed: {config.system.seed}")
    print(f"Simulator env: {config.env.train_env}, specifications: {config.simulator.transform_list}")
    print("---------------------------------------")
    set_seed(config.system.seed)

    save_path = os.path.join(config.train.agent_path, get_date_time_str())
    os.makedirs(save_path)
    config_dict = save_config(config, os.path.join(save_path, 'config.yaml'))
    agent_path = os.path.join(save_path, 'agent')
    os.makedirs(agent_path)
    evaluations_path = os.path.join(save_path, 'evaluations.pkl')

    # env.seed(config.system.seed)
    # eval_env.seed(config.system.seed)

    if config.wandb.enable:
        wandb.config.update(config_dict)
        
    model = get_agent_model(config.train.algorithm)(env, eval_env, config, agent_path, evaluations_path)

    model.train()

    print('Created behavioral agents, saved in path %s' % save_path)


def evaluate_agent(config, env, agent_path):
    print("---------------------------------------")
    print(f"Evaluating agent on environment: {config.env.eval_env}")
    print("---------------------------------------")
    model = get_agent_model(config.agent.type)(env, eval_env, config, agent_path, None)
    model.test_policy()


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="parse args")
    parser.add_argument('--config-file', type=str, default=None)
    parser.add_argument('--no-wandb', action='store_true')
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--train-env', type=str, default='')
    parser.add_argument('--eval-env', type=str, default='')
    parser.add_argument('--algorithm', type=str, default='')
    parser.add_argument('--stacked-frames', type=int, default=0)
    parser.add_argument('--added-name', type=str, default='')
    parser.add_argument('--transform-list', nargs='+')
    args = parser.parse_args()

    if args.transform_list is None:
        transform_list = []
    else:
        transform_list = [(args.transform_list[i], num(args.transform_list[i+1])) for i in range(0, len(args.transform_list)-1, 2)]

    sim_env, wandb_name = highway.get_transformed_env(args.train_env, transform_list)
    sim_env.reset()
    if 'obs_hidden_cars' in [t[0] for t in transform_list]:
        eval_env = highway.make_env(args.eval_env, wrap=True, hidden_cars=True)
        # eval_env = highway.ObsErrorCars(eval_env.env)
        # eval_env = highway.HighwayWrapper(eval_env)
    else:
        eval_env = highway.make_env(args.eval_env, wrap=True)
    eval_env.reset()

    if args.stacked_frames > 0:
        sim_env = utils.stack_frames(sim_env, args.stacked_frames, flatten=True)
        eval_env = utils.stack_frames(eval_env, args.stacked_frames, flatten=True)
        wandb_name += f'_stacked_frames_{args.stacked_frames}'

    if args.added_name:
        wandb_name += f'_{args.added_name}'

    # Changing run configurations
    config_list = [
        'system.seed', args.seed,
        'wandb.enable', not args.no_wandb,
        'wandb.project_name', args.eval_env,
        'wandb.entity_name', 'mechanistic_offline_rl',
        'wandb.name', f'{args.algorithm}_{wandb_name}'.replace("_", '-'),
        'env.train_env', args.train_env,
        'env.eval_env', args.eval_env,
        'train.agent_path', f'trained_agents/{args.algorithm}/{wandb_name}/'.replace("-", '_'),
        'train.algorithm', args.algorithm,
        'simulator.transform_list', transform_list
    ]

    config = get_cfg_defaults(args.config_file, config_list)

    if config.wandb.enable:
        wandb.init(project=config.wandb.project_name, entity=config.wandb.entity_name, config={})

    train_agent(config, sim_env, eval_env)
