import argparse


def str2bool(v):
    """ Used to convert the command line arg of bool into boolean var """
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise ValueError('Boolean value expected.')


def get_args(ps=None):
    if ps is None:
        ps = argparse.ArgumentParser()
    ps.add_argument('--wandb', type=str2bool, default=False)
    ps.add_argument('--prefix', type=str, default="test")
    ps.add_argument("--device", type=str, default="cpu")

    # Common
    ps.add_argument("--policy_name", type=str, default="DDPG")
    ps.add_argument('--env_name', type=str, default='paint', help="mujoco-HalfCheetah-v3")
    ps.add_argument('--seed', type=int, default=1234, help='random seed')
    ps.add_argument('--env_seed', type=int, default=1234, help='random seed')
    ps.add_argument('--dataset', type=str, default='cub200')

    # Training setup
    ps.add_argument('--total_ts', type=int, default=10000000, help='total training steps')
    ps.add_argument('--max_episode_steps', type=int, default=50, help='max length for episode (*)')
    ps.add_argument('--warmup_ts', type=int, default=5000)
    ps.add_argument('--batch_size', type=int, default=64, help='minibatch size')
    ps.add_argument('--buffer_size', type=int, default=100000, help='replay memory size')
    ps.add_argument('--num_envs', type=int, default=16, help='concurrent environment number/ number of environments')
    ps.add_argument('--num_updates', type=int, default=10, help='number of training steps per episode')
    ps.add_argument('--num_RL_updates', type=int, default=3, help='number of training steps per episode')
    ps.add_argument('--if_async', type=str2bool, default=False)
    ps.add_argument('--if_visualise', type=str2bool, default=False)
    ps.add_argument('--eval_num_episodes', type=int, default=1, help='episodes used for performing validation')
    ps.add_argument('--eval_freq', type=int, default=50, help='episode interval for performing validation')

    # Scheduler
    ps.add_argument('--epsilon_start', type=float, default=1.0, help="init value of epsilon decay")
    ps.add_argument('--epsilon_end', type=float, default=0.01, help="final value of epsilon decay")
    ps.add_argument('--decay_steps', type=int, default=1000000, help="init value of epsilon decay")
    ps.add_argument("--gaussian_noise_std", type=float, default=0.1)  # Std of Gaussian exploration noise
    ps.add_argument("--policy_noise", type=float, default=0.2)  # Noise added to target policy during critic update
    ps.add_argument("--noise_clip", type=float, default=0.5)  # Range to clip target policy noise
    ps.add_argument("--policy_freq", type=int, default=2)  # Frequency of delayed policy updates

    # trivial ones
    ps.add_argument('--load_path', type=str, default=None, help='Load model and resume training')
    ps.add_argument('--dir_data', type=str, default='./data', help='Output path for storing model')
    ps.add_argument('--dir_model', type=str, default='./data', help='Output path for storing model')
    ps.add_argument('--output', type=str, default='./results', help='Output path for storing model')
    ps.add_argument("--if_save_agent", type=str2bool, default=False)
    ps.add_argument("--save_freq", type=int, default=300)
    ps.add_argument("--load_model", type=str, default="")

    # Agent related
    ps.add_argument("--discount", type=float, default=0.99)
    ps.add_argument("--tau", type=float, default=0.005)

    # PaintGym specific
    ps.add_argument('--paint_bundle_size', type=int, default=5, help='action bundle size')
    ps.add_argument('--paint_type_encoder', type=str, default="resnet", help="cnn / resnet / flat-mlp")
    ps.add_argument('--paint_if_gan_reward', type=str2bool, default=True)
    ps.add_argument('--paint_type_iqa_reward', type=str, default="none", help="original / fixed / one-step")
    ps.add_argument('--paint_type_diff_reward', type=str, default="original", help="original / fixed / one-step")
    ps.add_argument('--paint_if_patch', type=str2bool, default=True)
    ps.add_argument('--paint_if_change_lr', type=str2bool, default=False)

    ps.add_argument('--mjc_if_pomdp', type=str2bool, default=False)
    ps.add_argument('--SAC_if_automatic_entropy_tuning', type=str2bool, default=False)

    ps.add_argument('--if_use_prev_state', type=str2bool, default=True)
    ps.add_argument('--if_use_latent_state', type=str2bool, default=False)
    ps.add_argument('--if_update_per_ts', type=str2bool, default=False)
    ps.add_argument('--if_use_act_val_fn', type=str2bool, default=False)
    ps.add_argument('--if_train_models', type=str2bool, default=False)
    ps.add_argument('--if_train_state_model', type=str2bool, default=False)
    ps.add_argument('--if_train_reward_model', type=str2bool, default=False)
    ps.add_argument('--if_actor_reward', type=str2bool, default=False)
    ps.add_argument('--if_use_next_reward', type=str2bool, default=False)

    # RecSim
    ps.add_argument('--recsim_user_budget', type=int, default=20, help="")
    ps.add_argument('--recsim_num_actions', type=int, default=100, help="")
    ps.add_argument('--recsim_num_categories', type=int, default=30, help="")
    ps.add_argument('--recsim_dim_embed', type=int, default=30, help="")
    ps.add_argument('--recsim_no_click_mass', type=float, default=2, help="")
    ps.add_argument('--recsim_user_dist', type=str, default="sklearn-gmm", help="uniform / modal / gmm")
    ps.add_argument('--recsim_category_dist', type=str, default="random", help="")
    ps.add_argument('--recsim_item_dist', type=str, default="sklearn-gmm", help="")
    ps.add_argument('--recsim_choice_model_type', type=str, default="multinomial", help="deterministic / multinomial")
    ps.add_argument('--recsim_type_user_utility_computation', type=str, default="dot_prod", help="euc_dist / dot_prod")
    ps.add_argument('--recsim_step_penalty', type=float, default=0.5, help="")
    ps.add_argument('--recsim_if_user_global_transition', type=str2bool, default=False, help="")
    ps.add_argument('--recsim_if_switch_act_task_emb', type=str2bool, default=False, help="")
    ps.add_argument('--recsim_if_deterministic_user', type=str2bool, default=False, help="")
    return ps


