import argparse
from src.models import load_model_and_tokenizer
from src.DeltaReweight_watermark import DeltaReweight_watermark
from src.attacks.stanford_detection import generate_stanford_test_data
from src.attacks.kgw_detection import generate_data_custom_delta, get_model_info
import os
import pickle as pkl

parser = argparse.ArgumentParser(description='Generate integer switch data')
parser.add_argument('--model_name', type=str, default="meta-llama/Llama-2-7b-chat-hf", help='HF model name')
parser.add_argument('--test', type=str, default="DiPMarks", help='test')
parser.add_argument("--context", type=int, default=5, help="Context length")
parser.add_argument("--temperature", type=float, default=1.0, help="Temperature for sampling")
parser.add_argument("--custom_name", type=str, default="", help="path name for custom results")
parser.add_argument("--keys",  nargs='+',type=int, help="key for watermark")
parser.add_argument("--max_new_tokens", type=int, default=50, help="max new tokens generated")
parser.add_argument("--n_samples", type=int, default=1000, help="Number of samples for stanford test")
parser.add_argument("--disable_watermark_every", type=int, default=0, help="Disable watermark every n requests")

PREFIXES = ["I ate", "I chose", "I picked", "I selected", "I took", "I went for", "I settled on", "I got", "I gathered", "I harvested"]
FORMAT = "Sure! "


if __name__ == "__main__":
    
    args = parser.parse_args()
    model_name = args.model_name
    temperature = args.temperature
    CUSTOM_NAME = args.custom_name
    keys = args.keys
    test = args.test
    max_new_tokens = args.max_new_tokens
    n_samples = args.n_samples
    context = args.context
    disable_watermark_every = args.disable_watermark_every
        
    model, tokenizer = load_model_and_tokenizer(model_name=model_name)
    model_name = model_name.replace("/", "_")
    
    print(f"Launching generation for {model_name}")
        
    watermark = DeltaReweight_watermark(
        model=model,
        tokenizer=tokenizer,
        seeds=keys,
        context_size=3,
        disable_watermark_every=disable_watermark_every
    )
    
    w = "DeltaReweight"
        
    if test == "stanford":
        
        path = f"pkl_results/{test}/{model_name}/{w}/{CUSTOM_NAME}_maxtokens{max_new_tokens}_{temperature}_{keys}.pkl"
        folder_path = f"pkl_results/{test}/{model_name}/{w}"
        
        generate_stanford_test_data(watermark, temperature, n_samples, max_new_tokens, folder_path, path)
    
    elif test == "KGW":
        
        path = f"pkl_results/{test}/{model_name}/{w}/{CUSTOM_NAME}_maxtokens{max_new_tokens}_{temperature}_{keys}.pkl"
        folder_path = f"pkl_results/{test}/{model_name}/{w}"
        
        word_list, example, format = get_model_info(model_name).values()
        
        prob_out, logits_out  = generate_data_custom_delta(watermark, temperature, word_list, context, example, format)

        out = {0: prob_out, "0_logits": logits_out}

        if os.path.exists(folder_path) is False:
            os.makedirs(folder_path)
            
        with open(path, "wb") as f:
            pkl.dump(out, f)