﻿import math
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch.nn.functional as F
from typing import List, Dict, Tuple, Optional


def l2_clip(x: torch.Tensor, C: float) -> torch.Tensor:
    """x: (..., K). row-wise l2 clip to norm<=C"""
    norms = x.norm(p=2, dim=-1, keepdim=True).clamp(min=1e-12)
    scale = (norms / C).clamp_min(1.0)
    return x / scale


def _mask_end_tokens_logits(
    logits: torch.Tensor,
    step: int,
    min_step_no_eos: int,
    eos_local_ids: List[int],
    neg_inf: float = -1e9,
) -> torch.Tensor:
    if step <= min_step_no_eos and eos_local_ids:
        logits = logits.clone()
        logits[eos_local_ids] = neg_inf
    return logits


def dp_normalize_and_noise_logits(
    logits_rq: torch.Tensor,
    clip_norm: float,
    noise_multiplier: float,
    adjacency: str = "add_remove",
) -> torch.Tensor:
    M = logits_rq.size(0)
    clipped = l2_clip(logits_rq, clip_norm)
    mean_logits = clipped.mean(dim=0)
    if adjacency == "replace":
        scale = M / (2.0 * clip_norm)
    else:
        scale = M / clip_norm
    z_norm = mean_logits * scale
    z_norm_noisy = z_norm + torch.randn_like(z_norm) * float(noise_multiplier)
    return z_norm_noisy


def load_llm(model_name: str, *, device="cuda", dtype=torch.bfloat16, use_flash_attn=True):
    tok = AutoTokenizer.from_pretrained(model_name, use_fast=True)
    if tok.pad_token_id is None:
        tok.pad_token_id = tok.eos_token_id
    attn_impl = "flash_attention_2" if use_flash_attn else "sdpa"
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=dtype,
        device_map="auto",
        attn_implementation=attn_impl,
    )
    model.eval()
    model.config.use_cache = True
    return tok, model


def aggregate_by_word(data, field):
    """data: List[dict(token, value)]"""
    word_values = {}
    current_word = ""
    current_value = 0.0
    for i, item in enumerate(data):
        token = item["token"]
        value = item[field]
        if token.startswith(" ") or i == 0 or token == "<|eot_id|>":
            if current_word:
                word_values[current_word] = current_value
            current_word = token.strip()
            current_value = value
        else:
            current_word += token
            current_value += value
    if current_word:
        word_values[current_word] = current_value
    temp = []
    for key, value in word_values.items():
        temp.append({"token": key, field: value})
    word_values = temp
    return word_values


