import argparse
import sys

from garage_experiments_class_structure import (
    DnCCEMExperiment,
    DnCSACExperiment,
    SACExperiment,
    SACDiscreteExperiment,
    DistralSACExperiment,
    DistralSQLDiscreteExperiment,
)


def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ("yes", "true", "t", "y", "1"):
        return True
    elif v.lower() in ("no", "false", "f", "n", "0"):
        return False
    else:
        raise argparse.ArgumentTypeError("Boolean value expected.")


def main(args) -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "experiment",
        choices=[
            "sac", "mtsac", "dnc_sac", "dnc_cem",
            "distral_sac", "distral_sql_discrete",
        ],
    )
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--name", type=str, default="None")
    parser.add_argument("--n_policies", type=int, default=4)
    parser.add_argument(
        "--kl_coeff",
        nargs="+",
        type=float,
        default=[0.01],
        help="order: 1->1 2->1 ... N->1 1->2 2->2 ... N->2 ...",
    )
    parser.add_argument("--stagewise", type=str2bool, default=False)
    parser.add_argument("--env", type=str, default="InvertedPendulum-v2")
    parser.add_argument("--hidden_sizes", nargs="+", type=int, default=[256, 256])
    parser.add_argument("--n_epochs", type=int, default=1000)
    parser.add_argument("--batch_size", type=int, default=1000)
    parser.add_argument("--lr", type=float, default=3e-4)
    parser.add_argument("--grad_clip", type=float, default=None)
    parser.add_argument("--num_evaluation_episodes", type=int, default=10)
    parser.add_argument("--evaluation_frequency", type=int, default=1)
    parser.add_argument(
        "--partition",
        default="random",
        choices=[
            "goal_quadrant",
            "obstacle_id",
            "obstacle_orientation",
            "goal_cluster",
            "task_id",
        ],
    )
    parser.add_argument("--gpu", type=int, default=None)
    parser.add_argument("--ray", type=str2bool, default=False)
    parser.add_argument("--notes", type=str, default="")
    parser.add_argument("--wandb", type=str2bool, default=True)
    parser.add_argument("--wandb_project", type=str, default=None)
    parser.add_argument("--wandb_entity", type=str, default=None)

    parser.add_argument("--goal_type", type=int, default=0)
    parser.add_argument("--sparse_tasks", nargs="+", type=int, default=[])
    parser.add_argument("--onehot_id", type=str2bool, default=False)
    parser.add_argument("--task_name", type=str, default=None)
    parser.add_argument("--discrete_action", type=str2bool, default=False)

    parser.add_argument(
        "--vis_freq", type=int, default=10, help="-1 for no visualization"
    )
    parser.add_argument("--vis_num", type=int, default=10)
    parser.add_argument("--vis_width", type=int, default=500)
    parser.add_argument("--vis_height", type=int, default=500)
    parser.add_argument("--vis_fps", type=int, default=80)
    parser.add_argument("--vis_format", type=str, default="mp4")

    ### Reward Arguments (only for Reacher Goal Clusters currently)
    parser.add_argument("--reward_params", nargs="+", type=float, default=[0, 0])
    parser.add_argument(
        "--reward_type", type=str, choices=["shift", "scale", "shaped"], default="shift"
    )
    args = parser.parse_known_args()[0]

    if args.env == "ReacherCluster-v0":
        env_args = {
            "reward_type": args.reward_type,
            "reward_params": args.reward_params,
        }
    else:
        env_args = {}

    parser.add_argument("--gradient_steps_per_itr", type=int, default=1000)
    parser.add_argument("--min_buffer_size", type=int, default=int(1e4))
    parser.add_argument("--target_update_tau", type=float, default=5e-3)
    parser.add_argument("--discount", type=float, default=0.99)
    parser.add_argument("--buffer_batch_size", type=int, default=32)
    parser.add_argument("--reward_scale", type=float, default=1.0)
    parser.add_argument("--steps_per_epoch", type=int, default=1)

    if "dnc" in args.experiment or "distral" in args.experiment:
        parser.add_argument(
            "--sampling_type",
            type=str,
            choices=["all", "i", "j", "i+j", "ixj"],
            default="all",
        )
        parser.add_argument("--regularize_representation", type=str2bool, default=False)
    elif "distral" in args.experiment:
        parser.add_argument("--two_column", type=str2bool, default=False)

    if args.experiment == "dnc_cem":
        ### addding CEM arguments
        parser.add_argument("--frequency", type=int, default=50)
        parser.add_argument("--sigma", type=float, default=1.1)
        parser.add_argument("--population_size", type=int, default=20)
        parser.add_argument("--num_elite", type=int, default=1)
        parser.add_argument("--iterations", type=int, default=10)
        parser.add_argument("--warm_up", type=int, default=20)
        parser.add_argument("--update", default="RL")
        parser.add_argument("--num_updates", type=int, default=1)
        parser.add_argument("--one_beta", type=str2bool, default=False)
        parser.add_argument("--discrete", type=str2bool, default=True)
        parser.add_argument("--individual_evaluation", type=str2bool, default=False)
        parser.add_argument("--fixed_search", type=str2bool, default=False)
        parser.add_argument("--restart", type=str2bool, default=False)

    if args.env.startswith("ReacherMultistage"):
        env_args.update(
            {
                "sparse_tasks": args.sparse_tasks,
                "goal_type": args.goal_type,
                "discrete": args.discrete_action,
            }
        )
    elif args.env.startswith("JacoReachMultistage"):
        env_args.update(
            {
                "sparse_tasks": args.sparse_tasks,
                "goal_type": args.goal_type,
            }
        )
    elif args.env.startswith("MetaWorld"):
        env_args.update(
            {
                "sparse_tasks": args.sparse_tasks,
                "goal_type": args.goal_type,
                "onehot_id": args.onehot_id,
                "task_name": args.task_name,
            }
        )
    elif args.env.startswith("Kitchen"):
        env_args.update(
            {
                "sparse_tasks": args.sparse_tasks,
                "goal_type": args.goal_type,
                "onehot_id": args.onehot_id,
                "task_name": args.task_name,
            }
        )

    args = parser.parse_args()

    agent_args = {
        "gradient_steps_per_itr": args.gradient_steps_per_itr,
        "min_buffer_size": args.min_buffer_size,
        "target_update_tau": args.target_update_tau,
        "discount": args.discount,
        "buffer_batch_size": args.buffer_batch_size,
        "reward_scale": args.reward_scale,
        "steps_per_epoch": args.steps_per_epoch,
    }

    if args.experiment in ["sac", "mtsac"]:
        if args.discrete_action:
            experiment = SACDiscreteExperiment(
                config=args,
                env_args=env_args,
                agent_args=agent_args,
            )
        else:
            experiment = SACExperiment(
                config=args,
                env_args=env_args,
                agent_args=agent_args,
            )

    elif args.experiment == "dnc_sac":
        experiment = DnCSACExperiment(
            config=args,
            env_args=env_args,
            agent_args=agent_args,
        )

    elif args.experiment == "dnc_cem":
        cem_configs = {
            "frequency": args.frequency,
            "sigma": args.sigma,
            "population_size": args.population_size,
            "num_elite": args.num_elite,
            "iterations": args.iterations,
            "warm_up": args.warm_up,
            "update": args.update,
            "num_updates": args.num_updates,
            "one_beta": args.one_beta,
            "discrete": args.discrete,
            "individual_evaluation": args.individual_evaluation,
            "fixed_search": args.fixed_search,
            "restart": args.restart,
        }
        experiment = DnCCEMExperiment(
            config=args,
            env_args=env_args,
            agent_args=agent_args,
            cem_configs=cem_configs,
        )

    elif args.experiment == "distral_sac":
        experiment = DistralSACExperiment(
            config=args,
            env_args=env_args,
            agent_args=agent_args,
        )

    elif args.experiment == "distral_sql_discrete":
        experiment = DistralSQLDiscreteExperiment(
            config=args,
            env_args=env_args,
            agent_args=agent_args,
        )

    else:
        raise NotImplementedError

    experiment.run()


if __name__ == "__main__":
    main(sys.argv)
