# 특정 파일에 대해서만 eval
# 특정 디렉토리에 대해서 eval
import argparse
from evaluate_module import get_ppl, get_acc, get_auroc, get_fpr_tpr

def set_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--mode",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--model_name",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--model_size",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--dataset_name",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--use_one_file",
        action="store_true",
    )
    parser.add_argument(
        "--file_path",
        type=str,
    )
    parser.add_argument(
        "--cache_dir",
        type=str,
        default='/data2/models'
    )
    parser.add_argument(
        "--data_cache_dir",
        type=str,
        default='/data2/datasets'
    )
    parser.add_argument(
        "--use_200push",
        action="store_true",
        help="When you generate watermark, set min_new_tokens (default : False)",
    )
    parser.add_argument(
        "--use_topk",
        action="store_true",
        help="topk (default : False)",
    )
    parser.add_argument(
        "--use_topp",
        action="store_true",
        help="topp (default : False)",
    )
    parser.add_argument(
        "--use_sampling",
        action="store_true",
        help="Whether to generate using multinomial sampling. (default : False)",
    )
    parser.add_argument(
        "--n_beams",
        type=int,
        default=1,
        help="Number of beams to use for beam search. 1 is normal greedy decoding",
    )
    parser.add_argument(
        "--attack_suffix",
        type=str,
        help="Attack suffix for attacked data evaluation (e.g., word_del, syn_sub)",
    )
    args = parser.parse_args()
    return args



if __name__=="__main__":
    args = set_args()

    if args.mode=="ppl":
        get_ppl(args)
    elif args.mode=="acc":
        get_acc(args)
    elif args.mode=="auroc" or args.mode=="auro":
        get_auroc(args)
    elif args.mode=="fpr_tpr":
        get_fpr_tpr(args)
    else:
        raise ValueError(f"Unsupported mode name ({args.mode}).")