
from transformers import pipeline
import torch
import json
from typing import Iterator, Dict
import gc
import os, math, torch
from transformers import AutoTokenizer, AutoModelForCausalLM

OPEN_AI_KEY = "<your-openai-key>"
ATTACK_SEQ_PATH = "../datasets/attack_list.json"
GROUND_TRUTH_PATH = "../datasets/ground_truth_list.json"

def calculate_perplexity(logits, target):
    """
    Calculate perplexity from logits and target labels.

    Args:
    - logits (torch.Tensor): Logits output from the model (batch_size, seq_length, vocab_size).
    - target (torch.Tensor): Ground truth labels (batch_size, seq_length).

    Returns:
    - perplexity (float): The perplexity score.
    """

    # Convert logits to log probabilities
    log_probs = torch.nn.functional.log_softmax(logits, dim=-1)

    # Gather the log probabilities for the correct target tokens
    # log_probs has shape (batch_size, seq_length, vocab_size)
    # target has shape (batch_size, seq_length)
    # The gather method will pick the log probabilities of the true target tokens
    target_log_probs = log_probs.gather(dim=-1, index=target.unsqueeze(-1)).squeeze(-1)

    # Calculate the negative log likelihood
    negative_log_likelihood = -target_log_probs

    # Calculate the mean negative log likelihood over all tokens
    mean_nll = negative_log_likelihood.mean()

    # Calculate perplexity as exp(mean negative log likelihood)
    perplexity = torch.exp(mean_nll)

    return perplexity.item()

# pip install torch transformers accelerate -U
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

def cal_seq_perplexity(text: str) -> float:
    enc = tokenizer(text, return_tensors="pt").to(DEVICE)
    with torch.no_grad():
        out = model(**enc, labels=enc["input_ids"]) 
        ppl = torch.exp(out.loss).item() 
    return ppl

def try_using_gpt_oss_20b():
    model_id = "openai/gpt-oss-20b"

    pipe = pipeline(
        "text-generation",
        model=model_id,
        torch_dtype="auto",
        device_map="auto",
    )

    messages = [
        {"role": "user", "content": "Explain quantum mechanics clearly and concisely."},
    ]

    outputs = pipe(
        messages,
        max_new_tokens=256,
    )
    print(outputs[0]["generated_text"][-1])

# local weight snapchat path
MODEL_DIR = "<your-model-dir>"

def read_inference_sequence(type: str):
    if type == "attack":
        path = ATTACK_SEQ_PATH
    else:
        path = GROUND_TRUTH_PATH

    return json.load(open(path, "r", encoding="utf-8"))

