
import argparse


def parser_args():
    parser = argparse.ArgumentParser(description='PyTorch Soft Actor-Critic Args')

    # Exp
    parser.add_argument('--exp_name', default="#####TODO#####",
                        help='name of experiment')
    
    # Env
    parser.add_argument('--env_name', default="Ant-v3",
                        help='Mujoco Gym environment (default: Halfcheetah-v3)')
    parser.add_argument('--seed', type=int, default=1234, metavar='N',
                        help='random seed (default: 12345)')

    # Reward weight
    parser.add_argument('--rew_weight', type=float, default=0.8, metavar='G',
                        help='weight for pseudo reward(default: 1)')

    parser.add_argument('--use_reward_scheduling', type=bool, default=False, metavar='G',
                        help='enable reward scheduling (default: True)')
    parser.add_argument('--max_weight', type=float, default=0.8, metavar='G',
                        help='Specifies the maximum weight for the pseudo reward under a rescheduling regimen (default: 1.0)')
    parser.add_argument('--saturation_episode', type=float, default=15000, metavar='G',
                        help='Specifies the episode where saturation is reached. (default: 60000)')

    # SAC
    parser.add_argument('--policy', default="Gaussian",
                        help='Policy Type: Gaussian | Deterministic (default: Gaussian)')
    parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
                        help='discount factor for reward (default: 0.99)')
    parser.add_argument('--lr', type=float, default=0.0001, metavar='G',
                        help='learning rate (default: 0.0001)')
    parser.add_argument('--alpha', type=float, default=0.01, metavar='G',
                        help='Temperature parameter α determines the relative importance of the entropy\
                                term against the reward (default: 0.2)')
    parser.add_argument('--automatic_entropy_tuning', type=bool, default=True, metavar='G',
                        help='Automaically adjust α (default: True)')
    parser.add_argument('--tau', type=float, default=0.005, metavar='G',
                        help='target smoothing coefficient(τ) (default: 0.005)')
    parser.add_argument('--target_update_interval', type=int, default=1, metavar='N',
                        help='Value target update per no. of updates per step (default: 1)')

    # Buffer RL
    parser.add_argument('--batch_size', type=int, default=256, metavar='N',
                        help='minibatch size of batch (default: 256)')
    parser.add_argument('--replay_size', type=int, default=10000, metavar='N',
                        help='total capacity of replay buffer (default: 1000000)')

    # Buffer Trajectory 
    parser.add_argument('--traj_batch_size', type=int, default=32, metavar='N',
                        help='minibatch size of traj buffer(default: 16)')
    parser.add_argument('--traj_eval_size', type=int, default=8, metavar='N',
                        help='minibatch size of traj buffer for eval (default: 16)')
    parser.add_argument('--traj_replay_size', type=int, default=128, metavar='N',
                        help='total capacity of trajectory replay buffer (default: 32)')
    parser.add_argument('--num_samples_per_trajectory', type=int, default=32, metavar='N',
                        help='num_samples_per_trajectory (default: 16)')

    # Network
    parser.add_argument('--hidden_size', type=int, default=1024, metavar='N',
                        help='hidden size (default: 1024)')

    # Psi
    parser.add_argument('--lambda_value_1', type=float, default=5, metavar='G',
                        help='lambda_value for psi (default: 10)')
    parser.add_argument('--lambda_value_L', type=float, default=5, metavar='G',
                        help='lambda_value for psi (default: 10)')
    parser.add_argument('--epsilon', type=float, default=1e-5, metavar='G',
                        help='epsilon for psi (default: 1e-5)')

    # Max/start step
    parser.add_argument('--start_steps', type=int, default=100, metavar='N',
                        help='Steps sampling random actions (default: 10000)')
    parser.add_argument('--num_steps', type=int, default=100000000, metavar='N',
                        help='maximum number of steps (default: 1000000)')
    parser.add_argument('--num_episode', type=int, default=500000, metavar='N',
                        help='maximum number of episodes (default: 500000)')


    # Update to data(UTD) ratio
    # (Option1) : Sequential Updates Method
    parser.add_argument('--sac_gradient_steps_per_epoch', type=int, default=64, metavar='N',
                        help='model updates per epoch (default: 1)')
    parser.add_argument('--psi_gradient_steps_per_epoch', type=int, default=64, metavar='N',
                        help='model updates per epoch (default: 1)')
    # (Option2) : Alternating Updates Method
    parser.add_argument('--gradient_steps_per_epoch', type=int, default=64, metavar='N',
                        help='model updates per epoch (default: 1)')

    parser.add_argument('--episodes_per_epoch', type=int, default=8, metavar='N',
                        help='episodes per epoch (default: 1)')

    # Evaluation period
    parser.add_argument('--eval_epoch_ratio', type=int, default=250, metavar='N',
                        help='model updates per epoch (default: 1)')

    # Video : for Ant-v3 env, setting fps under 50 is recommended
    parser.add_argument('--video_fps', type=int, default=100, metavar='N',
                        help='fps for video (default: 100)')

    # Latent dimension
    parser.add_argument('--radius_latent_dim', type=int, default=6, metavar='N',
                        help='dimension of radius latent (default: 3)')
    parser.add_argument('--radius_input_dim', type=int, default=8, metavar='N',
                        help='dimension of radius (default: 3)')
    parser.add_argument('--radius_bound', type=str, default='10,11',
                        help='Comma-separated values')
    parser.add_argument('--num_intervals', type=int, default=4, metavar='N',
                        help='number of intervals for sampling radius (default: 4)')
    parser.add_argument('--use_adaptive_sampling', type=bool, default=True, metavar='G',
                        help='enable use adaptive sampling (default: True)')
             
    # Device
    parser.add_argument('--cuda', action="store_false",
                        help='run on CUDA (default: True)')
    
    return parser.parse_args()