@torch.inference_mode()
def cg_stream_hf_fast_loaded(
    tok,
    model,
    *,
    system_prompt: str,
    user_template: str,
    query: str,
    retrieved: str,
    max_new_tokens: int = 128,
    gen_temperature: float = 0.7,
    eval_temperature: float = 1.0,
    device: str = "cuda",
):
    user_rq = user_template.format(retrieved=retrieved or "", query=query)
    user_q = user_template.format(retrieved="not available", query=query)
    msgs_rq = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_rq},
    ]
    msgs_q = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_q},
    ]
    rq_ids = tok.apply_chat_template(
        msgs_rq, add_generation_prompt=True, return_tensors="pt"
    )
    q_ids = tok.apply_chat_template(
        msgs_q, add_generation_prompt=True, return_tensors="pt"
    )
    maxlen = max(rq_ids.shape[1], q_ids.shape[1])

    def pad(x):
        return torch.nn.functional.pad(x, (0, maxlen - x.shape[1]), value=tok.pad_token_id)

    batch_ids = torch.cat([pad(q_ids), pad(rq_ids)], dim=0).to(device)
    attn_mask = (batch_ids != tok.pad_token_id).to(device)
    out = model(input_ids=batch_ids, attention_mask=attn_mask, use_cache=True)
    past = out.past_key_values
    cur_len = attn_mask.sum(dim=-1)
    results, seq_rq = [], []
    next_tokens = torch.full((2, 1), tok.eos_token_id, device=device, dtype=torch.long)

    for _ in range(max_new_tokens):
        position_ids = cur_len.clone().unsqueeze(-1)
        out = model(
            input_ids=next_tokens,
            past_key_values=past,
            position_ids=position_ids,
            use_cache=True,
        )
        logits = out.logits[:, -1, :]
        past = out.past_key_values
        H_q = entropy_from_logits(logits[0], temperature=eval_temperature).item()
        H_rq = entropy_from_logits(logits[1], temperature=eval_temperature).item()
        CG = float(H_q - H_rq)
        if gen_temperature and gen_temperature > 0.0:
            probs_rq = F.softmax(logits[1] / gen_temperature, dim=-1)
            next_rq = torch.multinomial(probs_rq, num_samples=1)
        else:
            next_rq = logits[1].argmax(dim=-1, keepdim=True)
        next_tokens = torch.cat([next_rq.unsqueeze(0), next_rq.unsqueeze(0)], dim=0)
        token_str = tok.decode(next_rq.tolist())
        results.append({"token": token_str, "CG": CG})
        seq_rq.append(int(next_rq))
        cur_len = cur_len + 1
        if int(next_rq) == tok.eos_token_id or token_str == "<|eot_id|>":
            break

    return results, tok.decode(seq_rq)


def entropy_from_logits(
    logits: torch.Tensor, temperature: float = 1.0, base: str = "e"
):
    z = logits / temperature
    logZ = torch.logsumexp(z, dim=-1)
    p = F.softmax(z, dim=-1)
    Ez = (p * z).sum(dim=-1)
    H = logZ - Ez
    if base == "2":
        H = H / math.log(2.0)
    return H


@torch.inference_mode()
def cg_stream_hf_separate(
    tok,
    model,
    *,
    system_prompt: str,
    user_template: str,
    query: str,
    retrieved: str,
    max_new_tokens: int = 128,
    gen_temperature: float = 0.7,
    eval_temperature: float = 1.0,
    device: str = "cuda",
) -> Tuple[List[Dict[str, float]], str]:
    user_rq = user_template.format(retrieved=retrieved or "", query=query)
    user_q = user_template.format(retrieved="not available", query=query)
    msgs_rq = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_rq},
    ]
    msgs_q = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_q},
    ]
    rq_ids = tok.apply_chat_template(
        msgs_rq, add_generation_prompt=True, return_tensors="pt"
    ).to(device)
    q_ids = tok.apply_chat_template(
        msgs_q, add_generation_prompt=True, return_tensors="pt"
    ).to(device)
    attn_rq = ((rq_ids != tok.pad_token_id) if tok.pad_token_id is not None else torch.ones_like(rq_ids)).to(device)
    attn_q = ((q_ids != tok.pad_token_id) if tok.pad_token_id is not None else torch.ones_like(q_ids)).to(device)
    out_rq = model(input_ids=rq_ids, attention_mask=attn_rq, use_cache=True)
    out_q = model(input_ids=q_ids, attention_mask=attn_q, use_cache=True)
    past_rq = out_rq.past_key_values
    past_q = out_q.past_key_values
    cur_len_rq = attn_rq.sum(dim=-1)
    cur_len_q = attn_q.sum(dim=-1)
    results: List[Dict[str, float]] = []
    seq_ids: list[int] = []
    next_tok = torch.tensor([[tok.eos_token_id]], device=device, dtype=torch.long)

    for _ in range(max_new_tokens):
        pos_q = cur_len_q.clone().unsqueeze(-1)
        out_q = model(input_ids=next_tok, past_key_values=past_q, position_ids=pos_q, use_cache=True)
        logits_q = out_q.logits[:, -1, :]
        past_q = out_q.past_key_values
        pos_rq = cur_len_rq.clone().unsqueeze(-1)
        out_rq = model(input_ids=next_tok, past_key_values=past_rq, position_ids=pos_rq, use_cache=True)
        logits_rq = out_rq.logits[:, -1, :]
        past_rq = out_rq.past_key_values
        H_q = entropy_from_logits(logits_q, temperature=eval_temperature).item()
        H_rq = entropy_from_logits(logits_rq, temperature=eval_temperature).item()
        CG = float(H_q - H_rq)
        if gen_temperature and gen_temperature > 0.0:
            probs_rq = F.softmax(logits_rq / gen_temperature, dim=-1)
            next_rq = torch.multinomial(probs_rq, num_samples=1)
        else:
            next_rq = logits_rq.argmax(dim=-1, keepdim=True)
        token_str = tok.decode(next_rq.squeeze(0).tolist())
        results.append({"token": token_str, "CG": CG})
        seq_ids.append(int(next_rq.item()))
        cur_len_q = cur_len_q + 1
        cur_len_rq = cur_len_rq + 1
        next_tok = next_rq
        if int(next_rq.item()) == tok.eos_token_id or token_str == "<|eot_id|>":
            break

    return results, tok.decode(seq_ids)