def cal_ppl_ce_v1():


    import os, math, torch
    from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig

    # Reduce CUDA fragmentation by using expandable segments
    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

    MODEL_DIR = "<your-model-dir>"

    tok = AutoTokenizer.from_pretrained(MODEL_DIR, use_fast=True, trust_remote_code=True, local_files_only=True)

    # Multi-GPU sharding + FP16 load (avoid BF16 for memory footprint)
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_DIR,
        trust_remote_code=True,
        local_files_only=True,
        torch_dtype=torch.float16,
        device_map="auto",
        low_cpu_mem_usage=True
    ).eval()

    # === Global default: temperature=0 equivalent (greedy decoding).
    # This does NOT affect loss-based metrics (PPL/entropy) computed below. ===
    try:
        gen_cfg = GenerationConfig.from_pretrained(MODEL_DIR)
    except Exception:
        gen_cfg = GenerationConfig()
    # Key settings: disable sampling -> greedy. temperature is ignored when do_sample=False,
    # but we keep it for clarity.
    gen_cfg.do_sample = False
    gen_cfg.temperature = 0.0
    gen_cfg.top_p = 1.0
    gen_cfg.top_k = 0
    gen_cfg.num_beams = 1
    model.generation_config = gen_cfg
    # ======================================================================

    # Disable KV cache (not needed for PPL/entropy; saves VRAM)
    model.config.use_cache = False

    # Force any remaining BF16/FP32 parameters & buffers to FP16
    def force_model_fp16(m: torch.nn.Module):
        for n, p in m.named_parameters(recurse=True):
            if p.dtype in (torch.bfloat16, torch.float32):
                p.data = p.data.to(torch.float16)
        for n, b in m.named_buffers(recurse=True):
            # Some buffers may not allow writing .data; ignore failures
            if getattr(b, "dtype", None) in (torch.bfloat16, torch.float32):
                try:
                    b.data = b.data.to(torch.float16)
                except Exception:
                    pass

    force_model_fp16(model)

    # Put inputs on the same device as the embedding layer to avoid device mismatch
    first_dev = next((v for k, v in getattr(model, "hf_device_map", {}).items()
                      if any(x in k for x in ["embed", "wte", "tok_embeddings"])), 0)

    # ---------- Metrics helpers ----------
    def compute_aux_metrics(ce_nats: float, pe_nats: float):
        """
        Returns:
          - avg_true_token_prob: exp(-CE), average probability assigned to the true token
          - uncertainty_minus_loss: PE - CE (both in nats/token)
        """
        avg_true_token_prob = math.exp(-ce_nats)
        uncertainty_minus_loss = pe_nats - ce_nats
        return avg_true_token_prob, uncertainty_minus_loss


    def nats_to_bits(x: float) -> float:
        """Convert nats to bits."""
        return x / math.log(2)

    @torch.inference_mode()
    def ppl(text: str) -> float:
        """
        Perplexity via exp(cross_entropy).
        Cross-entropy returned by HF is in nats/token.
        """
        enc = tok(text, return_tensors="pt").to(first_dev)
        out = model(**enc, labels=enc["input_ids"])
        return math.exp(out.loss.item())

    @torch.inference_mode()
    def cross_entropy_and_ppl(text: str):
        """
        Returns (cross_entropy_nats_per_token, cross_entropy_bits_per_token, ppl).
        """
        enc = tok(text, return_tensors="pt").to(first_dev)
        out = model(**enc, labels=enc["input_ids"])
        ce_nats = float(out.loss.detach().float().item())   # nats/token
        ce_bits = nats_to_bits(ce_nats)                     # bits/token
        ppl_val = math.exp(ce_nats)
        return ce_nats, ce_bits, ppl_val


    @torch.inference_mode()
    def predictive_entropy_v2(text: str, chunk_len: int = 1024, dtype: torch.dtype = torch.float32,
                           return_per_pos: bool = False):
        """
        Predictive entropy of the next-token distribution, averaged per token.
        Returns:
          - if return_per_pos=False: (mean_entropy_nats, mean_entropy_bits)
          - if return_per_pos=True:  (mean_entropy_nats, mean_entropy_bits, per_pos_ent_nats[np.ndarray])
        Notes:
          - Uses log_softmax for stability: H = -sum p * log p
          - Casts logits to `dtype` (float32 by default; float64 for precision checks)
        """
        enc = tok(text, return_tensors="pt").to(first_dev)
        out = model(**enc, use_cache=False)
        logits = out.logits[:, :-1, :].to(dtype)  # [1, L-1, V]

        Lm1 = logits.shape[1]
        ents = []
        for start in range(0, Lm1, chunk_len):
            end = min(Lm1, start + chunk_len)
            z = logits[:, start:end, :]  # [1, S, V]
            log_probs = torch.log_softmax(z, dim=-1)  # [1, S, V]
            probs = torch.exp(log_probs)  # [1, S, V]
            ent = -(probs * log_probs).sum(dim=-1)  # [1, S] in nats
            ents.append(ent)

        ent_all = torch.cat(ents, dim=1).squeeze(0)  # [L-1]
        mean_ent_nats = float(ent_all.mean().item())
        mean_ent_bits = nats_to_bits(mean_ent_nats)

        if return_per_pos:
            return mean_ent_nats, mean_ent_bits, ent_all.detach().cpu().numpy()
        else:
            return mean_ent_nats, mean_ent_bits


    def debug_compare_pe(text1: str, text2: str):
        """
        Compare predictive entropies of two texts:
          - float32 means and per-position arrays
          - float64 means (to check precision) and deltas
          - prints max absolute per-position diff
        """
        import numpy as np

        print("\n[DEBUG] Predictive Entropy comparison (float32):")
        m1_32, _, e1_32 = predictive_entropy_v2(text1, dtype=torch.float32, return_per_pos=True)
        m2_32, _, e2_32 = predictive_entropy_v2(text2, dtype=torch.float32, return_per_pos=True)
        minlen = min(len(e1_32), len(e2_32))
        diffs32 = e1_32[:minlen] - e2_32[:minlen]
        print(f"  mean PE32: {m1_32:.12f} vs {m2_32:.12f}")
        print(f"  max|diff| per-pos (PE32): {np.max(np.abs(diffs32)):.12e}")

        print("\n[DEBUG] Predictive Entropy comparison (float64):")
        m1_64, _, e1_64 = predictive_entropy_v2(text1, dtype=torch.float64, return_per_pos=True)
        m2_64, _, e2_64 = predictive_entropy_v2(text2, dtype=torch.float64, return_per_pos=True)
        diffs64 = e1_64[:minlen] - e2_64[:minlen]
        print(f"  mean PE64: {m1_64:.12f} vs {m2_64:.12f}")
        print(f"  max|diff| per-pos (PE64): {np.max(np.abs(diffs64)):.12e}")

        print(f"\n  mean delta (64-32): seq1 {m1_64 - m1_32:.12e} | seq2 {m2_64 - m2_32:.12e}")

        # Show last few positions to verify whether only the final token differs
        print("\n  tail entropies (last 3 positions if available):")
        print(f"    seq1 PE32 tail: {e1_32[-3:] if len(e1_32) >= 3 else e1_32}")
        print(f"    seq2 PE32 tail: {e2_32[-3:] if len(e2_32) >= 3 else e2_32}")


    @torch.inference_mode()
    def predictive_entropy(text: str, chunk_len: int = 1024):
        """
        Predictive entropy of the next-token distribution, averaged per token.
        Returns (entropy_nats_per_token, entropy_bits_per_token).

        Notes:
        - We compute entropy over the next-token distribution at each position:
          H(p) = -sum_i p_i log p_i.
        - We drop the last position (no "next token" to predict) to align with LM training.
        - For numerical stability, cast logits to float32 while keeping weights in FP16.
        - 'chunk_len' controls sequence chunking to reduce memory usage.
        """
        enc = tok(text, return_tensors="pt").to(first_dev)
        out = model(**enc, use_cache=False)
        logits = out.logits  # [1, L, V]
        logits = logits[:, :-1, :].to(torch.float32)  # align with next-token prediction

        L = logits.shape[1]
        total_ent = 0.0
        count = 0

        for start in range(0, L, chunk_len):
            end = min(L, start + chunk_len)
            z = logits[:, start:end, :]                     # [1, S, V]
            # Stable entropy: H = logsumexp(z) - sum(softmax(z) * z)
            lse = torch.logsumexp(z, dim=-1)                # [1, S]
            probs = torch.nn.functional.softmax(z, dim=-1)  # [1, S, V]
            exp_z = (probs * z).sum(dim=-1)                 # [1, S]
            ent = (lse - exp_z)                             # [1, S] in nats
            total_ent += float(ent.sum().item())
            count += ent.numel()

        mean_ent_nats = total_ent / max(count, 1)
        mean_ent_bits = nats_to_bits(mean_ent_nats)
        return mean_ent_nats, mean_ent_bits

    # ---------- Your data sources ----------
    # Assumes these functions are defined elsewhere in your codebase.
    ground_truth_seqs = read_inference_sequence("ground_truth")
    attack_seqs = read_inference_sequence("attack")
    ground_truth_len = len(ground_truth_seqs)
    attack_len = len(attack_seqs)

    # ---------- Evaluation / printing ----------
    for cal_time in range(10):
        print(f"\nRound NO.{cal_time + 1}/10 ...\n")
        index = 1

        for seq in ground_truth_seqs:
            print(f"Ground Truth {index}/{ground_truth_len}:")
            index += 1
            print("====" * 10)
            print(seq)

            ce_nats, ce_bits, ppl_val = cross_entropy_and_ppl(ground_truth_seqs[seq])
            pe_nats, pe_bits = predictive_entropy(ground_truth_seqs[seq])

            print(f"PPL: {ppl_val:.12f}")
            print(f"Cross-Entropy: {ce_nats:.12f} nats/token ({ce_bits:.12f} bits/token)")
            print(f"Predictive Entropy: {pe_nats:.12f} nats/token ({pe_bits:.12f} bits/token)")
            avg_p_true, pe_minus_ce = compute_aux_metrics(ce_nats, pe_nats)
            print(f"Avg True Token Prob: {avg_p_true:.12f}")
            print(f"Uncertainty - Loss (PE - CE): {pe_minus_ce:.12f} nats/token")

        for seq in attack_seqs:
            print(f"Attack {index - ground_truth_len}/{attack_len}:")
            index += 1
            print("====" * 10)
            print(seq)

            ce_nats, ce_bits, ppl_val = cross_entropy_and_ppl(attack_seqs[seq])
            pe_nats, pe_bits = predictive_entropy(attack_seqs[seq])

            print(f"PPL: {ppl_val:.12f}")
            print(f"Cross-Entropy: {ce_nats:.12f} nats/token ({ce_bits:.12f} bits/token)")
            print(f"Predictive Entropy: {pe_nats:.12f} nats/token ({pe_bits:.12f} bits/token)")
            avg_p_true, pe_minus_ce = compute_aux_metrics(ce_nats, pe_nats)
            print(f"Avg True Token Prob: {avg_p_true:.12f}")
            print(f"Uncertainty - Loss (PE - CE): {pe_minus_ce:.12f} nats/token")

        # run a focused debug on the first two attack sequences
        # print("\n===== DEBUG: Comparing attack[0] vs attack[1] predictive entropy =====")
        # debug_compare_pe(attack_seqs["attack_1"], attack_seqs["attack_2"])

        # Optional sanity check for greedy generation (does not affect metrics above)
        if False:
            prompt = "Write a one-sentence definition of perplexity:"
            enc = tok(prompt, return_tensors="pt").to(first_dev)
            out_ids = model.generate(**enc, max_new_tokens=64)  # uses the global greedy config
            print(tok.decode(out_ids[0], skip_special_tokens=True))

