import argparse
import warnings
import sys

# from mujoco_py import GlfwContext
# GlfwContext(offscreen=True)

from garage_experiments import (
    run_sac_test,
    run_dnc_sac_test,
    run_mop_dnc_test,
    run_hmop_dnc_test,
    run_hmop_sac_test,
    run_cds_test,
    run_qmp_uds_test,
    run_resume_test,
)

from old_garage_experiments import (
    run_sac_discrete_test,
    run_hmop_dnc_v2_test,
    run_dnc_cem_test,
    run_distral_sac_test,
    run_distral_sql_discrete_test,
)


def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ("yes", "true", "t", "y", "1"):
        return True
    elif v.lower() in ("no", "false", "f", "n", "0"):
        return False
    else:
        raise argparse.ArgumentTypeError("Boolean value expected.")


def main(args) -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "experiment",
        choices=[
            "sac",
            "mtsac",
            "mop_dnc",
            "hmop_dnc",
            "hmop_sac",
            "hmop_dnc_v2",
            "mop_dnc",
            "dnc_sac",
            "dnc_cem",
            "distral_sac",
            "distral_sql_discrete",
            "cds_dnc",
            "hmop_cds_dnc",
            "resume",
        ],
    )
    parser.add_argument("--env", type=str, default="InvertedPendulum-v2")
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--name", type=str, default="None")
    parser.add_argument("--gpu", type=int, default=None)
    parser.add_argument("--max_n_worker", type=int, default=4)
    parser.add_argument("--notes", type=str, default="")
    parser.add_argument("--wandb", type=str2bool, default=True)
    parser.add_argument("--wandb_project", type=str, default=None)
    parser.add_argument("--wandb_entity", type=str, default=None)
    parser.add_argument("--discrete_action", type=str2bool, default=False)
    parser.add_argument("--norm_obs", type=str2bool, default=False)

    ### Training arguments
    parser.add_argument("--gradient_steps_per_itr", type=int, default=1000) ## Split across all tasks
    parser.add_argument("--min_buffer_size", type=int, default=int(1e4))
    parser.add_argument("--target_update_tau", type=float, default=5e-3)
    parser.add_argument("--target_entropy", type=float, default=None)
    parser.add_argument("--discount", type=float, default=0.99)
    parser.add_argument("--buffer_batch_size", type=int, default=32)
    parser.add_argument("--reward_scale", type=float, default=1.0)
    parser.add_argument("--steps_per_epoch", type=int, default=1)
    parser.add_argument("--hidden_sizes", nargs="+", type=int, default=[256, 256])
    parser.add_argument(
        "--policy_architecture",
        type=str,
        choices=["separate", "shared", "multihead"],
        default="separate",
    )
    parser.add_argument(
        "--Q_architecture",
        type=str,
        choices=["separate", "shared", "multihead"],
        default="separate",
    )
    parser.add_argument("--n_epochs", type=int, default=1000)
    parser.add_argument("--batch_size", type=int, default=1000)  ## Split across all tasks
    parser.add_argument("--lr", type=float, default=3e-4)
    parser.add_argument("--qf_lr", type=float, default=3e-4)
    parser.add_argument("--alpha_lr", type=float, default=3e-4)
    parser.add_argument("--grad_clip", type=float, default=None)
    parser.add_argument(
        "--sampler_type", choices=["local", "ray", "mp", "mt_local"], default="local"
    )
    parser.add_argument(
        "--worker_type", choices=["default", "fragment"], default="default"
    )

    ### Evaluation arguments
    parser.add_argument("--num_evaluation_episodes", type=int, default=10)
    parser.add_argument("--evaluation_frequency", type=int, default=1)
    parser.add_argument(
        "--vis_freq", type=int, default=10, help="-1 for no visualization"
    )
    parser.add_argument("--vis_num", type=int, default=10)
    parser.add_argument("--vis_width", type=int, default=500)
    parser.add_argument("--vis_height", type=int, default=500)
    parser.add_argument("--vis_fps", type=int, default=40)
    parser.add_argument("--vis_format", type=str, default="mp4")
    parser.add_argument("--snapshot_gap", type=int, default=1)
    parser.add_argument("--snapshot_dir", type=str)

    ### Task arguments
    parser.add_argument("--n_policies", type=int, default=4)
    parser.add_argument(
        "--partition",
        default="task_id",
        choices=[
            "random",
            "goal_quadrant",
            "obstacle_id",
            "obstacle_orientation",
            "goal_cluster",
            "task_id",
        ],
    )
    # parser.add_argument("--goal_type", type=int, default=0)
    parser.add_argument("--include_task_id", type=str2bool, default=False)
    parser.add_argument("--sparse_tasks", nargs="+", type=int, default=[])
    parser.add_argument("--task_name", type=str, default=None)
    parser.add_argument("--control_penalty", type=float, default=0.0)
    parser.add_argument("--reward_params", nargs="+", type=float, default=[0, 0])
    parser.add_argument(
        "--reward_type", type=str, choices=["shift", "scale", "shaped"], default="shift"
    )

    # Kitchen-specific arguments
    parser.add_argument("--use_spirl_prior", type=str2bool, default=False)
    parser.add_argument("--vectorized_skills", type=str2bool, default=False)
    parser.add_argument("--target_divergence", type=float)
    parser.add_argument("--target_divergence_scheduler", type=str, default=None)
    parser.add_argument("--target_divergence_step", type=float)
    parser.add_argument("--max_target_divergence", type=float, default=None)
    parser.add_argument("--simpl_tricks", type=str2bool, default=False)

    args = parser.parse_known_args()[0]

    ### Add method specific arguments
    if "dnc" in args.experiment or "distral" in args.experiment:
        parser.add_argument(
            "--kl_coeff",
            nargs="+",
            type=float,
            default=[0],
            help="order: 1->1 2->1 ... N->1 1->2 2->2 ... N->2 ...",
        )
        parser.add_argument(
            "--sampling_type",
            type=str,
            choices=["all", "i", "j", "i+j", "ixj"],
            default="all",
        )
        parser.add_argument("--regularize_representation", type=str2bool, default=False)
        parser.add_argument("--distillation_period", type=int, default=None)
        parser.add_argument("--distillation_n_epochs", type=int, default=100)

        if "hmop" not in args.experiment:
            parser.add_argument("--stagewise", type=str2bool, default=False)

    if "distral" in args.experiment:
        parser.add_argument("--two_column", type=str2bool, default=False)

    if args.experiment == "dnc_cem":
        ### addding CEM arguments
        parser.add_argument("--frequency", type=int, default=50)
        parser.add_argument("--sigma", type=float, default=1.1)
        parser.add_argument("--population_size", type=int, default=20)
        parser.add_argument("--num_elite", type=int, default=1)
        parser.add_argument("--iterations", type=int, default=10)
        parser.add_argument("--warm_up", type=int, default=20)
        parser.add_argument("--update", default="RL")
        parser.add_argument("--num_updates", type=int, default=1)
        parser.add_argument("--one_beta", type=str2bool, default=False)
        parser.add_argument("--discrete", type=str2bool, default=True)
        parser.add_argument("--individual_evaluation", type=str2bool, default=False)
        parser.add_argument("--fixed_search", type=str2bool, default=False)
        parser.add_argument("--restart", type=str2bool, default=False)
    if "mop" in args.experiment:
        ### Mixture of Policies Arguments
        parser.add_argument(
            "--mixture_probs",
            nargs="+",
            type=float,
            default=[0.0],
            help="order: 1->1 2->1 ... N->1 1->2 2->2 ... N->2 ...",
        )
        parser.add_argument("--mixture_warmup", type=int, default=0)
        parser.add_argument("--train_samples_by_task", type=str2bool, default=True)
        parser.add_argument("--policy_sampling_freq", type=int, default=1)
        parser.add_argument(
            "--Qfilter",
            type=str,
            choices=["softmax", "argmax", "topQ", "rankQ"],
            default="argmax",
        )
        parser.add_argument("--resample", type=str2bool, default=False)
        parser.add_argument("--evaluate_mean", type=str2bool, default=False)

    if "cds" in args.experiment:
        parser.add_argument("--sharing_quantile", type=float, default=0.0)
        parser.add_argument("--unsupervised", type=str2bool, default=False)

    args = parser.parse_args()

    assert args.vis_freq == -1 or args.vis_freq % args.evaluation_frequency == 0

    if args.use_spirl_prior:
        assert (
            "KitchenSkill" in args.env or args.env == "kitchenmixed-v0"
        ) and not args.include_task_id

    ### Collect environment arguments

    env_args = {}
    if args.env not in ["InvertedPendulum-v2", "kitchenmixed-v0", "Walker2d-v3", "Walker2dForward-v0"]:
        env_args.update(
            {
                "include_task_id": args.include_task_id,
            }
        )
    if args.env == "ReacherCluster-v0" or args.env.startswith("JacoReachMultistage"):
        env_args.update(
            {
                "reward_type": args.reward_type,
                "reward_params": args.reward_params,
            }
        )

    if args.env.startswith("ReacherMultistage") or args.env.startswith(
        "JacoReachMultistage"
    ):
        env_args.update(
            {
                "sparse_tasks": args.sparse_tasks,
                "include_task_id": args.include_task_id,
            }
        )
    elif args.env.startswith("MetaWorld"):
        env_args.update(
            {
                "sparse_tasks": args.sparse_tasks,
                "include_task_id": args.include_task_id,
                "task_name": args.task_name,
            }
        )
    elif args.env.startswith("Kitchen"):
        if "MS" in args.env:
            env_args.update(
                {
                    "sparse_tasks": args.sparse_tasks,
                    "include_task_id": args.include_task_id,
                    "control_penalty": args.control_penalty,
                }
            )
        else:
            env_args.update(
                {
                    "sparse_tasks": args.sparse_tasks,
                    "include_task_id": args.include_task_id,
                    "task_name": args.task_name,
                    "control_penalty": args.control_penalty,
                    "vectorized_skills": args.vectorized_skills,
                }
            )
    elif args.env.startswith("Maze"):
        env_args.update(
            {
                "sparse_tasks": args.sparse_tasks,
                "include_task_id": args.include_task_id,
                "task_name": args.task_name,
            }
        )

    ### Run Experiment

    if args.experiment in ["sac", "mtsac"]:
        if args.discrete_action:
            run_sac_discrete_test(
                config=args,
                env_args=env_args,
            )
        else:
            run_sac_test(
                config=args,
                env_args=env_args,
            )

    elif args.experiment == "dnc_sac":
        run_dnc_sac_test(
            config=args,
            env_args=env_args,
        )

    elif args.experiment == "mop_dnc":
        run_mop_dnc_test(
            config=args,
            env_args=env_args,
        )

    elif args.experiment == "hmop_dnc":
        run_hmop_dnc_test(
            config=args,
            env_args=env_args,
        )
    elif args.experiment == "hmop_sac":
        run_hmop_sac_test(
            config=args,
            env_args=env_args,
        )

    elif args.experiment == "cds_dnc":
        run_cds_test(
            config=args,
            env_args=env_args,
        )

    elif args.experiment == "hmop_cds_dnc":
        run_qmp_uds_test(
            config=args,
            env_args=env_args,
        )
    elif args.experiment == "resume":
        run_resume_test(config=args)
    else:
        warnings.warn(f"Calling unmaintained method: {args.experiment}")

        agent_args = {
            "gradient_steps_per_itr": args.gradient_steps_per_itr,
            "min_buffer_size": args.min_buffer_size,
            "target_update_tau": args.target_update_tau,
            "discount": args.discount,
            "buffer_batch_size": args.buffer_batch_size,
            "reward_scale": args.reward_scale,
            "steps_per_epoch": args.steps_per_epoch,
        }

        if args.experiment == "mop_dnc":
            run_mop_dnc_test(
                config=args,
                env_args=env_args,
                agent_args=agent_args,
            )

        elif args.experiment == "hmop_dnc_v2":
            run_hmop_dnc_v2_test(
                config=args,
                env_args=env_args,
                agent_args=agent_args,
            )

        elif args.experiment == "dnc_cem":
            cem_configs = {
                "frequency": args.frequency,
                "sigma": args.sigma,
                "population_size": args.population_size,
                "num_elite": args.num_elite,
                "iterations": args.iterations,
                "warm_up": args.warm_up,
                "update": args.update,
                "num_updates": args.num_updates,
                "one_beta": args.one_beta,
                "discrete": args.discrete,
                "individual_evaluation": args.individual_evaluation,
                "fixed_search": args.fixed_search,
                "restart": args.restart,
            }
            run_dnc_cem_test(
                config=args,
                env_args=env_args,
                agent_args=agent_args,
                cem_configs=cem_configs,
            )

        elif args.experiment == "distral_sac":
            run_distral_sac_test(
                config=args,
                env_args=env_args,
                agent_args=agent_args,
            )

        elif args.experiment == "distral_sql_discrete":
            run_distral_sql_discrete_test(
                config=args,
                env_args=env_args,
                agent_args=agent_args,
            )

        else:
            raise NotImplementedError


if __name__ == "__main__":
    main(sys.argv)

