import json
import re
import numpy as np
from tqdm import tqdm
from watermark_2 import generate_watermark_keys
import string

p0 = 0.5

def split_sentences(text):
    sentences = []
    current = ""
    i = 0
    length = len(text)
    PUNCT_TOKEN_RE = re.compile(r'^\s*(?:\n|\.{3}|[.?!，,。？！…]+(?:\s)?)\s*$')

    def is_punct_token(token_str: str) -> bool:
        return bool(PUNCT_TOKEN_RE.fullmatch(token_str))

    while i < length:
        ch = text[i]
        current += ch
        if is_punct_token(ch):
            sentences.append(current)
            current = ""
            i += 1
            space_chunk = ""
            while i < length and text[i].isspace() and text[i] != "\n":
                space_chunk += text[i]
                i += 1
            current += space_chunk
            continue
        i += 1

    if current.strip():
        sentences.append(current)

    return [s for s in sentences if s.strip()]


def filter_short_word_sentences(sentences):
    filtered = []
    for sent in sentences:
        if len(sent) > 3:
            filtered.append(sent)
    return filtered


input_file = "/home/awb/sentence_level_watermark/expe_result/llama3.2-3B_c4/ablation_study/without_quality.jsonl"
output_file = "/home/awb/sentence_level_watermark/expe_result/llama3.2-3B_c4/ablation_study/without_quality_output.jsonl"

with open(input_file, "r", encoding="utf-8") as fin, open(output_file, "w", encoding="utf-8") as fout:
    for i, line in enumerate(tqdm(fin)):
        entry = json.loads(line)
        text = entry.get("output_with_watermark")

        sentences = split_sentences(text)
        sentences = filter_short_word_sentences(sentences)

        model_id_count_valid = 0
        model_id_user_id_count_valid = 0
        model_id_count_total = 0
        model_id_user_id_count_total = 0

        translator = str.maketrans('', '', string.punctuation)

        for sent in sentences:
            sentence_withoutpunctuation = sent.translate(translator).strip()
            key_type, HMAC_syntax, HMAC_semanteme = generate_watermark_keys(sentence_withoutpunctuation)

            score = (HMAC_semanteme % 2) * 0.5 + (HMAC_syntax % 2) * 0.5
            if key_type == "model_id":
                model_id_count_total += 1
                model_id_count_valid += score
            else:
                model_id_user_id_count_total += 1
                model_id_user_id_count_valid += score

        result = {
            "test_index": entry.get("test_index", i),
            "z_model_id_rate": (
                (model_id_count_valid * (20 / model_id_count_total) - p0 * 20) / np.sqrt(p0 * (1 - p0) * 20)
            ) if model_id_count_total > 0 else None,
            "z_model_id.user_id_rate": (
                (model_id_user_id_count_valid * (20 / model_id_user_id_count_total) - p0 * 20) / np.sqrt(p0 * (1 - p0) * 20)
            ) if model_id_user_id_count_total > 0 else None,
            "z_total_rate": (
                ((model_id_count_valid + model_id_user_id_count_valid) * (20 / (model_id_user_id_count_total + model_id_count_total)) - p0 * 20) / np.sqrt(p0 * (1 - p0) * 20)
            ) if (model_id_user_id_count_total > 0) or (model_id_count_total > 0) else None
        }

        if result["z_model_id_rate"] is None or result["z_model_id.user_id_rate"] is None or len(sentences) < 4:
            result["too_few_sentences"] = True

        fout.write(json.dumps(result, ensure_ascii=False) + "\n")