import argparse
from src.models import load_model_and_tokenizer
from src.stanford_watermark import Stanford_watermark
from src.attacks.stanford_detection import generate_stanford_test_data
from src.attacks.cache_detection import generate_cache_detection_data, get_cacheTest_phase1
from src.attacks.kgw_detection import generate_data_custom_delta, get_model_info
import os
import pickle as pkl

parser = argparse.ArgumentParser(description='Generate integer switch data')
parser.add_argument('--model_name', type=str, default="meta-llama/Llama-2-7b-chat-hf", help='HF model name')
parser.add_argument('--key_length', type=int, default=256, help='HF model name')
parser.add_argument("--context", type=int, default=5, help="context size")
parser.add_argument("--temperature", type=float, default=1.0, help="Temperature for sampling")
parser.add_argument("--custom_name", type=str, default="", help="path name for custom results")
parser.add_argument("--keys",  nargs='+',type=int, help="key for watermark")
parser.add_argument("--test",  type=str, help="which test to run")
parser.add_argument("--max_new_tokens", type=int, default=50, help="max new tokens generated")
parser.add_argument("--n_samples", type=int, default=1000, help="Number of samples for stanford test")
parser.add_argument("--disable_watermark_every", type=int, default=0, help="Disable watermark every n requests")
parser.add_argument("--percent_disable", type=float, default=0, help="Percent of requests to disable watermark")
parser.add_argument("--n_trial",  type=int, default=1, help="Number of trials for alpha estimation")

PREFIXES = ["I ate", "I chose", "I picked", "I selected", "I took", "I went for", "I settled on", "I got", "I gathered", "I harvested"]
FORMAT = "Sure! "

    

if __name__ == "__main__":
    
    args = parser.parse_args()
    model_name = args.model_name
    key_length = args.key_length
    context = args.context
    temperature = args.temperature
    CUSTOM_NAME = args.custom_name
    keys = args.keys
    test = args.test
    max_new_tokens = args.max_new_tokens
    n_samples = args.n_samples
    disable_watermark_every = args.disable_watermark_every
    percent_disable = args.percent_disable
    n_trial = args.n_trial
    model, tokenizer = load_model_and_tokenizer(model_name=model_name)
    model_name = model_name.replace("/", "_")
    word_list, example, format = get_model_info(model_name).values()
    
    print(f"Launching generation for {model_name} with context {context} under key length {key_length}")
        
    watermark = Stanford_watermark(
        model=model,
        tokenizer=tokenizer,
        key_length=key_length,
        seeds=keys,
        disable_watermark_every=disable_watermark_every,
        percent_disable=percent_disable
    )
    
    w = "Stanford"

    
    if test == "KGW":
        
        # Open the prev out if it exists
        path = f"pkl_results/{test}/{model_name}/{w}/{CUSTOM_NAME}_context{context}_{temperature}_{keys}.pkl"
        if os.path.exists(path):
            out = pkl.load(open(path, "rb"))
        else:
            out = {}
        
        prob_out, logits_out  = generate_data_custom_delta(watermark, temperature, word_list, context, example, format)
        out[key_length] = prob_out
        out[str(key_length) + "_logits"] = logits_out

        
        if os.path.exists(f"pkl_results/{test}/{model_name}/{w}") is False:
            os.makedirs(f"pkl_results/{test}/{model_name}/{w}")
            
        with open(path, "wb") as f:
            pkl.dump(out, f)
        
    elif test == "stanford":
        
        path = f"pkl_results/{test}/{model_name}/{w}/{CUSTOM_NAME}_maxtokens{max_new_tokens}_{temperature}_{keys}_{key_length}.pkl"
        folder_path = f"pkl_results/{test}/{model_name}/{w}"
        
        generate_stanford_test_data(watermark, temperature, n_samples, max_new_tokens, folder_path, path)

    elif test == "cache":
        
        word_list, example, format = get_cacheTest_phase1(model_name).values()
        
        path = f"pkl_results/{test}/{model_name}/{w}/{CUSTOM_NAME}_{temperature}_{keys}_{key_length}.pkl"
        
        if os.path.exists(path):
            out = pkl.load(open(path, "rb"))
        else:
            out = {}
            
        prob_out, logits_out, n_trials_phase1  = generate_cache_detection_data(watermark, temperature,word_list, example, format, n_trial)
        out[0] = prob_out
        out[str(0) + "_logits"] = logits_out
        
        if os.path.exists(f"pkl_results/{test}/{model_name}/{w}") is False:
            os.makedirs(f"pkl_results/{test}/{model_name}/{w}")
        
        with open(path, "wb") as f:
            pkl.dump(out, f)