import torch.nn.functional as F

@torch.inference_mode()
def ce_ppl_sliding_with_temperature(
    text: str,
    temperature: float,
    max_length: int | None = None,
    stride: int = 512,
):
    """
    Temperature-scaled CE/PPL on a fixed text (analysis, non-standard):
      - For each sliding window, run a forward pass WITHOUT labels.
      - Scale logits by 1/τ and compute cross-entropy on valid positions only.
      - τ=1.0 -> matches standard CE/PPL (up to tiny float diffs).
      - τ<1 makes the distribution sharper; τ>1 makes it flatter.

    NOTE:
      * This does NOT change model weights.
      * This is NOT decoding; it only re-calibrates the scoring distribution.
      * Requires temperature > 0.
    """
    if temperature is None or float(temperature) <= 0:
        raise ValueError("temperature must be > 0")
    tau = float(temperature)

    enc = tok(text, return_tensors="pt").to(first_dev)
    input_ids = enc["input_ids"]  # [1, L]
    L = input_ids.size(1)

    if max_length is None:
        max_length = _model_max_len()

    nll_sum = 0.0
    n_tokens = 0
    prev_end = 0

    for begin in range(0, L, stride):
        end = min(begin + max_length, L)
        trg_len = end - prev_end  # align with your baseline sliding-window logic
        ids = input_ids[:, begin:end]        # [1, S]
        target_ids = ids.clone()
        # Only the last `trg_len` positions are scored; HF internally shifts by 1
        target_ids[:, :-trg_len] = -100

        # Forward pass without labels to get raw logits
        out = model(input_ids=ids, use_cache=False)
        # Align logits and labels for next-token prediction
        logits = out.logits[:, :-1, :].to(torch.float32) / tau   # temperature scaling
        labels = target_ids[:, 1:]                                # shift labels

        V = logits.size(-1)
        ce = F.cross_entropy(
            logits.reshape(-1, V),
            labels.reshape(-1),
            reduction="mean",
            ignore_index=-100,
        ).item()

        num_valid = (labels != -100).sum().item()
        if num_valid > 0:
            nll_sum += ce * num_valid
            n_tokens += num_valid

        prev_end = end
        if end == L:
            break

    ce_nats = nll_sum / max(n_tokens, 1)
    ce_bits = nats_to_bits(ce_nats)
    ppl = math.exp(ce_nats)
    return ce_nats, ce_bits, ppl

