import argparse

def parse_args():
    parser = argparse.ArgumentParser(description="Load arguments", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    
    ### Experiment settings
    parser.add_argument("--exp_name", type=str, default="test", help="The name of the experiment.")
    
    ### Environment settings
    parser.add_argument("--seed", type=int, default=11, help="Random seed for running experiments.")
    parser.add_argument("--task_seed", type=int, default=42, help="Random seed for generating tasks.")
    parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID.")
    parser.add_argument("--parallel", type=bool, default=False, help="Whether to run in parallel, i.e., use multiple GPUs.")
    parser.add_argument("--root_save_path", type=str, default=None, help="The root path to save the experiment results.")
    
    ### Model settings
    parser.add_argument("--model_name", type=str, default="llama-3", help="The language model adopted as the user.")
    parser.add_argument("--max_seq_len", type=int, default=512, help="The maximum of the sequence, no more than 8192.")
    parser.add_argument("--max_batch_size", type=int, default=4, help="The maximum of the batch size.")
    parser.add_argument("--num_responses", type=int, default=1, help="The number of responses generated by LLMs.")
    parser.add_argument("--temperature", type=float, default=0.2, help="The temperature coefficient of LLMs.")
    parser.add_argument("--top_p", type=float, default=0.9, help="Top-p sampling.")
    parser.add_argument("--max_tokens", type=int, default=4096, help="Maximum tokens.")
    
    ### Task mode
    parser.add_argument("--data_type", type=str, default="2D", choices=["2D", "3D"], help="The type of classificationdata.")
    parser.add_argument("--task_mode", type=str, default="linear_classification", help="The task mode to be tested.")
    parser.add_argument("--pred_mode", type=str, default="all", choices=["all", "single"], help="The prediction mode for conventional methods.")
    
    ### Data Settings
    parser.add_argument("--batch_size", type=int, default=10, help="Batch size used in dataloader.")
    parser.add_argument("--precision", type=int, default=None, help="The precision of data in each task.")
    parser.add_argument("--circle_task_noise", type=float, default=0.03, help="The noise coefficient of circle task.")
    parser.add_argument("--num_samples", type=int, default=128, help="The number of data samples in tasks.")
    parser.add_argument("--num_classes", type=int, default=2, help="The number of classes in tasks.")
    parser.add_argument("--num_coordinate", type=int, default=50, help="The number of coordinate points in each axis.")
    parser.add_argument("--num_query", type=int, default=2500, help="The number of query points in each (>2D) task.")
    parser.add_argument("--num_eval", type=int, default=1000, help="The number of evaluation points in each task. Only used in validation experiments.")
    
    ### Prompt settings
    parser.add_argument("--prompt_mode", type=str, default="standard", help="The prompt mode.")
    
    parser.add_argument("--inference_mode", type=str, default="implicit", choices=["implicit", "explicit"], help="The inference mode.")
    
    ### Algorithm settings
    parser.add_argument("--ml_alg", type=str, default=None, help="Machine learning algorithms.")
    parser.add_argument("--prob_type", type=str, default="standard", 
                        choices=["standard", "ml-only", "uniform", "random"], help="The type of probability.")
    parser.add_argument("--hybrid_type", type=str, default="prob", 
                        choices=["prob", "hsic"], help="The type of perform classification in the hybrid way.")
    
    args = parser.parse_args()
    return args