from argparse import ArgumentParser


def parse_args():
    """Parses the command line arguments."""
    parser = ArgumentParser()
    parser.add_argument(
        "--model_base_dir",
        default="PATH_TO_MODELS/",
        help="The directory where the model is stored",
    )
    parser.add_argument(
        "--fitness_mode",
        default="BIAS_PPL",
        help="The fitness function used for the evolutionary algorithm",
    )
    parser.add_argument(
        "--path_to_prompts",
        default="./FASP/prompts/holistic/",
        help="The directory where the prompts are stored",
    )
    parser.add_argument(
        "--path_to_head_contributions",
        default="./FASP/model/head_contributions.json",
        help="The directory where the head contributions are stored",
    )
    parser.add_argument(
        "--path_to_models_config",
        default="./FASP/model/models_config.json",
        help="The directory where the model configurations are stored",
    )
    parser.add_argument(
        "--path_to_tox_model",
        default="PATH_TO_TOXICITY_MODEL/",
        help="The directory where the unbiased model is stored",
    )
    parser.add_argument(
        "--pop_size",
        type=int,
        default=5,
        help="The size of the population for the evolutionary algorithm",
    )
    parser.add_argument(
        "--evo_epoch",
        type=int,
        default=10,
        help="The number of epochs for the evolutionary algorithm",
    )
    parser.add_argument(
        "--pop_init_mutation_rate",
        type=float,
        default=0.95,
        help="The mutation rate for the initial population",
    )
    parser.add_argument(
        "--mutation_rate",
        type=float,
        default=0.1,
        help="The mutation rate for the evolutionary algorithm",
    )
    parser.add_argument(
        "--scalar",
        type=float,
        default=1,
        help="The scalar for the fitness function",
    )
    parser.add_argument(
        "--device",
        type=str,
        default="cuda:1",
        help="The device that we are using. We normally use cuda:1.",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=1,
        help="The seed that we are using. We normally run every experiment for 5 seeds.",
    )
    parser.add_argument(
        "--head_knockout",
        type=int,
        default=None,
        help="the id of the attention head to be knocked out in the language generation models",
    )
    parser.add_argument(
        "--model",
        choices=[
            "gpt2",
            "distilgpt2",
        ],
        default="distilgpt2",
        help="Type of language generation model used",
    )
    parser.add_argument(
        "--method",
        choices=[
            "FASP",
            "FPVE",
        ],
        default="FPVE",
        help="Method for pruning the attention heads",
    )
    parser.add_argument(
        "--nsgaii",
        action="store_true",
        default=False,
        help="Whether or not to use the NSGA-II algorithm for the evolutionary algorithm",
    )
    parser.add_argument(
        "--ppl_threshold",
        type=float,
        default=0,
        help="The ppl threshold for evolution, default is 0 which means no threshold",
    )
    parser.add_argument(
        "--parent_selection",
        type=str,
        choices=["random", "roulette_wheel", "binary_tournament"],
        default="random",
        help="Parent selection method for evolutionary algorithm. "
        "'rolette_wheel_selection' cannot be used with NSGA-II. "
        "'binary_tournament_selection' is compatible with NSGA-II. "
        "Default is 'random'.",
    )
    parser.add_argument(
        "--pruned_heads_ratio",
        type=float,
        default=0.0,
        help="The ratio of the pruned attention heads, which is referred to as alpha in the main paper.",
    )
    parser.add_argument(
        "--gamma",
        type=float,
        default=0.3,
        help="The hyperparameter controling the percentage of examples that are considered important for performance",
    )
    parser.add_argument(
        "--prompting",
        choices=[
            "holistic",
        ],
        default="holistic",
        help="Type of prompt used for the language model",
    )
    parser.add_argument(
        "--targeted_holistic_bias",
        choices=[
            "gender_and_sex",
        ],
        default="gender_and_sex",
        help="The group for which biased is assessed using the holistic bias framework",
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        default=1024,
        help="Batch size for the language model.",
    )
    parser.add_argument(
        "--stride",
        type=int,
        default=512,
        help="Stride used for computing the model preplexity. This corresponds to the number of tokens the model conditions on at each step.",
    )
    parser.add_argument(
        "--max_continuation_length",
        type=int,
        default=40,
        help="The maximum length of the continuation for the language generation model",
    )
    parser.add_argument(
        "--max_prompt_length",
        type=int,
        default=22,
        help="The maximum length of the prompt for the language generation model",
    )
    parser.add_argument(
        "--output_dir",
        default="PATH_TO_OUTPUT/",
        help="Directory to the output",
    )
    parser.add_argument(
        "--use_gender_scores",
        type=bool,
        default=True,
        help="Whether or not to use the head scores for gender bias when reducing other biases",
    )

    return parser.parse_args()