def cal_perplexity_and_cross_entropy_using_demo():
    import os, math, torch
    from transformers import AutoTokenizer, AutoModelForCausalLM

    # optional: reduce the CUDA fragmentation by using expandable segments
    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

    MODEL_DIR = "<your-model-dir>"

    tok = AutoTokenizer.from_pretrained(
        MODEL_DIR, use_fast=True, trust_remote_code=True, local_files_only=True
    )

    # sliding for multi-GPU
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_DIR,
        trust_remote_code=True,
        local_files_only=True,
        torch_dtype=torch.bfloat16,  # ← 改成 bfloat16
        device_map="auto",
        low_cpu_mem_usage=True
    ).eval()

    # turn off KV cache (not needed for PPL/entropy; saves VRAM)
    model.config.use_cache = False

    # find embedding device (for inputs)
    first_dev = next(
        (v for k, v in getattr(model, "hf_device_map", {}).items()
         if any(x in k for x in ["embed", "wte", "tok_embeddings"])),
        0
    )

    def nats_to_bits(x: float) -> float:
        return x / math.log(2)

    def _model_max_len():
        return (
                getattr(model.config, "n_positions", None)
                or getattr(model.config, "max_position_embeddings", 2048)
        )

    @torch.inference_mode()
    def ce_ppl_sliding(text: str, max_length: int | None = None, stride: int = 512):

        enc = tok(text, return_tensors="pt").to(first_dev)
        input_ids = enc["input_ids"]  # [1, L]
        L = input_ids.size(1)

        if max_length is None:
            max_length = _model_max_len()

        nll_sum = 0.0
        n_tokens = 0
        prev_end = 0

        for begin in range(0, L, stride):
            end = min(begin + max_length, L)
            trg_len = end - prev_end
            ids = input_ids[:, begin:end]  # [1, S]
            target_ids = ids.clone()

            target_ids[:, :-trg_len] = -100

            out = model(input_ids=ids, labels=target_ids)
            loss = out.loss.item()
            num_valid = (target_ids != -100).sum().item()
            num_loss_tokens = num_valid - target_ids.size(0)  # shift to the left by 1 internally

            nll_sum += loss * num_loss_tokens
            n_tokens += num_loss_tokens

            prev_end = end
            if end == L:
                break

        ce_nats = nll_sum / max(n_tokens, 1)
        ppl = math.exp(ce_nats)
        ce_bits = nats_to_bits(ce_nats)
        return ce_nats, ce_bits, ppl

    @torch.inference_mode()
    def predictive_entropy(text: str, chunk_len: int = 1024):
        """
        Predictive entropy of the next-token distribution, averaged per token.
        """
        enc = tok(text, return_tensors="pt").to(first_dev)
        logits = model(**enc, use_cache=False).logits[:, :-1, :].to(torch.float32)

        total = 0.0
        count = 0
        Lm1 = logits.shape[1]

        for start in range(0, Lm1, chunk_len):
            z = logits[:, start:start + chunk_len, :]  # [1, S, V]
            log_probs = torch.log_softmax(z, dim=-1)  # [1, S, V]
            probs = torch.exp(log_probs)  # [1, S, V]
            ent = -(probs * log_probs).sum(dim=-1)  # [1, S] in nats
            total += float(ent.sum().item())
            count += ent.numel()

        mean_nats = total / max(count, 1)
        mean_bits = nats_to_bits(mean_nats)
        return mean_nats, mean_bits

    def compute_aux_metrics(ce_nats: float, pe_nats: float):
        avg_true_token_prob = math.exp(-ce_nats)
        uncertainty_minus_loss = pe_nats - ce_nats

        return avg_true_token_prob, uncertainty_minus_loss

    def evaluate_group(name: str, seqs: dict[str, str], max_length=None, stride=512):
        print(f"\n=== {name} ===")
        for i, key in enumerate(seqs, 1):
            text = seqs[key]
            ce_nats, ce_bits, ppl = ce_ppl_sliding(text, max_length=max_length, stride=stride)
            pe_nats, pe_bits = predictive_entropy(text)
            p_true, gap = compute_aux_metrics(ce_nats, pe_nats)

            for tau in (0.7, 1.0, 1.3):
                ce_t, ce_b, ppl_t = ce_ppl_sliding_with_temperature(
                    text, temperature=tau, max_length=max_length, stride=stride
                )
                print(f"{name} {i}/{len(seqs)}: {key}")
                print(f"  [Temp={tau:.1f}] CE: {ce_t:.6f} nats/token ({ce_b:.6f} bits/token), PPL: {ppl_t:.6f}")

    ground_truth_seqs = read_inference_sequence("ground_truth")  # dict[str, str]
    attack_seqs = read_inference_sequence("attack")  # dict[str, str]
    evaluate_group("Ground Truth", ground_truth_seqs)
    evaluate_group("Attack", attack_seqs)