def mask_eos_early_prob(
    probs: torch.Tensor,
    step: int,
    min_step_no_eos: int,
    eos_local_ids: list[int],
    eps: float = 1e-12,
) -> torch.Tensor:
    if step <= min_step_no_eos and len(eos_local_ids) > 0:
        p = probs.clone()
        p[..., eos_local_ids] = 0.0
        s = p.sum(dim=-1, keepdim=True).clamp_min(eps)
        p = p / s
        return p
    return probs


@torch.inference_mode()
def cg_stream_hf_ensemble_parallel_probs(
    tok,
    model,
    *,
    system_prompt: str,
    user_template: str,
    query: str,
    retrieved_list: List[str],
    sigma: float,
    max_new_tokens: int = 128,
    gen_temperature: float = 0.7,
    eval_temperature: float = 1.0,
    vocab_keep_k: Optional[int] = None,
    min_step: int = 200,
    device: str = "cuda",
) -> Tuple[List[Dict[str, float]], str]:
    model.eval()
    embed_device = model.get_input_embeddings().weight.device
    pad_id = tok.pad_token_id if tok.pad_token_id is not None else tok.eos_token_id
    eos_global_ids = [tok.eos_token_id]
    try:
        eot_id = tok.convert_tokens_to_ids("<|eot_id|>")
        if eot_id is not None and eot_id != -1:
            eos_global_ids.append(eot_id)
    except Exception:
        pass
    eos_global_ids = list({i for i in eos_global_ids if i is not None})
    msgs_rq_list = []
    for retrieved in retrieved_list:
        user_rq = user_template.format(retrieved=retrieved or "", query=query)
        msgs_rq_list.append(
            [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_rq}]
        )
    user_q = user_template.format(retrieved="not available", query=query)
    msgs_q = [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_q}]
    rq_id_list = [
        tok.apply_chat_template(m, add_generation_prompt=True, return_tensors="pt").squeeze(0)
        for m in msgs_rq_list
    ]
    max_len_rq = max(x.size(0) for x in rq_id_list)
    rq_ids = torch.full((len(rq_id_list), max_len_rq), pad_id, dtype=torch.long)
    for i, ids in enumerate(rq_id_list):
        rq_ids[i, : ids.size(0)] = ids
    q_ids = tok.apply_chat_template(msgs_q, add_generation_prompt=True, return_tensors="pt")
    rq_ids = rq_ids.to(embed_device)
    q_ids = q_ids.to(embed_device)
    attn_rq = (rq_ids != pad_id).to(embed_device)
    attn_q = (q_ids != pad_id).to(embed_device)
    out_rq = model(input_ids=rq_ids, attention_mask=attn_rq, use_cache=True)
    out_q = model(input_ids=q_ids, attention_mask=attn_q, use_cache=True)
    past_rq = out_rq.past_key_values
    past_q = out_q.past_key_values
    cur_len_rq = attn_rq.sum(dim=-1)
    cur_len_q = attn_q.sum(dim=-1)
    results: List[Dict[str, float]] = []
    seq_ids: List[int] = []
    next_tok = torch.tensor([[tok.eos_token_id]], device=embed_device, dtype=torch.long)

    def _softmax_temp(logits: torch.Tensor, T: float) -> torch.Tensor:
        z = logits / max(T, 1e-8)
        z = z - z.max(dim=-1, keepdim=True).values
        return F.softmax(z, dim=-1)

    def _entropy_from_probs(p: torch.Tensor) -> torch.Tensor:
        p = torch.clamp(p, 1e-12, 1.0)
        return -(p * torch.log(p)).sum(dim=-1)

    MIN_STEP_NO_EOS = min_step

    for step in range(1, max_new_tokens + 1):
        pos_rq = cur_len_rq.unsqueeze(-1)
        out_rq = model(
            input_ids=next_tok.expand(rq_ids.size(0), 1),
            past_key_values=past_rq,
            position_ids=pos_rq,
            use_cache=True,
        )
        logits_rq = out_rq.logits[:, -1, :]
        past_rq = out_rq.past_key_values
        pos_q = cur_len_q.unsqueeze(-1)
        out_q = model(
            input_ids=next_tok,
            past_key_values=past_q,
            position_ids=pos_q,
            use_cache=True,
        )
        logits_q = out_q.logits[:, -1, :]
        past_q = out_q.past_key_values
        keep_idx: Optional[torch.Tensor] = None
        V = logits_rq.size(-1)
        if vocab_keep_k is not None and 0 < vocab_keep_k < V:
            mean_ctx = logits_rq.mean(dim=0)
            top_ctx = torch.topk(mean_ctx, k=vocab_keep_k).indices
            top_base = torch.topk(logits_q.squeeze(0), k=vocab_keep_k).indices
            keep_idx = torch.unique(torch.cat([top_ctx, top_base], dim=-1))
            logits_rq = logits_rq.index_select(-1, keep_idx)
            logits_q = logits_q.index_select(-1, keep_idx)
        probs_rq = _softmax_temp(logits_rq, eval_temperature)
        base_probs = _softmax_temp(logits_q, eval_temperature).squeeze(0)
        M = probs_rq.size(0)
        mean_probs = probs_rq.mean(dim=0)
        sigma_mean = float(sigma) / max(M, 1)
        noise = torch.randn_like(mean_probs) * sigma_mean
        tilde_probs = torch.clamp(mean_probs + noise, min=0.0)
        tilde_probs = tilde_probs / tilde_probs.sum().clamp_min(1e-12)
        if keep_idx is not None:
            eos_local_ids = []
            for gid in eos_global_ids:
                m = (keep_idx == gid).nonzero(as_tuple=True)[0]
                if m.numel() > 0:
                    eos_local_ids.append(int(m.item()))
        else:
            eos_local_ids = eos_global_ids
        tilde_probs = mask_eos_early_prob(
            tilde_probs, step=step, min_step_no_eos=MIN_STEP_NO_EOS, eos_local_ids=eos_local_ids
        )
        H_q = _entropy_from_probs(base_probs).item()
        H_ens = _entropy_from_probs(tilde_probs).item()
        CG = float(H_q - H_ens)
        if gen_temperature and gen_temperature > 0.0:
            logp = torch.log(tilde_probs.clamp_min(1e-12))
            sample_probs = F.softmax(logp / gen_temperature, dim=-1)
        else:
            sample_probs = tilde_probs
        next_local = torch.multinomial(sample_probs, num_samples=1)
        local_id = int(next_local.item())
        if keep_idx is not None:
            token_id = int(keep_idx[local_id].item())
        else:
            token_id = local_id
        seq_ids.append(token_id)
        token_str = tok.decode([token_id], skip_special_tokens=True)
        results.append({"step": step, "token": token_str, "CG": CG, "H_base": H_q, "H_ens": H_ens})
        next_tok = torch.tensor([[token_id]], device=embed_device, dtype=torch.long)
        cur_len_q = cur_len_q + 1
        cur_len_rq = cur_len_rq + 1
        if token_id in eos_global_ids or token_str == "<|eot_id|>":
            break

    return results, tok.decode(seq_ids, skip_special_tokens=True)


