
import argparse


def parser_args():
    parser = argparse.ArgumentParser(description='PyTorch Soft Actor-Critic Args')

    # Exp
    parser.add_argument('--exp_name', default="TODO",
                        help='name of experiment')
    
    # Env
    parser.add_argument('--env_name', default="Halfcheetah-v3",
                        help='Mujoco Gym environment (default: Halfcheetah-v3)')
    parser.add_argument('--seed', type=int, default=123, metavar='N',
                        help='random seed (default: 12345)')

    # Action repeat
    parser.add_argument('--num_action_repeat', type=int, default=1, metavar='N',
                        help='number of action repeat (default: 4)')

    # Image related
    parser.add_argument('--output_size', type=int, default=84, metavar='N',
                        help='output size of crop operation (default: 84)')
    parser.add_argument('--num_stacked_frames', type=int, default=3, metavar='N',
                        help='num_stacked_frames (default: 3)')

    # CNN encoder
    parser.add_argument('--cnn_depth', type=int, default=12, metavar='N',
                        help='cnn_depth (default: 48)')

    # Reward weight
    parser.add_argument('--rew_weight', type=float, default=0.8, metavar='G',
                        help='weight for pseudo reward(default: 1)')

    parser.add_argument('--use_reward_scheduling', type=bool, default=False, metavar='G',
                        help='enable reward scheduling (default: True)')
    parser.add_argument('--max_weight', type=float, default=0.6, metavar='G',
                        help='Specifies the maximum weight for the pseudo reward under a rescheduling regimen (default: 1.0)')
    parser.add_argument('--saturation_episode', type=float, default=3000, metavar='G',
                        help='Specifies the episode where saturation is reached. (default: 60000)')

    # SAC
    parser.add_argument('--policy', default="Gaussian",
                        help='Policy Type: Gaussian | Deterministic (default: Gaussian)')
    parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
                        help='discount factor for reward (default: 0.99)')
    parser.add_argument('--lr', type=float, default=0.0001, metavar='G',
                        help='learning rate (default: 0.0001)')
    parser.add_argument('--alpha', type=float, default=0.2, metavar='G',
                        help='Temperature parameter α determines the relative importance of the entropy\
                                term against the reward (default: 0.2)')
    parser.add_argument('--automatic_entropy_tuning', type=bool, default=True, metavar='G',
                        help='Automaically adjust α (default: True)')
    parser.add_argument('--tau', type=float, default=0.005, metavar='G',
                        help='target smoothing coefficient(τ) (default: 0.005)')
    parser.add_argument('--target_update_interval', type=int, default=1, metavar='N',
                        help='Value target update per no. of updates per step (default: 1)')

    # Buffer RL
    parser.add_argument('--batch_size', type=int, default=256, metavar='N',
                        help='minibatch size of batch (default: 256)')
    parser.add_argument('--replay_size', type=int, default=30000, metavar='N',
                        help='total capacity of replay buffer (default: 1000000)')

    # Buffer Trajectory 
    parser.add_argument('--traj_batch_size', type=int, default=32, metavar='N',
                        help='minibatch size of traj buffer(default: 16)')
    parser.add_argument('--traj_eval_size', type=int, default=8, metavar='N',
                        help='minibatch size of traj buffer for eval (default: 16)')
    parser.add_argument('--traj_replay_size', type=int, default=256, metavar='N',
                        help='total capacity of trajectory replay buffer (default: 32)')
    parser.add_argument('--num_samples_per_trajectory', type=int, default=16, metavar='N',
                        help='num_samples_per_trajectory (default: 16)')

    # Network
    parser.add_argument('--hidden_size', type=int, default=1024, metavar='N',
                        help='hidden size (default: 256 (for image))')

    # Psi
    parser.add_argument('--lambda_value', type=float, default=3, metavar='G',
                        help='lambda_value for psi (default: 10)')
    parser.add_argument('--epsilon', type=float, default=1e-5, metavar='G',
                        help='epsilon for psi (default: 1e-5)')

    # Max/start step
    parser.add_argument('--start_steps', type=int, default=100, metavar='N',
                        help='Steps sampling random actions (default: 10000)')
    parser.add_argument('--num_steps', type=int, default=5000000, metavar='N',
                        help='maximum number of steps (default: 1000000)')
    parser.add_argument('--num_episode', type=int, default=500000, metavar='N',
                        help='maximum number of episodes (default: 500000)')


    # Update to data(UTD) ratio
    # (Option1) : Sequential Updates Method
    parser.add_argument('--sac_gradient_steps_per_epoch', type=int, default=64, metavar='N',
                        help='model updates per epoch (default: 1)')
    parser.add_argument('--psi_gradient_steps_per_epoch', type=int, default=64, metavar='N',
                        help='model updates per epoch (default: 1)')
    # (Option2) : Alternating Updates Method
    parser.add_argument('--gradient_steps_per_epoch', type=int, default=64, metavar='N',
                        help='model updates per epoch (default: 1)')

    parser.add_argument('--episodes_per_epoch', type=int, default=8, metavar='N',
                        help='episodes per epoch (default: 1)')

    # Evaluation period
    parser.add_argument('--eval_epoch_ratio', type=int, default=250, metavar='N',
                        help='model updates per epoch (default: 125)')

    # Video : for Ant-v3 env, setting fps under 50 is recommended
    parser.add_argument('--video_fps', type=int, default=50, metavar='N',
                        help='fps for video (default: 100)')

    # Latent dimension
    parser.add_argument('--radius_latent_dim', type=int, default=3, metavar='N',
                        help='dimension of radius latent (default: 3)')
    parser.add_argument('--radius_input_dim', type=int, default=128, metavar='N',
                        help='dimension of radius (default: 3)')
    parser.add_argument('--radius_bound', type=str, default='3,6',
                        help='Comma-separated values')
    parser.add_argument('--num_intervals', type=int, default=3, metavar='N',
                        help='number of intervals for sampling radius (default: 4)')
                        
    # Device
    parser.add_argument('--cuda', action="store_false",
                        help='run on CUDA (default: True)')
    
    return parser.parse_args()