def build_eval_prompt(intent: str, text_full: str, framed: bool = True) -> str:
    """
    Build a single string that combines the intent and the full conversation text.

    Parameters
    ----------
    intent : str
        Parsed intent, e.g., "car_rental".
    text_full : str
        The whole conversation flattened as "role: content" lines.
    framed : bool, default True
        If True, wrap the conversation in a clearly delimited block to avoid
        header/metadata being conflated with the conversation.

    Returns
    -------
    str
        A prompt like:
        "The user wants to do car_rental, the multi-function call and conversations are shown below:\n\n<conversation>"
    """
    # Normalize intent lightly (keep original token for reproducibility)
    intent = (intent or "unknown").strip()

    header = f"The user wants to do {intent}, the multi-function call and conversations are shown below:"
    if not framed:
        # Plain concatenation (closest to your example)
        return header + "\n\n" + (text_full or "")

    # Framed version: safer/more readable for downstream models
    return (
        header
        + "\n\n=== CONVERSATION START ===\n"
        + (text_full or "")
        + "\n=== CONVERSATION END ==="
    )

def iter_jsonl(path: str) -> Iterator[Dict]:
    """Yield dict records from a .jsonl file safely."""
    with open(path, "r", encoding="utf-8") as f:
        for ln, line in enumerate(f, 1):
            line = line.strip()
            if not line:
                continue
            try:
                yield json.loads(line)
            except json.JSONDecodeError as e:
                raise RuntimeError(f"JSON parse error at line {ln}: {e}")

