import json
import torch
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm
import os

# ================= 配置区域 =================
MODEL_PATH = "/sdb1/awb/models--llama-3.1-8B"
DATA_FILE = "/home/awb/sentence_level_watermark/expe_result/baseline/PF_c4/Llama-3.2-3B/total.jsonl"
USE_4BIT = False  # 你说是 False
# ===========================================

# ✅ 强制使用 cuda:4
DEVICE = "cuda:4" if torch.cuda.is_available() else "cpu"


def load_model(model_path, use_4bit=False):
    print(f"正在从本地加载模型: {model_path} ...")

    if not os.path.exists(os.path.join(model_path, "config.json")):
        raise ValueError(f"错误：在 {model_path} 下未找到 config.json。请检查路径是否为模型根目录。")

    try:
        tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token

        if use_4bit:
            # 你当前是 False，这段不会走到；保留以防以后切 True
            from transformers import BitsAndBytesConfig
            bnb_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_compute_dtype=torch.float16,
                bnb_4bit_quant_type="nf4"
            )
            model = AutoModelForCausalLM.from_pretrained(
                model_path,
                quantization_config=bnb_config,
                device_map={"": 2},  # 强制放到 cuda:2（如果以后 USE_4BIT=True）
                trust_remote_code=True
            )
        else:
            # ✅ 关键：不要 device_map="auto"
            model = AutoModelForCausalLM.from_pretrained(
                model_path,
                torch_dtype=torch.float16,
                trust_remote_code=True
            ).to(DEVICE)

        model.eval()
        print(f"模型加载成功！当前 device = {next(model.parameters()).device}")
        return model, tokenizer

    except Exception as e:
        print(f"模型加载失败: {e}")
        exit(1)


def calculate_conditional_ppl(prompt, output, model, tokenizer):
    """
    计算 PPL(Output | Prompt)
    """
    full_text = prompt + output

    encodings = tokenizer(full_text, return_tensors="pt")
    input_ids = encodings.input_ids.to(DEVICE)

    prompt_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(DEVICE)
    prompt_len = prompt_ids.shape[1]

    if input_ids.shape[1] <= prompt_len:
        return None

    labels = input_ids.clone()
    labels[:, :prompt_len] = -100

    with torch.no_grad():
        outputs = model(input_ids, labels=labels)
        nll = outputs.loss

    return torch.exp(nll).item()


def main():
    model, tokenizer = load_model(MODEL_PATH, USE_4BIT)

    ppl_scores = []

    print(f"开始计算 PPL，数据文件: {DATA_FILE}")

    with open(DATA_FILE, 'r', encoding='utf-8') as f:
        lines = f.readlines()

    for i, line in enumerate(tqdm(lines)):
        try:
            item = json.loads(line)
            prompt = item['prompt']
            output = item['output_with_watermark']

            ppl = calculate_conditional_ppl(prompt, output, model, tokenizer)

            if ppl is not None:
                ppl_scores.append(ppl)

        except json.JSONDecodeError:
            print(f"Skipping line {i}: JSON Error")
        except Exception as e:
            print(f"Error on line {i}: {e}")

    if ppl_scores:
        avg_ppl = np.mean(ppl_scores)
        median_ppl = np.median(ppl_scores)
        print("\n" + "="*30)
        print(f"处理样本数: {len(ppl_scores)}")
        print(f"平均困惑度 (Mean PPL): {avg_ppl:.4f}")
        print(f"中位困惑度 (Median PPL): {median_ppl:.4f}")
        print("="*30)
    else:
        print("未计算出有效 PPL。")


if __name__ == "__main__":
    main()
