import argparse
import ray.tune as tune
import yaml
from ray.util.ml_utils.dict import merge_dicts
from argparse import Namespace


def create_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--torch", action="store_true")
    parser.add_argument(
        "--scenario",
        type=str,
        default="fort_attack",
        choices=["fort_attack", "atari"],
        help="name of the scenario script",
    )
    parser.add_argument("--mixed-torch-tf", action="store_true")
    parser.add_argument(
        "--as-test",
        action="store_true",
        help="Whether this script should be run as a test: --stop-reward must "
        "be achieved within --stop-timesteps AND --stop-iters.",
    )
    parser.add_argument(
        "--stop-iters",
        type=int,
        default=1000,
        help="Number of iterations to train.",
    )
    parser.add_argument(
        "--stop-timesteps",
        type=int,
        default=1000000000000,
        help="Number of timesteps to train.",
    )
    parser.add_argument(
        "--stop-reward",
        type=float,
        default=150.0,
        help="Reward at which we stop training.",
    )

    # Fortattack specific args
    parser.add_argument("--num-guards", type=int, default=5, help="number of guards")
    parser.add_argument(
        "--num-attackers", type=int, default=5, help="number of attackers"
    )
    parser.add_argument(
        "--num-steps", type=int, default=200, help="number of max environment steps"
    )
    parser.add_argument(
        "--num-shots", type=int, default=10, help="number of adversaries"
    )
    parser.add_argument(
        "--max-rot", type=float, default=0.17, help="maximum rotation of agents"
    )
    parser.add_argument(
        "--random-starting-rot",
        type=bool,
        default=True,
        help="maximum rotation of agents",
    )
    parser.add_argument("--return-image", action="store_true")
    parser.add_argument("--attacker-can-fire", action="store_true")
    parser.add_argument("--use-hard-coded-paths", action="store_true")
    parser.add_argument("--use-pygame", action="store_true")

    # atari specific args
    parser.add_argument(
        "--sub-scenario",
        type=str,
        default="MsPacmanNoFrameskip-v4",
        help="sub scenario",
    )
    parser.add_argument("--noop-max", type=int, default=30, help="noop")
    parser.add_argument(
        "--skip",
        type=int,
        default=4,
        help="number of frames to skip for atari environments",
    )
    # Model specific args
    parser.add_argument(
        "--model-type", type=str, default="rnn", help="type of custom policy to use"
    )
    parser.add_argument(
        "--group-type",
        type=int,
        default="0",
        help="0 for no group, 1 for adversary group, 2 for guard group, 3 for both",
    )
    parser.add_argument(
        "--concepts",
        type=list,
        default=[
            "can_shoot_ordinal",
            "agent_targeting_ordinal",
            # "relative_orientation",
            # "distance_between",
            # "distance_from_base",
        ],
        help="names of concepts to use",
    )

    parser.add_argument("--include-concepts", action="store_true")
    parser.add_argument("--use-balanced", action="store_true")

    parser.add_argument("--include-whitening", action="store_true")

    parser.add_argument(
        "--concept-yaml",
        type=str,
        default="",
        help="type of custom policy to use",
    )
    parser.add_argument(
        "--experiment-yaml",
        type=str,
        default=None,
        help="experiment yml with parameters",
    )

    # Rllib args
    parser.add_argument("--use-exploration", action="store_true")
    parser.add_argument(
        "--num-workers", type=int, default=10, help="how many workers to use"
    )
    parser.add_argument(
        "--num-eval-workers", type=int, default=1, help="how many workers to use"
    )
    parser.add_argument(
        "--num-envs-per-worker", type=int, default=8, help="how many workers to use"
    )
    parser.add_argument("--num-gpu", type=float, default=1, help="how many gpus to use")
    parser.add_argument(
        "--algo", type=str, default="ppo", help="which algorithm to use"
    )
    parser.add_argument(
        "--scheduler", type=str, default="medianstopping", help="which algorithm to use"
    )
    parser.add_argument(
        "--max-episode-steps", type=int, default=50000, help="max episode steps"
    )

    # evaluation specific args
    parser.add_argument("--test", action="store_true")
    parser.add_argument("--steps", type=int, default=200000, help="")
    parser.add_argument("--episodes", type=int, default=100, help="")
    parser.add_argument(
        "--render", action="store_true", help="Render the environment while evaluating."
    )
    parser.add_argument(
        "--video-dir",
        type=str,
        default="",
        help="video saving directory",
    )
    parser.add_argument(
        "--rollout-dir",
        type=str,
        default="",
        help="rollout saving directory",
    )
    parser.add_argument(
        "--replacement",
        type=list,
        default=["can_shoot_ordinal"],
        help="what intervention value to run",
    )

    parser.add_argument(
        "--save-info",
        default=True,
        action="store_false",
        help="Save the info field generated by the step() method, "
        "as well as the action, observations, rewards and done fields.",
    )
    parser.add_argument(
        "--use-shelve",
        default=False,
        action="store_true",
        help="Save rollouts into a python shelf file (will save each episode "
        "as it is generated). An output filename must be set using --out.",
    )
    parser.add_argument(
        "--track-progress",
        default=False,
        action="store_true",
        help="Write progress to a temporary file (updated "
        "after each episode). An output filename must be set using --out; "
        "the progress file will live in the same folder.",
    )

    parser.add_argument("--checkpoint", type=str, default="", help="checkpoint to use")
    # callback arg
    parser.add_argument(
        "--callback",
        type=str,
        default="sequential",
    )

    # args.sequential_win_rate_threshold
    parser.add_argument(
        "--sequential-win-rate-threshold",
        type=float,
        default=0.8,
    )
    parser.add_argument("--lr", default=None)
    parser.add_argument("--entropy-coeff", default=None)
    parser.add_argument("--policies-to-train", default=None)
    parser.add_argument("--conceptdim", default=None)
    parser.add_argument("--cirriculum", default=None)
    parser.add_argument("--rectangle", action="store_true")
    parser.add_argument("--entropy-coeff-schedule", default=None)
    parser.add_argument("--use-reward", default=None)
    parser.add_argument("--serve", action="store_true")
    parser.add_argument("--local-dir", default=None)
    parser.add_argument("--concept-update", type=list, default=None)
    parser.add_argument("--t-whitening", default=None)
    parser.add_argument("--num-samples", default=None)
    parser.add_argument("--embed-dim", default=None)
    parser.add_argument("--render-env", default=True)
    parser.add_argument("--correct-concepts", default=False)
    return parser


def create_args():
    parser = create_parser()
    args = parser.parse_args()

    experiment_yaml = args.experiment_yaml

    if experiment_yaml is not None:
        args_dict = vars(args)
        with open(experiment_yaml, "r") as stream:
            yml_dict = yaml.safe_load(stream)
        args_dict = merge_dicts(args_dict, yml_dict["parameters"])
        args = Namespace(**args_dict)

    return args