def parse_func_call_into_jsonl():
    # Input: the minimal JSONL you created earlier with fields {id, intent, text_full}
    in_path = "../datasets/ComplexFuncBench.text_full.jsonl"

    # Output: prompts JSONL; each line has {id, intent, prompt}
    out_path = "../datasets/ComplexFuncBench.prompts.jsonl"

    n = 0
    with open(out_path, "w", encoding="utf-8") as w:
        for rec in iter_jsonl(in_path):
            sample_id = rec.get("id")
            intent = rec.get("intent", "unknown")
            text_full = rec.get("text_full", "")

            prompt = build_eval_prompt(intent, text_full, framed=True)

            w.write(json.dumps({
                "id": sample_id,
                "intent": intent,
                "prompt": prompt
            }, ensure_ascii=False) + "\n")

            # Print a small preview for the first few lines
            n += 1
            if n <= 2:
                preview = prompt[:400].replace("\n", "\\n")
                print(f"[Preview #{n}] id={sample_id} intent={intent} | prompt[0:400]={preview}...")

    print(f"Done. Wrote {n} prompts to: {out_path}")

def iter_prompts(jsonl_path: str) -> Iterator[str]:
    """
    Stream prompts from a JSONL file where each line is an object
    containing a 'prompt' field.
    """
    with open(jsonl_path, "r", encoding="utf-8") as f:
        for ln, line in enumerate(f, 1):
            line = line.strip()
            if not line:
                continue
            try:
                obj = json.loads(line)
            except json.JSONDecodeError as e:
                raise RuntimeError(f"JSON parse error at line {ln}: {e}")
            # Skip lines without a 'prompt' key
            if "prompt" not in obj:
                continue
            prompt = obj["prompt"]
            # Coerce to string just in case
            yield prompt if isinstance(prompt, str) else str(prompt)

