import json
from datasets import load_dataset
import torch
from transformers import AutoTokenizer
from watermark import Detector
from sweet import SweetLogitsProcessor,SweetDetector
from kgw_watermark import WatermarkLogitsProcessor,WatermarkDetector
from transformers import AutoTokenizer,AutoModelForCausalLM,LogitsProcessorList
import torch
import numpy as np
from tqdm import tqdm

with open('./LTW/eval_records/eval_plot/dipper_attack_output.json', 'r') as f:
    dipper_output = json.load(f)


dataset = load_dataset("json", data_files="./LTW/c4_subset_500.jsonl")["train"]


detection_scores = {}


gamma = 0.25
z_threshold = 4
torch.cuda.set_device(4)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_path="./LTW/models/opt-6.7b"
model= AutoModelForCausalLM.from_pretrained(model_path).to(device)
tokenizer=AutoTokenizer.from_pretrained(model_path)
model.eval()


watermark_detector = Detector(
    vocab=list(tokenizer.get_vocab().values()),
    gamma=gamma,
    tokenizer=tokenizer,
    z_threshold=z_threshold,
    model=model,
    k=6,
    checkpoint_path="./LTW/ckpt/tmp/selective_network_epoch0_step2000.pth"
)

watermark_detector2 = Detector(
    vocab=list(tokenizer.get_vocab().values()),
    gamma=gamma,
    tokenizer=tokenizer,
    z_threshold=z_threshold,
    model=model,
    k=6,
    embed_unigram_wm=True,
    checkpoint_path="./LTW/ckpt/tmp/selective_network_epoch0_step2000.pth"
)

kgw_detector = WatermarkDetector(
    vocab=list(tokenizer.get_vocab().values()),
    gamma=gamma,
    tokenizer=tokenizer,
    z_threshold=z_threshold
)


for distortion_key in dipper_output:
    detection_scores[distortion_key] = {
        'our_wm_output1': [],
        'our_wm_output': [],
        'kgw_wm_output': []
    }

    for count, data in enumerate(tqdm(dataset, desc=f"Processing {distortion_key}")):
        input_text = data['text'][:300]

        # our_wm_output1 
        output = input_text + dipper_output[distortion_key]['our_wm_output1'][count]
        tokenized_input = tokenizer.encode(input_text, return_tensors='pt', add_special_tokens=False).to(device)[0]
        tokenized_output = tokenizer.encode(output, return_tensors='pt').to(device)[0]
        z_score = watermark_detector.detect(tokenized_output, tokenized_input)['z_score']
        detection_scores[distortion_key]['our_wm_output1'].append(z_score)

        # our_wm_output 
        output = input_text + dipper_output[distortion_key]['our_wm_output'][count]
        tokenized_input = tokenizer.encode(input_text, return_tensors='pt', add_special_tokens=False).to(device)[0]
        tokenized_output = tokenizer.encode(output, return_tensors='pt').to(device)[0]
        z_score = watermark_detector2.detect(tokenized_output, tokenized_input)['z_score']
        detection_scores[distortion_key]['our_wm_output'].append(z_score)

        # kgw_wm_output 
        output = input_text + dipper_output[distortion_key]['kgw_wm_output'][count]
        tokenized_output = tokenizer.encode(output, return_tensors='pt').to(device)[0]
        z_score = kgw_detector.detect(tokenized_output, tokenized_input)['z_score']
        detection_scores[distortion_key]['kgw_wm_output'].append(z_score)


with open('./LTW/eval_records/eval_plot/dipper_attack_zscore.json', 'w') as f:
    json.dump(detection_scores, f, indent=4)
    print("检测结果保存成功")
