import argparse
import json
import os

def load_system_config():
    config_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'configs', 'SystemConfig.json')
    with open(config_path, 'r', encoding='utf-8') as f:
        return json.load(f)

def set_args():
    system_config = load_system_config()
    
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--mode",
        type=str,
        required=True,
        help="watermark or detect",
    )
    parser.add_argument(
        "--model_name",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--model_size",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--dataset_name",
        type=str,
        required=True,
        help="gsm, c4, c4_ko, c4_ja, mgsm_ja, gsm_ko",
    )
    parser.add_argument(
        "--water_mode",
        type=str,
        required=True,
        help="choose watermarking mode -> wo_watermark, patternbase, patternlen, kgw, watme, exp",
    )
    parser.add_argument(
        "--disable_gpu",
        action="store_true",
    )
    parser.add_argument(
        "--seeding_scheme",
        type=str,
        default="simple_1",
        help="Seeding scheme to use to generate the greenlists at each generation and verification step.",
    )
    parser.add_argument(
        "--gamma",
        type=float,
        help="The fraction of the vocabulary to partition into the greenlist at each generation and verification step.",
    )
    parser.add_argument(
        "--delta",
        type=float,
        help="The amount/bias to add to each of the greenlist token logits before each token sampling step.",
    )
    parser.add_argument(
        "--alpha",
        type=float,
        help="The parameter for adaptive watermarking.",
    )
    parser.add_argument(
        "--token_cut_length",
        type=int,
        help="Token length cutoff for pattern-based watermarking methods (patternbase, patternlen). Tokens with length <= this value go to neutral group.",
    )
    parser.add_argument(
        "--model_cache_dir",
        type=str,
        default=system_config["model_cache_dir"],
        help="huggingface model and tokenizer cache directory",
    )
    parser.add_argument(
        "--data_cache_dir",
        type=str,
        default=system_config["data_cache_dir"],
        help="huggingface datasets cache directory",
    )
    parser.add_argument(
        "--hash_key",
        type=int,
        default=15485863
    )

    parser.add_argument(
        "--max_new_tokens",
        type=int,
        default=200,
        help="Maximmum number of new tokens to generate.",
    )
    parser.add_argument(
        "--generation_seed",
        type=int,
        default=123,
        help="Seed for setting the torch global rng prior to generation.",
    )
    parser.add_argument(
        "--topp",
        type=float,
        help="topp (default : False)",
    )
    parser.add_argument(
        "--sampling_temp",
        type=float,
        default=0.7,
        help="Sampling temperature to use when generating using multinomial sampling.",
    )
    parser.add_argument(
        "--n_beams",
        type=int,
        default=1,
        help="Number of beams to use for beam search. 1 is normal greedy decoding",
    )
    parser.add_argument(
        "--attack_suffix",
        type=str,
        help="Attack suffix for the attack name.",
    )
    args = parser.parse_args()
    return args