@torch.inference_mode()
def ce_ppl_sliding_streaming(
    text: str,
    max_length: int = 1024,   # <- set a safe window, 1024 or 2048
    stride: int = 512,
    empty_cache_every: int = 0,  # e.g. set 4 to clear cache every 4 windows; 0 = never
):
    """
    CE/PPL with a sliding window, keeping the full tokenized sequence on CPU.
    Only the current window (ids/mask/labels) is moved to GPU to avoid OOM.
    """
    # ❶ tokenize on CPU (do NOT .to(first_dev) here)
    MODEL_DIR = "<your_model_path>"

    tok = AutoTokenizer.from_pretrained(
        MODEL_DIR, use_fast=True, trust_remote_code=True, local_files_only=True
    )

    # sliding for multi-GPU
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_DIR,
        trust_remote_code=True,
        local_files_only=True,
        torch_dtype=torch.bfloat16,  # ← 改成 bfloat16
        device_map="auto",
        low_cpu_mem_usage=True
    ).eval()

    # turn off KV cache (not needed for PPL/entropy; saves VRAM)
    model.config.use_cache = False


    enc = tok(text, return_tensors="pt")

    # find embedding device (for inputs)
    first_dev = next(
        (v for k, v in getattr(model, "hf_device_map", {}).items()
         if any(x in k for x in ["embed", "wte", "tok_embeddings"])),
        0
    )
    input_ids = enc["input_ids"]                      # CPU tensor [1, L]
    attn_mask = enc.get("attention_mask", None)       # CPU tensor [1, L] or None
    L = input_ids.size(1)

    # ❷ explicit safe window size
    if max_length is None:
        max_length = 1024  # do not trust model.config for memory; pick a safe default

    nll_sum = 0.0
    n_tokens = 0
    prev_end = 0
    win_idx = 0

    for begin in range(0, L, stride):
        end = min(begin + max_length, L)
        trg_len = end - prev_end                          # align with HF sliding logic

        ids_cpu  = input_ids[:, begin:end]                # still on CPU
        mask_cpu = attn_mask[:, begin:end] if attn_mask is not None else None

        # ❸ move only this window to GPU
        ids  = ids_cpu.to(first_dev, non_blocking=True)
        mask = mask_cpu.to(first_dev, non_blocking=True) if mask_cpu is not None else None

        labels = ids.clone()
        labels[:, :-trg_len] = -100                       # score only last trg_len tokens

        out = model(input_ids=ids, attention_mask=mask, labels=labels)
        loss = float(out.loss)                            # mean CE on valid labels (nats/token)

        num_valid = int((labels != -100).sum().item())
        num_loss_tokens = num_valid - labels.size(0)      # adjust for label shift (batch=1)

        nll_sum += loss * num_loss_tokens
        n_tokens += num_loss_tokens

        # ❹ free per-window tensors ASAP
        del ids, mask, labels, out
        win_idx += 1
        if empty_cache_every and (win_idx % empty_cache_every == 0):
            torch.cuda.empty_cache()
            gc.collect()

        prev_end = end
        if end == L:
            break

    ce_nats = nll_sum / max(n_tokens, 1)
    ce_bits = ce_nats / math.log(2.0)
    ppl = math.exp(ce_nats)
    return ce_nats, ce_bits, ppl

