# %%
# in this notebook, we will do abaltion study on the effectiveness of the adaptive thresholding method on the selective watermarking.
# as well as the impact of watermark strength on our method.


wm_strength=[1.5,2,2.5,3,3.5,4]

z_value={
    "with":{},
    "without":{}
}

perplexity={
    "with":{},
    "without":{}
}


cossim={
    "with":{},
    "without":{}
}


for key in wm_strength:
    for category in ["with", "without"]:
        z_value[category][key] = []
        perplexity[category][key] = []
        cossim[category][key] = []




# %%
from watermark import Detector,Watermark
from watermark_ablation import Watermark_ablation,Detector_ablation
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM,AutoModel,OPTForCausalLM
from sentence_transformers import SentenceTransformer, util
import torch

torch.cuda.set_device(1)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

dataset = load_dataset("json", data_files="./LTW/c4_subset_500.jsonl")
dataset=dataset["train"]



model_path="./LTW/models/opt-6.7b"


model= AutoModelForCausalLM.from_pretrained(model_path,torch_dtype=torch.float16).to(device)
tokenizer=AutoTokenizer.from_pretrained(model_path)
model.eval()


def get_first_n_words(text, n=200):
    words = text.split()  
    return ' '.join(words[:n])  


model_name = "./LTW/models/opt-1.3b"
ppl_model = OPTForCausalLM.from_pretrained(model_name).to(device)
ppl_tokenizer = AutoTokenizer.from_pretrained(model_name)

def calculate_perplexity(text):

    inputs = ppl_tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)

    with torch.no_grad():

        outputs = ppl_model(**inputs, labels=inputs["input_ids"])
        

        loss = outputs.loss 
        

    perplexity = torch.exp(loss)  
    return perplexity.item()


semantic_model = SentenceTransformer('./hf_models/all-mpnet-base-v2') 

def calculate_sim(text1,text2):


    embedding1 = semantic_model.encode(text1, convert_to_tensor=True)
    embedding2 = semantic_model.encode(text2, convert_to_tensor=True)
    
    cosine_similarity = util.pytorch_cos_sim(embedding1, embedding2)
    return cosine_similarity.item()


ckpt_path="./LTW/ckpt/tmp/selective_network_epoch0_step2000.pth"
wm = Watermark(checkpoint_path=ckpt_path,device=device,k=6,model=model,tokenizer=tokenizer, max_new_tokens= 225,min_new_tokens=175,embed_unigram_wm=True)
wm_ablation=Watermark_ablation(checkpoint_path=ckpt_path,device=device,k=6,model=model,tokenizer=tokenizer, max_new_tokens= 225,min_new_tokens=175,embed_unigram_wm=True)

gamma=0.25

detector = Detector(vocab=list(tokenizer.get_vocab().values()),
                                        gamma=gamma,
                                        tokenizer=tokenizer,
                                        z_threshold=4,
                                        model=model,
                                        k=6,
                                        checkpoint_path="./LTW/ckpt/tmp/selective_network_epoch0_step2000.pth",
                                        embed_unigram_wm=True,   
                                        )

detector_ablation = Detector_ablation(vocab=list(tokenizer.get_vocab().values()),
                                        gamma=gamma,
                                        tokenizer=tokenizer,
                                        z_threshold=4,
                                        model=model,
                                        k=6,
                                        checkpoint_path="./LTW/ckpt/tmp/selective_network_epoch0_step2000.pth",
                                        embed_unigram_wm=True,   
                                        )


# %%
for delta in wm_strength:
    for data in dataset:
        text=data['text'] 
        input_text=text[:300]
        human_ans=get_first_n_words(text[300:],200)

        output=wm.generate_watermark(input_text,gamma,delta)
        output=output[0]
        ppl=calculate_perplexity(output)
        sim=calculate_sim(human_ans,output)
        perplexity["with"][delta].append(ppl)
        cossim["with"][delta].append(sim)

        output=input_text+output
        tokenized_input=tokenizer.encode(input_text,  return_tensors='pt',add_special_tokens=False).to(device)
        tokenized_output=tokenizer.encode(output, return_tensors='pt').to(device)
        tokenized_output=tokenized_output[0]
        tokenized_input=tokenized_input[0]
        detection_result=detector.detect(tokenized_output,tokenized_input)
        z_value["with"][delta].append(detection_result['z_score'])

        

        output_ablation=wm_ablation.generate_watermark(input_text,gamma,delta)
        output_ablation=output_ablation[0]
        ppl_ablation=calculate_perplexity(output_ablation)
        sim_ablation=calculate_sim(human_ans,output_ablation)
        perplexity["without"][delta].append(ppl_ablation)
        cossim["without"][delta].append(sim_ablation)

        output_ablation=input_text+output_ablation
        tokenized_input=tokenizer.encode(input_text,  return_tensors='pt',add_special_tokens=False).to(device)
        tokenized_output=tokenizer.encode(output_ablation, return_tensors='pt').to(device)
        tokenized_output=tokenized_output[0]
        tokenized_input=tokenized_input[0]
        detection_result=detector.detect(tokenized_output,tokenized_input)
        z_value["without"][delta].append(detection_result['z_score'])


eval_records={
    "perplexity":perplexity,
    "similarity":cossim,
    "z_value":z_value
}
import json
with open('./LTW/eval_records/ablation.json', 'w') as f:
    json.dump(eval_records, f)
        

        
    



