import argparse

from pathlib import Path

from transfer.utils.utils import str2bool


def single_parse_args(args=None):
    parser = argparse.ArgumentParser()

    parser.add_argument("--env", type=str, default="hopper")
    parser.add_argument("--dataset_type", type=str, default="buffer")
    parser.add_argument("--seed", type=int, help="Seed for randomness", default=0)
    parser.add_argument("--dataset_path", type=Path, default="data")  # path where datasets are stored
    parser.add_argument("--dataset", type=str, default="medium")  # medium, medium-replay, medium-expert, expert
    parser.add_argument("--mode", type=str, default="normal")  # normal for standard setting, delayed for sparse
    parser.add_argument("--K", type=int, default=20)
    parser.add_argument("--pct_traj", type=float, default=1.0)
    parser.add_argument("--batch_size", type=int, default=64)
    parser.add_argument("--use_actions", type=str2bool, default=True)
    parser.add_argument("--use_returns", type=str2bool, default=True)
    parser.add_argument("--model_type", type=str, default="dt")  # mlp, dt
    parser.add_argument("--embed_dim", type=int, default=128)
    parser.add_argument("--n_layer", type=int, default=3)
    parser.add_argument("--n_head", type=int, default=1)
    parser.add_argument("--activation_function", type=str, default="relu")
    parser.add_argument("--dropout", type=float, default=0.1)
    parser.add_argument("--learning_rate", "-lr", type=float, default=1e-4)
    parser.add_argument("--weight_decay", "-wd", type=float, default=1e-4)
    parser.add_argument("--warmup_steps", type=int, default=10000)
    parser.add_argument("--num_eval_episodes", type=int, default=100)
    parser.add_argument("--max_iters", type=int, default=10)
    parser.add_argument("--num_steps_per_iter", type=int, default=10000)
    parser.add_argument("--device", type=str, default="cuda")
    parser.add_argument("--log_to_wandb", "-w", action="store_true")
    parser.add_argument("--exp_prefix", type=str, default="gym-experiment")

    return parser.parse_known_args(args=args)[0]


def mt_parse_args(args=None):
    parser = argparse.ArgumentParser()

    parser.add_argument("--env", type=str, default=["hopper"], nargs="+")
    parser.add_argument("--dataset_type", type=str, default="buffer")
    parser.add_argument("--seed", type=int, help="Seed for randomness", default=0)
    parser.add_argument("--dataset_path", type=Path, default="data")
    parser.add_argument("--dataset", type=str, default=["medium"], nargs="+")
    parser.add_argument("--mode", type=str, default="normal")  # normal for standard setting, delayed for sparse
    parser.add_argument("--K", type=int, default=20)
    parser.add_argument("--pct_traj", type=float, default=1.0)
    parser.add_argument("--batch_size", type=int, default=64)
    parser.add_argument("--use_actions", type=str2bool, default=True)
    parser.add_argument("--use_returns", type=str2bool, default=True)
    parser.add_argument("--model_type", type=str, default="dt")  # mlp, dt
    parser.add_argument("--embed_dim", type=int, default=128)
    parser.add_argument("--n_layer", type=int, default=3)
    parser.add_argument("--n_head", type=int, default=1)
    parser.add_argument("--activation_function", type=str, default="relu")
    parser.add_argument("--dropout", type=float, default=0.1)
    parser.add_argument("--learning_rate", "-lr", type=float, default=1e-4)
    parser.add_argument("--weight_decay", "-wd", type=float, default=1e-4)
    parser.add_argument("--warmup_steps", type=int, default=10000)
    parser.add_argument("--num_eval_episodes", type=int, default=100)
    parser.add_argument("--max_iters", type=int, default=10)
    parser.add_argument("--num_steps_per_iter", type=int, default=10000)
    parser.add_argument("--device", type=str, default="cuda")
    parser.add_argument("--log_to_wandb", "-w", action="store_true")
    parser.add_argument("--exp_prefix", type=str, default="gym-experiment")
    parser.add_argument("--add_one_hot", action="store_true")

    return parser.parse_known_args(args=args)[0]


def sac_parse_args(args=None):
    parser = argparse.ArgumentParser()

    parser.add_argument("--sequence_name", type=str)
    parser.add_argument("--gamma", type=float, default=0.99)
    parser.add_argument("--learning_rate", type=float, default=1e-3)
    parser.add_argument("--seed", "-s", type=int, default=0)
    parser.add_argument("--steps", type=int, default=1_000_000)
    parser.add_argument("--start_steps", type=int, default=10_000)
    parser.add_argument("--log_every", type=int, default=20_000)
    parser.add_argument("--update_after", type=int, default=1_000)
    parser.add_argument("--replay_size", type=int, default=1_000_000)
    parser.add_argument("--exp_name", type=str, default="sac")
    parser.add_argument("--exp_tags", type=str, nargs="+", default=[])
    parser.add_argument("--use_layer_norm", type=str2bool, default=True)
    parser.add_argument("--sparse_rewards", type=str2bool, default=False)
    parser.add_argument("--accumulate_rewards", type=str2bool, default=False)
    parser.add_argument("--continue_from_pos", type=str2bool, default=False)
    parser.add_argument("--append_timestep", type=str2bool, default=False)
    parser.add_argument("--finish_on_done", type=str2bool, default=False)
    parser.add_argument("--reward_early_finish", type=str2bool, default=False)
    parser.add_argument("--bootstrap_on_time_limit", type=str2bool, default=True)
    parser.add_argument("--save_buffer", type=str2bool, default=False)
    parser.add_argument("--target_output_std", type=float, default=0.089)
    parser.add_argument("--alpha_init")
    parser.add_argument("--done_on_transition", type=str2bool, default=False)
    parser.add_argument("--ordering", type=int, nargs="+")
    parser.add_argument("--memory", type=int)
    parser.add_argument("--tasks_to_memorize", type=int, nargs="+")
    parser.add_argument("--actor_memory_weight", type=float, default=1.0)
    parser.add_argument("--critic_memory_weight", type=float, default=1.0)
    parser.add_argument("--reset_weights_after_memory", type=str2bool, default=False)
    parser.add_argument("--num_test_eps_stochastic", type=int, default=10)
    parser.add_argument("--num_test_eps_deterministic", type=int, default=10)
    parser.add_argument("--num_render_eps_stochastic", type=int, default=0)
    parser.add_argument("--num_render_eps_deterministic", type=int, default=0)
    parser.add_argument("--hidden_dim", type=int, default=256)
    parser.add_argument("--num_layers", type=int, default=4)
    parser.add_argument("--apply_kl_loss", type=str2bool, default=False)
    parser.add_argument("--reg_method", type=str, default=None)

    parser.add_argument("--init_weights_path", type=str)
    parser.add_argument("--log_to_wandb", action="store_true")

    return parser.parse_known_args(args=args)[0]
