import argparse
from src.models import load_model_and_tokenizer
from src.kgw_watermarks import KGW_watermark
from src.attacks.kgw_detection import generate_data_custom_delta, generate_data_context_size_estimation, get_model_info
import os
from src.attacks.cache_detection import generate_cache_detection_data, get_cacheTest_phase1
from src.attacks.stanford_detection import generate_stanford_test_data
import pickle as pkl
from src.utils import logit
import numpy as np

parser = argparse.ArgumentParser(description='Generate integer switch data')
parser.add_argument('--model_name', type=str, default="meta-llama/Llama-2-7b-chat-hf", help='HF model name')
parser.add_argument('--seeding_scheme', type=str, default="lefthash", help='HF model name')
parser.add_argument("--context", type=int, default=5, help="context size")
parser.add_argument("--max_new_tokens", type=int, default=50, help="max new tokens generated")
parser.add_argument("--delta", type=float, default=4.0, help="delta value")
parser.add_argument("--gamma", type=float, default=0.25, help="gamma value")
parser.add_argument("--n_samples", type=int, default=0, help="Number of samples for p-value")
parser.add_argument("--temperature", type=float, default=1.0, help="Temperature for sampling")
parser.add_argument("--custom_name", type=str, default="", help="path name for custom results")
parser.add_argument("--keys",  nargs='+',type=int, help="key for watermark")
parser.add_argument("--test",  type=str, help="which test to run")
parser.add_argument("--n_trial",  type=int, default=1, help="Number of trials for alpha estimation")
parser.add_argument("--disable_watermark_every", type=int, default=0, help="Disable watermark every n requests")


PREFIXES = ["I ate", "I chose", "I picked", "I selected", "I took", "I went for", "I settled on", "I got", "I gathered", "I harvested"]



