"""
Top-level script to run all baseline environments and algorithms used in the paper
"Robust Adversarial Reinforcement Learning via Bounded Rationality Curricula"
"""

import argparse
import datetime
import os
import time
from run_experiment import evaluation
import numpy as np
BASE_RESULTS_DIRECTORY = os.path.dirname(os.path.realpath(__file__)) + "/results"
exp_name = 'cheetah_run_qarl'

# exp_name = 'hopper_hop_qarl'
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

def main():
    parser = argparse.ArgumentParser(
        "Curriculum Adversarial RL Experiment Runner",
        description="Launches an adversarial RL experiment",
    )

    parser.add_argument("--seeds", type=int, default=1)
    parser.add_argument(
        "--algorithm",
        type=str,
        default="qarl",
        choices=[
            "baseline",
            "rarl",
            "mas",
            "sgld",
            "force",
            "qarl_spdl",
            "qarl_linear",
            "qarl_point",
            "qarl_single",
            "fixed_force",
            "fixed_temp",
            "random",
        ],
    )
    parser.add_argument(
        "--domain_name",
        type=str,
        default="cheetah_two_players",
        choices=[
            "acrobot_two_players",
            "ball_in_cup_two_players",
            "cartpole_two_players",
            "cheetah_two_players",
            "hopper_two_players",
            "pendulum_two_players",
            "quadruped_two_players",
            "reacher_two_players",
            "walker_two_players",
        ],
    )
    parser.add_argument(
        "--task_name",
        type=str,
        default="run_vs_adversary",
        choices=[
            "balance_vs_adversary",
            "balance_sparse_vs_adversary",
            "swingup_vs_adversary",
            "swingup_sparse_vs_adversary",
            "balance_vs_adversary",
            "catch_vs_adversary",
            "walk_vs_adversary",
            "run_vs_adversary",
            "hop_vs_adversary",
            "reach_goal_vs_adversary",
            "reach_goal_vs_wind_adversary",
            "easy_vs_adversary",
            "hard_vs_adversary",
        ],
    )
    parser.add_argument("--horizon", type=int, default=500)
    parser.add_argument("--gamma", type=float, default=0.99)
    parser.add_argument("--bool_render", type=bool, default=False)
    parser.add_argument("--new_adv_max_force", type=float, default=None)
    parser.add_argument("--use_cuda", type=bool, default=True)
    parser.add_argument("--n_total_iterations", type=int, default=200) #200
    parser.add_argument("--n_evaluation_episodes", type=int, default=20)
    args = parser.parse_args()


    # Range of robustness values to evaluate
    if args.domain_name == "acrobot_two_players":
        first_metric_range = np.linspace(0.95, 1.05, 5)
        second_metric_range = np.linspace(0.95, 1.05, 5)
    elif args.domain_name  == "ball_in_cup_two_players":
        first_metric_range = np.linspace(0.01, 0.11, 5)
        second_metric_range = np.linspace(0.01, 0.11, 5)
    if args.domain_name  == "cartpole_two_players":
        if args.task_name in (
            "balance_vs_adversary",
            "balance_sparse_vs_adversary",
        ):
            first_metric_range = np.linspace(1, 20, 5)
            second_metric_range = np.linspace(0.5, 1.5, 5)
        elif args.task_name in (
            "swingup_vs_adversary",
            "swingup_sparse_vs_adversary",
        ):
            first_metric_range = np.linspace(0.05, 0.15, 5)
            second_metric_range = np.linspace(0.5, 1.5, 5)
        else:
            first_metric_range = np.array([])
            second_metric_range = np.array([])
    elif args.domain_name == "cheetah_two_players":
        first_metric_range = np.linspace(3, 9, 5)
        second_metric_range = np.linspace(0.1, 1.9, 5)
    elif args.domain_name  == "hopper_two_players":
        first_metric_range = np.linspace(1, 9, 5)
        second_metric_range = np.linspace(0.1, 1.9, 5)
    elif args.domain_name  == "pendulum_two_players":
        first_metric_range = np.linspace(0.05, 0.15, 5)
        second_metric_range = np.linspace(0.5, 1.5, 5)
    elif args.domain_name  == "quadruped_two_players":
        first_metric_range = np.linspace(65, 75, 5)
        second_metric_range = np.linspace(0.95, 1.05, 5)
    elif args.domain_name == "reacher_two_players":
        first_metric_range = np.linspace(0.02, 0.06, 5)
        second_metric_range = np.linspace(0.02, 0.05, 5)
    elif args.domain_name  == "walker_two_players":
        first_metric_range = np.linspace(5, 15, 5)
        second_metric_range = np.linspace(0.1, 1.3, 5)
    else:
        first_metric_range = np.array([])
        second_metric_range = np.array([])



    print('base results directory: ', BASE_RESULTS_DIRECTORY)
    print('exp name: ', exp_name)
    start_time = time.time()
    for i in range(args.seeds):
        results_dir = (
            BASE_RESULTS_DIRECTORY + f"/{exp_name}/algorithm___{args.algorithm}/{i}"
        )
        for first_metric_value in first_metric_range:
            for second_metric_value in second_metric_range:
                experiment_args = {
                    "domain_name": args.domain_name,
                    "task_name": args.task_name,
                    "horizon": args.horizon,
                    "gamma": args.gamma,
                    "bool_render": args.bool_render,
                    "new_adv_max_force": args.new_adv_max_force,
                    "use_cuda": args.use_cuda,
                    "n_total_iterations": args.n_total_iterations,
                    "n_evaluation_episodes": args.n_evaluation_episodes,
                    "first_metric_value": first_metric_value,
                    "second_metric_value": second_metric_value
                }
                
                evaluation(
                    algorithm=args.algorithm, seed=i, results_dir=results_dir, **experiment_args
                )
    print('training time: ', time.time() - start_time)

if __name__ == "__main__":
    main()
