from hpo.pbt import PopulationBasedTraining
from hpo.mfpbt import MultipleFrequenciesPopulationBasedTraining
from hpo.ablation import MultipleFrequenciesPopulationBasedTraining as Ablation

from pprint import pprint

import os
import argparse
import yaml

from hpo.pbt import PATH_TO_MAIN_PROJECT


def load_config():
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--env",
        type=str,
        help="Environment to use for the experiment",
    )
    parser.add_argument(
        "--hpo",
        type=str,
        help="HPO algorithm to use for the experiment",
    )
    parser.add_argument(
        "--frequencies",
        nargs="*",
        type=int,
        help="Frequencies to use, if in MF-PBT, give in the form --frequencies 1 10 20",
    )
    parser.add_argument(
        "--exp-name",
        type=str,
        help="Experiment name - if we detect an experiment with the same name, we will reload it",
    )
    parser.add_argument(
        "--num-agents", type=int, help="Number of agents to use in parallel"
    )
    parser.add_argument(
        "--num-envs-per-agent",
        type=int,
        help="Number of brax envs to use for each agent",
    )
    parser.add_argument(
        "--num-rounds", type=int, help="Number of rounds to in the experiment"
    )
    parser.add_argument("--bucket-path", type=str, help="path for external logging")
    parser.add_argument(
        "--num-timesteps-round",
        type=int,
        help=(
            "Number of environment steps per round, note that depending on the training config, "
            "batch size and num_minibatches, more steps will be performed"
        ),
    )
    parser.add_argument(
        "--jax-seed",
        type=int,
        default=0,
        help="jax seed used for the training and env simulations",
    )
    parser.add_argument(
        "--numpy-seed",
        type=int,
        default=0,
        help="Numpy seed used for the random initialization of hyperparameters",
    )

    xp_args = parser.parse_args()

    env_choice = xp_args.env

    if env_choice is None and xp_args.hpo is None:
        env_choice = "inverted_pendulum"
        xp_args.hpo = "mfpbt"
        config_file = os.path.join(PATH_TO_MAIN_PROJECT, "configurations/example.yml")
    else:
        if env_choice not in [
            "ant",
            "halfcheetah",
            "hopper",
            "humanoid",
            "pusher",
            "walker2d",
        ]:
            raise ValueError(
                f"Environment {env_choice} not supported, choose one of ['ant', 'halfcheetah', 'hopper', 'humanoid', 'pusher', 'walker2d']"
            )

        if xp_args.hpo not in ["pbt", "mfpbt", "ablation", "do_nothing"]:
            raise ValueError(
                f"HPO algorithm {xp_args.hpo} not supported, choose one of ['pbt', 'mf-pbt', 'ablation', 'do_nothing']"
            )

        config_file = os.path.join(
            PATH_TO_MAIN_PROJECT, f"configurations/{env_choice}/{xp_args.hpo}.yml"
        )

    # Load the arguments of the config file
    config = yaml.safe_load(open(config_file))

    # Update them if needed
    for arg in xp_args._get_kwargs():
        (key, value) = arg
        if value is not None and key in config.keys():
            config[key] = value

    if xp_args.exp_name is None:
        config["exp_name"] = f"{xp_args.hpo}/{env_choice}/n_agents_{config['num_agents']}"

    return config, xp_args.hpo


if __name__ == "__main__":
    config, hpo_choice = load_config()
    print("Launching experiment with the following config: \n")
    pprint(config)
    print("\n")

    if hpo_choice == "mfpbt":
        alg = MultipleFrequenciesPopulationBasedTraining(**config)
    elif hpo_choice == "ablation":
        alg = Ablation(**config)
    else:
        alg = PopulationBasedTraining(**config)

    alg.run()