def add_args(args: argparse.Namespace):
    # args.wandb = False
    # args.device = "cpu"
    # args.num_RL_updates = 10
    # args.if_visualise = True

    # args.env_name = "mujoco-single-Walker2d"
    # args.if_train_reward_model = True
    # args.if_train_state_model = True

    # args.if_actor_reward = True
    # args.if_use_prev_state = False
    # args.if_use_latent_state = True

    # args.policy_name = "SAC"
    # args.policy_name = "TD3"
    # args.policy_name = "DDPG"
    # args.mjc_if_pomdp = True
    if args.env_name.lower() == "paint":
        args.total_ts = 2500000
        args.discount = args.discount ** args.paint_bundle_size
        if args.if_train_state_model:
            args.if_use_prev_state = False
        if args.policy_name.lower() in ["td3", "sac"]:
            args.paint_if_patch = False
    elif args.env_name.startswith("mujoco-") and "single" not in args.env_name or args.env_name.startswith("classic-"):
        args.if_update_per_ts = True
        args.warmup_ts = 25000
        args.buffer_size = 500000
        args.batch_size = 256
        args.eval_freq = 10
        if args.env_name.startswith("mujoco-"):
            args.max_episode_steps = 1000
        if not args.if_use_act_val_fn:
            args.if_train_state_model = args.if_train_reward_model = True
    elif args.env_name.lower() == "recsim":
        args.total_ts = 1000000
        # args.warmup_ts = 0
        # args.decay_steps = 10
        args.eval_num_episodes = 3
    if args.if_use_act_val_fn:
        args.if_use_prev_state = False
    args.if_train_models = args.if_train_state_model or args.if_train_reward_model
    if args.if_update_per_ts:
        # args.num_updates = args.num_envs
        args.num_updates = 1
        args.num_RL_updates = 1
    return args


def get_all_args():
    ps = argparse.ArgumentParser()
    ps = get_args(ps=ps)
    args = ps.parse_args()
    args = add_args(args=args)
    return args