@torch.inference_mode()
def cg_stream_hf_ensemble_parallel_logits(
    tok,
    model,
    *,
    system_prompt: str,
    user_template: str,
    query: str,
    retrieved_list: List[str],
    eps_step: float,
    adjacency: str = "add_remove",
    max_new_tokens: int = 128,
    gen_temperature: float = 0.7,
    eval_temperature: float = 1.0,
    clipThr: float = 30.0,
    min_steps: int = 300,
    vocab_keep_k: Optional[int] = None,
    device: str = "cuda",
) -> Tuple[List[Dict[str, float]], str]:
    model.eval()
    embed_device = model.get_input_embeddings().weight.device
    pad_id = tok.pad_token_id if tok.pad_token_id is not None else tok.eos_token_id
    eos_global_ids = [tok.eos_token_id]
    try:
        eot_id = tok.convert_tokens_to_ids("<|eot_id|>")
        if eot_id is not None and eot_id != -1:
            eos_global_ids.append(eot_id)
    except Exception:
        pass
    eos_global_ids = list({i for i in eos_global_ids if i is not None})
    msgs_rq_list = []
    for retrieved in retrieved_list:
        user_rq = user_template.format(retrieved=retrieved or "", query=query)
        msgs_rq_list.append(
            [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_rq}]
        )
    user_q = user_template.format(retrieved="not available", query=query)
    msgs_q = [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_q}]
    rq_id_list = [
        tok.apply_chat_template(m, add_generation_prompt=True, return_tensors="pt").squeeze(0)
        for m in msgs_rq_list
    ]
    max_len_rq = max(x.size(0) for x in rq_id_list)
    M = len(rq_id_list)
    rq_ids = torch.full((M, max_len_rq), pad_id, dtype=torch.long)
    for i, ids in enumerate(rq_id_list):
        rq_ids[i, : ids.size(0)] = ids
    q_ids = tok.apply_chat_template(msgs_q, add_generation_prompt=True, return_tensors="pt").squeeze(0)
    rq_ids = rq_ids.to(embed_device)
    q_ids = q_ids.to(embed_device)
    attn_rq = (rq_ids != pad_id).to(embed_device)
    attn_q = (q_ids != pad_id).unsqueeze(0).to(embed_device)
    out_rq = model(input_ids=rq_ids, attention_mask=attn_rq, use_cache=True)
    out_q = model(input_ids=q_ids.unsqueeze(0), attention_mask=attn_q, use_cache=True)
    past_rq = out_rq.past_key_values
    past_q = out_q.past_key_values
    cur_len_rq = attn_rq.sum(dim=-1)
    cur_len_q = attn_q.sum(dim=-1)
    results: List[Dict[str, float]] = []
    seq_ids: List[int] = []
    last_rq = rq_ids[torch.arange(M, device=embed_device), (cur_len_rq - 1).to(torch.long)].unsqueeze(-1)
    last_q = q_ids[(cur_len_q.item() - 1)].view(1, 1)
    next_tok_rq = last_rq.clone()
    next_tok_q = last_q.clone()

    def _softmax_T(logits: torch.Tensor, T: float) -> torch.Tensor:
        z = logits / max(T, 1e-8)
        z = z - z.max(dim=-1, keepdim=True).values
        return F.softmax(z, dim=-1)

    def _entropy(p: torch.Tensor) -> torch.Tensor:
        p = p.clamp_min(1e-12)
        return -(p * p.log()).sum(dim=-1)

    if adjacency == "replace":
        delta_u = 2.0 * clipThr /  M
    else:
        delta_u = clipThr / M

    MIN_STEP_NO_EOS = int(min_steps)

    for step in range(1, max_new_tokens + 1):
        pos_rq = cur_len_rq.unsqueeze(-1)
        out_rq = model(input_ids=next_tok_rq, past_key_values=past_rq, position_ids=pos_rq, use_cache=True)
        logits_rq = out_rq.logits[:, -1, :]
        past_rq = out_rq.past_key_values
        pos_q = cur_len_q.unsqueeze(-1)
        out_q = model(input_ids=next_tok_q, past_key_values=past_q, position_ids=pos_q, use_cache=True)
        logits_q = out_q.logits[:, -1, :]
        past_q = out_q.past_key_values
        keep_idx: Optional[torch.Tensor] = None
        V = logits_rq.size(-1)
        if vocab_keep_k is not None and 0 < vocab_keep_k < V:
            mean_ctx = logits_rq.mean(dim=0)
            top_ctx = torch.topk(mean_ctx, k=vocab_keep_k).indices
            top_base = torch.topk(logits_q.squeeze(0), k=vocab_keep_k).indices
            keep_idx = torch.unique(torch.cat([top_ctx, top_base], dim=-1))
            logits_rq = logits_rq.index_select(-1, keep_idx)
            logits_q = logits_q.index_select(-1, keep_idx)
        u = logits_rq.clamp(min=-clipThr, max=clipThr).mean(dim=0)
        if keep_idx is not None:
            eos_local_ids = []
            for gid in eos_global_ids:
                m = (keep_idx == gid).nonzero(as_tuple=True)[0]
                if m.numel() > 0:
                    eos_local_ids.append(int(m.item()))
        else:
            eos_local_ids = eos_global_ids
        base_probs = _softmax_T(logits_q.squeeze(0), eval_temperature)
        ens_probs_eval = _softmax_T(u, eval_temperature)
        H_q = float(_entropy(base_probs))
        H_ens = float(_entropy(ens_probs_eval))
        CG = float(H_q - H_ens)
        denom = 2.0 * delta_u * max(gen_temperature, 1e-8)
        scores = (eps_step / denom) * u
        scores_masked = _mask_end_tokens_logits(scores, step, MIN_STEP_NO_EOS, eos_local_ids, neg_inf=-1e9)
        scores_masked = scores_masked - scores_masked.max()
        sample_probs = torch.softmax(scores_masked, dim=-1)
        next_local = torch.multinomial(sample_probs, num_samples=1)
        local_id = int(next_local.item())
        token_id = int(keep_idx[local_id].item()) if keep_idx is not None else local_id
        seq_ids.append(token_id)
        token_str = tok.decode([token_id], skip_special_tokens=True)
        results.append({"step": step, "token": token_str, "CG": CG, "H_base": H_q, "H_ens": H_ens})
        next_tok_q = torch.tensor([[token_id]], device=embed_device, dtype=torch.long)
        next_tok_rq = next_tok_q.expand(M, 1).contiguous()
        cur_len_q = cur_len_q + 1
        cur_len_rq = cur_len_rq + 1
        if token_id in eos_global_ids or token_str == "<|eot_id|>":
            break

    return results, tok.decode(seq_ids, skip_special_tokens=True)




