import argparse
from deep_sprl.util.parameter_parser import parse_parameters
import tensorflow as tf
import deep_sprl.environments

tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)


def main():
    parser = argparse.ArgumentParser("Self-Paced Learning experiment runner")
    parser.add_argument("--base_log_dir", type=str, default="logs")
    parser.add_argument("--type", type=str, default="np_self_paced",
                        choices=["default", "random", "self_paced", "np_self_paced", "wasserstein", "alp_gmm",
                                 "goal_gan", "acl"])
    parser.add_argument("--learner", type=str, default="trpo", choices=["trpo", "ppo", "sac"])
    parser.add_argument("--env", type=str, default="point_mass_2d",
                        choices=["pick_and_place", "point_mass_2d", "maze"])
    parser.add_argument("--hard_likelihood", action="store_true")
    parser.add_argument("--seed", type=int, default=1)

    args, remainder = parser.parse_known_args()
    parameters = parse_parameters(remainder)
    parameters["hard_likelihood"] = args.hard_likelihood

    if args.type == "self_paced" or args.type == "np_self_paced" or args.type == "wasserstein":
        import torch
        torch.set_num_threads(1)

    if args.env == "point_mass_2d":
        from deep_sprl.experiments import PointMass2DExperiment
        exp = PointMass2DExperiment(args.base_log_dir, args.type, args.learner, parameters, args.seed)
    elif args.env == "pick_and_place":
        from deep_sprl.experiments import PickAndPlaceExperiment
        exp = PickAndPlaceExperiment(args.base_log_dir, args.type, args.learner, parameters, args.seed)
    elif args.env == "maze":
        from deep_sprl.experiments import MazeExperiment
        exp = MazeExperiment(args.base_log_dir, args.type, args.learner, parameters, args.seed)
    elif args.env == "maze_vec":
        from deep_sprl.experiments import MazeExperimentVec
        exp = MazeExperimentVec(args.base_log_dir, args.type, args.learner, parameters, args.seed)
    else:
        raise RuntimeError("Unknown environment '%s'!" % args.env)

    exp.train()
    exp.evaluate()


if __name__ == "__main__":
    main()