if __name__ == "__main__":
    
    args = parser.parse_args()
    model_name = args.model_name
    context = args.context
    delta_value = args.delta
    gamma = args.gamma
    seeding_scheme = args.seeding_scheme
    n_samples = args.n_samples
    temperature = args.temperature
    CUSTOM_NAME = args.custom_name
    keys = args.keys
    test = args.test
    max_new_tokens = args.max_new_tokens
    n_trial = args.n_trial
    disable_watermark_every = args.disable_watermark_every
        
    model, tokenizer = load_model_and_tokenizer(model_name=model_name)
    model_name = model_name.replace("/", "_")
    word_list, example, format = get_model_info(model_name).values()
    
    print(f"Launching generation for {model_name} with delta {delta_value} and gamma {gamma} and context {context} under seeding scheme {seeding_scheme}")
        
    watermark = KGW_watermark(
        model = model, 
        tokenizer= tokenizer, 
        delta = delta_value, 
        gamma = gamma, 
        seeding_scheme = seeding_scheme, 
        seeds = keys,
        disable_watermark_every=disable_watermark_every)
    
    if delta_value == 0:
        w = "no_watermark"
    else:
        w = "KGW"

    # Open the prev out if it exists
    

    if test == "KGW":
        
        if w == "no_watermark":
            path = f"pkl_results/{test}/{model_name}/{w}/{CUSTOM_NAME}_context{context}_{temperature}.pkl"
        else:
            path = f"pkl_results/{test}/{model_name}/{w}/{seeding_scheme}/{CUSTOM_NAME}_context{context}_gamma{gamma}_{temperature}_{keys}.pkl"
    
        if os.path.exists(path):
            out = pkl.load(open(path, "rb"))
        else:
            out = {}
        
        delta = delta_value
        prob_out, logits_out  = generate_data_custom_delta(watermark, temperature, word_list, context, example, format)
        out[delta] = prob_out
        out[str(delta) + "_logits"] = logits_out

        if w == "no_watermark":
            if os.path.exists(f"pkl_results/{test}/{model_name}/{w}") is False:
                os.makedirs(f"pkl_results/{test}/{model_name}/{w}")
        else:
            if os.path.exists(f"pkl_results/{test}/{model_name}/{w}/{seeding_scheme}") is False:
                os.makedirs(f"pkl_results/{test}/{model_name}/{w}/{seeding_scheme}")
            
        with open(path, "wb") as f:
            pkl.dump(out, f)
            

    elif test == "stanford":
        
        if w == "no_watermark":
            path = f"pkl_results/{test}/{model_name}/{w}/{CUSTOM_NAME}_maxtokens{max_new_tokens}_{temperature}.pkl"
            folder_path = f"pkl_results/{test}/{model_name}/{w}"
        else:
            path = f"pkl_results/{test}/{model_name}/{w}/{seeding_scheme}/{CUSTOM_NAME}_maxtokens{max_new_tokens}_gamma{gamma}_delta{delta_value}_{temperature}_{keys}.pkl"
            folder_path = f"pkl_results/{test}/{model_name}/{w}/{seeding_scheme}"
            
        generate_stanford_test_data(watermark, temperature, n_samples, max_new_tokens, folder_path, path)
 
    elif test == "cache":
            
        word_list, example, format = get_cacheTest_phase1(model_name).values()
        
        test = "cache"
        
        if w == "no_watermark":
            path = f"pkl_results/{test}/{model_name}/{w}/{CUSTOM_NAME}_{temperature}.pkl"
        else:
            path = f"pkl_results/{test}/{model_name}/{w}/{seeding_scheme}/{CUSTOM_NAME}_gamma{gamma}_{temperature}_{keys}.pkl"
    
        if os.path.exists(path):
            out = pkl.load(open(path, "rb"))
        else:
            out = {}
        
        prob_out, logits_out, n_trials_phase1  = generate_cache_detection_data(watermark, temperature,word_list, example, format, n_trial)
        out[delta_value] = prob_out
        out[str(delta_value) + "_logits"] = logits_out

        if w == "no_watermark":
            if os.path.exists(f"pkl_results/{test}/{model_name}/{w}") is False:
                os.makedirs(f"pkl_results/{test}/{model_name}/{w}")
        else:
            if os.path.exists(f"pkl_results/{test}/{model_name}/{w}/{seeding_scheme}") is False:
                os.makedirs(f"pkl_results/{test}/{model_name}/{w}/{seeding_scheme}")

        with open(path, "wb") as f:
            pkl.dump(out, f)
            
    elif test == "context":
        
        if w == "no_watermark": # normally this should not be used on non-watermarked model
            path = f"pkl_results/{test}/{model_name}/{w}/{CUSTOM_NAME}_context{context}_{temperature}.pkl"
        else:
            path = f"pkl_results/{test}/{model_name}/{w}/{seeding_scheme}/{CUSTOM_NAME}_context{context}_gamma{gamma}_{temperature}_{keys}.pkl"
    
        if os.path.exists(path):
            out = pkl.load(open(path, "rb"))
        else:
            out = {}
            
        word_list, example, format = get_model_info(model_name).values()
        
        if False:
            with open(f"pkl_results/KGW/{model_name}/{w}/{seeding_scheme}/{CUSTOM_NAME}_context{context}_gamma{gamma}_{temperature}_{keys}.pkl", "rb") as f:
                data = pkl.load(f)[delta_value]
                
                data = logit(data)
                weigth = np.sum(data, axis= (0,1))
                chosen = np.argmax(weigth)
                
                chosen_k = [np.argmax(data[0,:, chosen]) + 1]
        else:
            chosen_k = [1,2,3,4,5,6,7,8,9]
        
                
        
        delta = delta_value
        prob_out, logits_out  = generate_data_context_size_estimation(watermark, temperature, word_list, example,  format = format, digits_of_interest= chosen_k)
        out[delta] = prob_out
        out[str(delta) + "_logits"] = logits_out

        if w == "no_watermark":
            if os.path.exists(f"pkl_results/{test}/{model_name}/{w}") is False:
                os.makedirs(f"pkl_results/{test}/{model_name}/{w}")
        else:
            if os.path.exists(f"pkl_results/{test}/{model_name}/{w}/{seeding_scheme}") is False:
                os.makedirs(f"pkl_results/{test}/{model_name}/{w}/{seeding_scheme}")
            
        with open(path, "wb") as f:
            pkl.dump(out, f)