@torch.inference_mode()
def em_generate_ensemble_only(
    tok,
    model,
    *,
    user_prompts: List[str],
    system_prompt: str = "You are a careful text refiller that follows privacy and factuality rules.",
    eps_step: float,
    adjacency: str = "add_remove",
    max_new_tokens: int = 128,
    gen_temperature: float = 1.0,
    clipThr: float = 30.0,
    min_steps: int = 50,
    vocab_keep_k: Optional[int] = None,
    device: str = "cuda",
) -> str:

    model.eval()
    embed_device = model.get_input_embeddings().weight.device
    pad_id = tok.pad_token_id if tok.pad_token_id is not None else tok.eos_token_id
    eos_global_ids = [tok.eos_token_id]
    try:
        eot_id = tok.convert_tokens_to_ids("<|eot_id|>")
        if eot_id is not None and eot_id != -1:
            eos_global_ids.append(eot_id)
    except Exception:
        pass
    eos_global_ids = list({i for i in eos_global_ids if i is not None})
    msg_batches = [
        [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": up},
        ]
        for up in user_prompts
    ]
    if len(msg_batches) == 0:
        raise ValueError("user_prompts must contain at least one string.")
    encoded_list = [
        tok.apply_chat_template(m, add_generation_prompt=True, return_tensors="pt").squeeze(0)
        for m in msg_batches
    ]
    max_len = max(x.size(0) for x in encoded_list)
    M = len(encoded_list)
    input_ids = torch.full((M, max_len), pad_id, dtype=torch.long)
    for i, ids in enumerate(encoded_list):
        input_ids[i, :ids.size(0)] = ids
    input_ids = input_ids.to(embed_device)
    attn = (input_ids != pad_id).to(embed_device)
    out = model(input_ids=input_ids, attention_mask=attn, use_cache=True)
    past = out.past_key_values
    cur_len = attn.sum(dim=-1)
    if adjacency == "replace":
        delta_u = (2.0 * clipThr) / M
    else:
        delta_u = clipThr / M

    def _mask_eos_inplace(vec: torch.Tensor, step: int, eos_locals: List[int], neg_inf: float = -1e9):
        if step < int(min_steps):
            for j in eos_locals:
                if 0 <= j < vec.size(-1):
                    vec[j] = neg_inf

    next_tok = input_ids[torch.arange(M, device=embed_device), (cur_len - 1).to(torch.long)].unsqueeze(-1)
    seq_ids: List[int] = []

    for step in range(1, max_new_tokens + 1):
        pos = cur_len.unsqueeze(-1)
        out = model(input_ids=next_tok, past_key_values=past, position_ids=pos, use_cache=True)
        logits = out.logits[:, -1, :]
        past = out.past_key_values
        cur_len = cur_len + 1
        keep_idx = None
        V = logits.size(-1)
        if vocab_keep_k is not None and 0 < vocab_keep_k < V:
            mean_logits_for_topk = logits.mean(dim=0)
            keep_idx = torch.topk(mean_logits_for_topk, k=vocab_keep_k).indices
            logits = logits.index_select(-1, keep_idx)
        u = logits.mean(dim=0).clamp(min=-clipThr, max=clipThr)
        denom = 2.0 * delta_u * max(gen_temperature, 1e-8)
        scores = (eps_step / denom) * u
        if keep_idx is None:
            eos_locals = eos_global_ids
        else:
            eos_locals = []
            for gid in eos_global_ids:
                m = (keep_idx == gid).nonzero(as_tuple=True)[0]
                if m.numel() > 0:
                    eos_locals.append(int(m.item()))
        scores = scores - scores.max()
        if keep_idx is None:
            _mask_eos_inplace(scores, step, eos_locals=eos_locals)
        else:
            _mask_eos_inplace(scores, step, eos_locals=eos_locals)
        probs = F.softmax(scores, dim=-1)
        local_id = int(torch.multinomial(probs, num_samples=1).item())
        token_id = int(keep_idx[local_id].item()) if keep_idx is not None else local_id
        seq_ids.append(token_id)
        next_tok = torch.full((M, 1), token_id, device=embed_device, dtype=torch.long)
        if token_id in eos_global_ids:
            break

    return tok.decode(seq_ids, skip_special_tokens=True)


if __name__ == "__main__":
    torch.set_grad_enabled(False)
    main()
