import json
import re
import string
import numpy as np
from tqdm import tqdm
from collections import defaultdict

try:
    from watermark_2 import generate_watermark_keys
except ImportError:
    exit()

p0 = 0.5 

def split_sentences(text):
    if not text: return []
    sentences = []
    current = ""
    i = 0
    length = len(text)
    PUNCT_TOKEN_RE = re.compile(r'^\s*(?:\n|\.{3}|[.?!，,。？！…]+(?:\s)?)\s*$')

    def is_punct_token(token_str: str) -> bool:
        return bool(PUNCT_TOKEN_RE.fullmatch(token_str))

    while i < length:
        ch = text[i]
        current += ch
        if is_punct_token(ch):
            sentences.append(current)
            current = ""
            i += 1
            space_chunk = ""
            while i < length and text[i].isspace() and text[i] != "\n":
                space_chunk += text[i]
                i += 1
            current += space_chunk
            continue
        i += 1

    if current.strip():
        sentences.append(current)

    return [s for s in sentences if s.strip()]

def filter_short_word_sentences(sentences):
    filtered = []
    for sent in sentences:
        if len(sent) > 3: 
            filtered.append(sent)
    return filtered

def process_single_text(text, test_index_val):
    sentences = split_sentences(text)
    sentences = filter_short_word_sentences(sentences) 

    model_id_count_valid = 0
    model_id_user_id_count_valid = 0
    model_id_count_total = 0
    model_id_user_id_count_total = 0

    translator = str.maketrans('', '', string.punctuation)

    for sent in sentences:
        sentence_withoutpunctuation = sent.translate(translator).strip() 
        if not sentence_withoutpunctuation: continue

        key_type, HMAC_syntax, HMAC_semanteme = generate_watermark_keys(sentence_withoutpunctuation)
        score = (HMAC_semanteme % 2) * 0.5 + (HMAC_syntax % 2) * 0.5
        
        if key_type == "model_id":
            model_id_count_total += 1
            model_id_count_valid += score
        else:
            model_id_user_id_count_total += 1
            model_id_user_id_count_valid += score

    result = {
        "test_index": test_index_val,        
        "z_model_id_rate": (
            (model_id_count_valid * (20 / model_id_count_total) - p0 * 20) / np.sqrt(p0 * (1 - p0) * 20)
        ) if model_id_count_total > 0 else None,
        
        "z_model_id.user_id_rate": (
            (model_id_user_id_count_valid * (20 / model_id_user_id_count_total) - p0 * 20) / np.sqrt(p0 * (1 - p0) * 20)
        ) if model_id_user_id_count_total > 0 else None,
        
        "z_total_rate": (
            ((model_id_count_valid + model_id_user_id_count_valid) * (20 / (model_id_user_id_count_total + model_id_count_total)) - p0 * 20) / np.sqrt(p0 * (1 - p0) * 20)
        ) if (model_id_user_id_count_total > 0) or (model_id_count_total > 0) else None
    }

    if result["z_model_id_rate"] is None or result["z_model_id.user_id_rate"] is None or len(sentences) < 4:
        result["too_few_sentences"] = True
    
    return result

input_file = "/home/awb/sentence_level_watermark/expe_result/llama3.1-8B_c4/attack_total.jsonl"
output_file = "/home/awb/sentence_level_watermark/expe_result/llama3.1-8B_c4/attack_total_output.jsonl"

stats = defaultdict(list)

with open(input_file, "r", encoding="utf-8") as fin, open(output_file, "w", encoding="utf-8") as fout:
    for i, line in enumerate(tqdm(fin, desc="Processing lines")):
        entry = json.loads(line)
        test_idx = entry.get("test_index", i)
        tasks = {}
        
        if entry.get("output_with_watermark"):
            tasks["Original"] = entry.get("output_with_watermark")
            
        attacks = entry.get("attacked_versions", {})
        if attacks:
            tasks.update(attacks)
            
        entry_robustness = {"test_index": test_idx, "results": {}}
        
        for name, text in tasks.items():
            res = process_single_text(text, test_idx)
            entry_robustness["results"][name] = res
            
            if res.get("z_total_rate") is not None:
                stats[name].append(res["z_total_rate"])

        fout.write(json.dumps(entry_robustness, ensure_ascii=False) + "\n")