def cal_text_ppl_ce(text: str):
    # optional: reduce the CUDA fragmentation by using expandable segments
    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

    MODEL_DIR = "<your_model_dir>"

    tok = AutoTokenizer.from_pretrained(
        MODEL_DIR, use_fast=True, trust_remote_code=True, local_files_only=True
    )

    # sliding for multi-GPU
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_DIR,
        trust_remote_code=True,
        local_files_only=True,
        torch_dtype=torch.bfloat16,  # ← 改成 bfloat16
        device_map="auto",
        low_cpu_mem_usage=True
    ).eval()

    # turn off KV cache (not needed for PPL/entropy; saves VRAM)
    model.config.use_cache = False

    # find embedding device (for inputs)
    first_dev = next(
        (v for k, v in getattr(model, "hf_device_map", {}).items()
         if any(x in k for x in ["embed", "wte", "tok_embeddings"])),
        0
    )

    def nats_to_bits(x: float) -> float:
        return x / math.log(2)

    def _model_max_len():
        return (
                getattr(model.config, "n_positions", None)
                or getattr(model.config, "max_position_embeddings", 2048)
        )

    @torch.inference_mode()
    def ce_ppl_sliding(text: str, max_length: int | None = None, stride: int = 512):

        enc = tok(text, return_tensors="pt").to(first_dev)
        input_ids = enc["input_ids"]  # [1, L]
        L = input_ids.size(1)

        if max_length is None:
            max_length = _model_max_len()

        nll_sum = 0.0
        n_tokens = 0
        prev_end = 0

        for begin in range(0, L, stride):
            end = min(begin + max_length, L)
            trg_len = end - prev_end
            ids = input_ids[:, begin:end]  # [1, S]
            target_ids = ids.clone()

            target_ids[:, :-trg_len] = -100

            out = model(input_ids=ids, labels=target_ids)
            loss = out.loss.item()
            num_valid = (target_ids != -100).sum().item()
            num_loss_tokens = num_valid - target_ids.size(0)  # shift to the left by 1 internally

            nll_sum += loss * num_loss_tokens
            n_tokens += num_loss_tokens

            prev_end = end
            if end == L:
                break

        ce_nats = nll_sum / max(n_tokens, 1)
        ppl = math.exp(ce_nats)
        ce_bits = nats_to_bits(ce_nats)
        return ce_nats, ce_bits, ppl

    text_sample = text
    ce_nats, ce_bits, ppl = ce_ppl_sliding_streaming(text_sample,
                                                     max_length=1024,  # safety window 1024/2048
                                                     stride=512,
                                                     empty_cache_every=4)
    print("CE: %.6f nats/token (%.6f bits/token), PPL: %.6f" % (ce_nats, ce_bits, ppl))

text = "<your_prompt_need_to_calculate_perplexity_and_cross_entropy>"
if __name__ == '__main__':
    cal_text_ppl_ce(text)