import torch
import numpy as np
import torch.nn.functional as F
import os, sys

current_script_path = os.path.abspath(__file__)
scripts_dir = os.path.dirname(current_script_path)
project_root = os.path.dirname(scripts_dir)
if project_root not in sys.path:
    sys.path.insert(0, project_root)

BASE_LINE = None
BASE_LINE_VOCAB_SIZE = None  # 记录当前 BASE_LINE 的词汇表大小

def pc_sampler_function(
    probabilities: torch.Tensor,
    token_ids: torch.Tensor,
    lambda_val: float,
    alpha: float,
    bg_freq_tensor: torch.Tensor
) -> torch.Tensor:
    
    if probabilities.shape != token_ids.shape:
        raise f"probabilities.shape: {probabilities.shape}, token_ids.shape: {token_ids.shape} must be equal"

    device = probabilities.device
    sequence_len = probabilities.shape[1]
    f_bg_tensor = bg_freq_tensor[token_ids]
    epsilon = 1e-9
    cross_entropy_scores = -probabilities * torch.log(f_bg_tensor + epsilon)
    cross_entropy_scores = torch.clamp(cross_entropy_scores, max=alpha)
    positions = torch.arange(sequence_len, device=device, dtype=torch.float32)
    positional_bias = torch.exp(-lambda_val * positions)
    final_scores = positional_bias * cross_entropy_scores

    return final_scores

def load_baseline(model, baseline_name, vocab_size=None):
    global BASE_LINE, BASE_LINE_VOCAB_SIZE
    
    # 动态获取词汇表大小
    if vocab_size is None:
        if hasattr(model, 'config') and hasattr(model.config, 'vocab_size'):
            vocab_size = model.config.vocab_size
        else:
            vocab_size = 200000  # 使用足够大的默认值
    
    # 如果 BASE_LINE 不存在，或者词汇表大小变化，重新创建
    if BASE_LINE is None or BASE_LINE_VOCAB_SIZE != vocab_size:
        from utils.load_json_or_jsonl import load_json_or_jsonl
        p_baseline_dict = load_json_or_jsonl(baseline_name)
        token_num_ = p_baseline_dict['num_token']
        p_baseline_dict = p_baseline_dict['p_baseline_dict']
        del_keys = []
        for key in p_baseline_dict.keys():
            del_keys.append(key)
        for key in del_keys:
            p_baseline_dict[int(key)] = p_baseline_dict[key]
        for key in del_keys:
            del p_baseline_dict[key]
        for key in p_baseline_dict.keys():
            p_baseline_dict[key] = p_baseline_dict[key] / token_num_
        
        BASE_LINE = torch.full((vocab_size,), 1/token_num_, device=model.device, dtype=torch.float32)
        keys = torch.tensor(list(p_baseline_dict.keys()), device=model.device, dtype=torch.long)
        values = torch.tensor(list(p_baseline_dict.values()), device=model.device, dtype=torch.float32)
        BASE_LINE.scatter_(0, keys, values)
        BASE_LINE_VOCAB_SIZE = vocab_size
    else:
        BASE_LINE = BASE_LINE.to(model.device)

def apply_eos_penalty(logits, model, eos_penalty=0.0):
    '''
    Apply EOS (End of Sequence) penalty to logits to discourage early termination.
    
    Args:
        logits: Model output logits of shape [batch_size, seq_len, vocab_size]
        model: The model (used to get eos_token_id from config)
        eos_penalty: Penalty value to subtract from EOS token logits (default: 0.0, no penalty)
    
    Returns:
        logits with EOS penalty applied
    '''
    if eos_penalty == 0.0:
        return logits
    
    # Try to get eos_token_id from model config
    eos_token_id = None
    if hasattr(model, 'config'):
        if hasattr(model.config, 'eos_token_id') and model.config.eos_token_id is not None:
            eos_token_id = model.config.eos_token_id
        elif hasattr(model.config, 'eos_token_id') and isinstance(model.config.eos_token_id, (list, tuple)) and len(model.config.eos_token_id) > 0:
            eos_token_id = model.config.eos_token_id[0]
    
    if eos_token_id is not None and eos_token_id < logits.shape[-1]:
        # Apply penalty: subtract penalty value from EOS token logits
        logits[:, :, eos_token_id] = logits[:, :, eos_token_id] - eos_penalty
    
    return logits


def add_gumbel_noise(logits, temperature):
    '''
    The Gumbel max is a method for sampling categorical distributions.
    According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality.
    Thus, we use float64.
    '''
    if temperature == 0:
        return logits
    logits = logits.to(torch.float64)
    noise = torch.rand_like(logits, dtype=torch.float64)
    gumbel_noise = (- torch.log(noise)) ** temperature
    return logits.exp() / gumbel_noise


def get_num_transfer_tokens(mask_index, steps):
    '''
    In the reverse process, the interval [0, 1] is uniformly discretized into steps intervals.
    Furthermore, because LLaDA employs a linear noise schedule (as defined in Eq. (8)),
    the expected number of tokens transitioned at each step should be consistent.

    This function is designed to precompute the number of tokens that need to be transitioned at each step.
    '''
    mask_num = mask_index.sum(dim=1, keepdim=True)

    base = mask_num // steps
    remainder = mask_num % steps

    num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base

    for i in range(mask_num.size(0)):
        num_transfer_tokens[i, :remainder[i]] += 1

    return num_transfer_tokens

def get_transfer_index(logits, temperature, remasking, mask_index, x, num_transfer_tokens, threshold=None):
    logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
    x0 = torch.argmax(logits_with_noise, dim=-1) # b, l

    if remasking == 'low_confidence':
        p = F.softmax(logits.to(torch.float64), dim=-1)
        x0_p = torch.squeeze(
            torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l
    elif remasking == 'random':
        x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
    else:
        raise NotImplementedError(remasking)
    
    x0 = torch.where(mask_index, x0, x)
    confidence = torch.where(mask_index, x0_p, -np.inf)

    transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
    if threshold is not None:
        num_transfer_tokens = mask_index.sum(dim=1, keepdim=True)
    for j in range(confidence.shape[0]):
        _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j])
        transfer_index[j, select_index] = True
        if threshold is not None:
            for k in range(1, num_transfer_tokens[j]):
                if confidence[j, select_index[k]] < threshold:
                    transfer_index[j, select_index[k]] = False
    return x0, transfer_index

def get_transfer_index_dynamic(logits, temperature, remasking, mask_index, x, num_transfer_tokens, factor=1):
    logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
    x0 = torch.argmax(logits_with_noise, dim=-1) # b, l
    if remasking == 'low_confidence':
        p = F.softmax(logits.to(torch.float64), dim=-1)
        x0_p = torch.squeeze(
            torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l
    elif remasking == 'random':
        x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
    else:
        raise NotImplementedError(remasking)
    
    x0 = torch.where(mask_index, x0, x)
    confidence = torch.where(mask_index, x0_p, -np.inf)

    transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
    num_transfer_tokens = mask_index.sum(dim=1, keepdim=True)
    
    for j in range(confidence.shape[0]):
        ns=list(range(1,num_transfer_tokens[j]+1))
        es=[factor/(n+1) for n in ns]
        threshs=[1-e for e in es]

        # at least one token is transferred
        threshs[0]=-1
        sorted_confidence=torch.sort(confidence[j][mask_index[j]],dim=-1,descending=True)[0]
        assert len(sorted_confidence)==len(threshs)
        for top_i in range(len(threshs)):
            if sorted_confidence[top_i]<threshs[top_i]:
                break

        if top_i == 0 or top_i == len(threshs)-1:
            top_i+=1

        _, select_index = torch.topk(confidence[j], k=top_i)
        transfer_index[j, select_index] = True

    return x0, transfer_index

@torch.no_grad()
def generate(model, prompt, steps=128, gen_length=128, block_length=128, lambd=1, alpha=1, baseline_name='P_baseline.json', temperature=0.,
                  cfg_scale=0., remasking='low_confidence', mask_id=126336, return_order=False,
                  candidate_number=1, position_temperature=0.0, debug=False, prefilled_positions=None,
                  heuristic='confidence', return_cumulative_entropy=False, tokens_per_step=None, is_dream=False,
                  save_monotone_residual_path=None, eos_penalty=0.0):
    """
    基础生成函数（基于 PC-Sampler），支持 Info-Gain Sampler 扩展
    
    Args:
        candidate_number: 候选动作数量。<=1 或 position_temperature<=0 时退化为 PC-Sampler
        position_temperature: 位置采样温度，用于 IG-Sampler 的动作采样。<=0 时退化为 PC-Sampler
        debug: 是否输出调试信息
        prefilled_positions: 预填充位置列表 [(position, token_id), ...]，position 是相对于生成区域的偏移
        heuristic: 启发函数类型，可选 'pc'（PC值）, 'confidence'（置信度）, 'neg_entropy'（负熵）, 'uniform'（随机）
        return_cumulative_entropy: 是否返回累积熵（每一步动作集合的熵之和）
        tokens_per_step: 每步解码的 token 数量 (K)。若设置，则 steps = num_masks // K
        is_dream: 是否是 Dream 模型（需要 shift logits）
        eos_penalty: EOS token penalty value to discourage early termination (default: 0.0, no penalty)
    """
    debug=False
    global BASE_LINE
    # 每次都重新加载以确保 vocab_size 正确
    load_baseline(model, baseline_name)
    
    if return_order:
        orders = {}
    
    # 特殊处理：gen_length=0 表示 mask 已经在 prompt 中（如 Sudoku 预填充模式）
    if gen_length == 0:
        x = prompt.clone()
        # 找出 prompt 中的 mask 位置
        mask_positions = (x == mask_id)
        num_masks = mask_positions.sum().item()
        
        if debug:
            print(f"[DEBUG] gen_length=0 mode: found {num_masks} masks in prompt")
        
        if num_masks == 0:
            return x  # 没有 mask，直接返回
        
        # prompt_len 设为第一个 mask 之前的位置
        first_mask_pos = torch.where(mask_positions[0])[0][0].item()
        prompt_len = first_mask_pos
        
        # block_length 是从第一个 mask 到序列末尾的距离
        seq_len = x.shape[1]
        block_length = seq_len - prompt_len
        gen_length = block_length
        
        # 根据 tokens_per_step (K) 计算 steps
        if tokens_per_step is not None and tokens_per_step > 0:
            steps = max(1, num_masks // tokens_per_step)
        else:
            # 默认每步转移一个 token
            steps = num_masks
        
        if debug:
            print(f"[DEBUG] First mask at position {first_mask_pos}, seq_len={seq_len}")
            print(f"[DEBUG] prompt_len={prompt_len}, block_length={block_length}, num_masks={num_masks}, steps={steps}")
    else:
        x = torch.full((1, prompt.shape[1] + gen_length), mask_id, dtype=torch.long).to(model.device)
        x[:, :prompt.shape[1]] = prompt.clone()
        prompt_len = prompt.shape[1]
        
        # 预填充已知位置（如 Sudoku 中已有的数字）
        if prefilled_positions is not None:
            for pos, token_id in prefilled_positions:
                x[:, prompt.shape[1] + pos] = token_id
            if debug:
                print(f"[DEBUG] Prefilled {len(prefilled_positions)} positions")
        
        # 根据 tokens_per_step (K) 计算 steps
        if tokens_per_step is not None and tokens_per_step > 0:
            num_masks = gen_length  # 非预填充模式，mask 数量等于 gen_length
            steps = max(1, num_masks // tokens_per_step)
            if debug:
                print(f"[DEBUG] tokens_per_step={tokens_per_step}, num_masks={num_masks}, steps={steps}")

    prompt_index = (x != mask_id)
    
    # 累积熵统计
    cumulative_entropy = 0.0
    
    # 单调性残差数据收集（仅用于 Info-Gain Sampler 模式）
    monotone_residual_data = []

    assert gen_length % block_length == 0
    num_blocks = gen_length // block_length

    assert steps % num_blocks == 0
    steps = steps // num_blocks

    for num_block in range(num_blocks):
        
        block_mask_index = (x[:, prompt_len + num_block * block_length: prompt_len + (num_block + 1) * block_length:] == mask_id)
        num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps)
        for i in range(steps):
            mask_index = (x == mask_id)
            if cfg_scale > 0.:
                un_x = x.clone()
                un_x[prompt_index] = mask_id
                x_ = torch.cat([x, un_x], dim=0)
                logits = model(x_).logits
                # Dream 模型需要 shift logits
                if is_dream:
                    logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1)
                logits, un_logits = torch.chunk(logits, 2, dim=0)
                logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
                # Apply EOS penalty after CFG
                logits = apply_eos_penalty(logits, model, eos_penalty)
            else:
                logits = model(x).logits
                # Dream 模型需要 shift logits
                if is_dream:
                    logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1)
                # Apply EOS penalty
                logits = apply_eos_penalty(logits, model, eos_penalty)

            logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
            x0 = torch.argmax(logits_with_noise, dim=-1) # b, l

            if remasking == 'low_confidence':
                p = F.softmax(logits, dim=-1)
                x0_p = torch.squeeze(
                    torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l
            elif remasking == 'random':
                x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
            else:
                raise NotImplementedError(remasking)

            x0_p[:, prompt_len + (num_block + 1) * block_length:] = -np.inf

            x0 = torch.where(mask_index, x0, x)
            
            # 计算启发值（根据 heuristic 参数选择不同的启发函数）
            block_x0_p = x0_p[:, prompt_len + num_block * block_length:prompt_len + (num_block + 1) * block_length]
            block_x0 = x0[:, prompt_len + num_block * block_length:prompt_len + (num_block + 1) * block_length]
            block_logits = logits[:, prompt_len + num_block * block_length:prompt_len + (num_block + 1) * block_length]
            
            if heuristic == 'pc':
                # PC-Sampler 启发值
                heuristic_scores = pc_sampler_function(
                    probabilities=block_x0_p,
                    token_ids=block_x0,
                lambda_val=lambd,
                alpha=alpha,
                bg_freq_tensor=BASE_LINE
            )
            elif heuristic == 'confidence':
                # 直接使用置信度作为启发值
                heuristic_scores = block_x0_p
            elif heuristic == 'neg_entropy':
                # 使用负熵作为启发值（熵越低，越确定，启发值越高）
                block_probs_for_entropy = F.softmax(block_logits, dim=-1)
                entropy = -torch.sum(block_probs_for_entropy * torch.log(block_probs_for_entropy + 1e-10), dim=-1)
                heuristic_scores = -entropy  # 负熵，熵越低值越大
            elif heuristic == 'margin':
                # 使用 margin（top1 - top2 概率差）作为启发值
                block_probs_for_margin = F.softmax(block_logits, dim=-1)
                top2_probs, _ = torch.topk(block_probs_for_margin, k=2, dim=-1)
                heuristic_scores = top2_probs[:, :, 0] - top2_probs[:, :, 1]  # margin
            elif heuristic == 'uniform':
                # 均匀随机启发值
                heuristic_scores = torch.rand_like(block_x0_p)
            else:
                raise ValueError(f"Unknown heuristic: {heuristic}. Supported: 'pc', 'confidence', 'neg_entropy', 'margin', 'uniform'")
            
            # 只考虑当前 block 内的 mask 位置
            # 非 mask 位置设为 -inf，这样选择最大值时会被排除
            block_mask = mask_index[:, prompt_len + num_block * block_length:prompt_len + (num_block + 1) * block_length]
            confidence = torch.where(block_mask, heuristic_scores, torch.tensor(-np.inf, device=x.device))
            
            transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
            
            # 计算熵（用于调试）
            block_logits = logits[:, prompt_len + num_block * block_length:prompt_len + (num_block + 1) * block_length]
            block_probs = F.softmax(block_logits, dim=-1)
            block_entropy = -torch.sum(block_probs * torch.log(block_probs + 1e-10), dim=-1)  # [batch, block_len]
            
            # 获取 block 内的 confidence（模型预测置信度）
            block_confidence = x0_p[:, prompt_len + num_block * block_length:prompt_len + (num_block + 1) * block_length]
            
            for j in range(confidence.shape[0]):
                k = num_transfer_tokens[j, i].item()
                
                # DEBUG: 只在第一步输出调试信息
                step_debug = debug and (num_block == 0 and i == 0)
                if step_debug:
                    print(f"\n[DEBUG] ===== Step {i}, k={k}, candidate_number={candidate_number}, position_temperature={position_temperature} =====")
                    # 获取有效位置
                    valid_mask = confidence[j] > -np.inf
                    valid_indices = torch.where(valid_mask)[0]
                    valid_conf = confidence[j][valid_mask]
                    
                    if len(valid_conf) > 0:
                        # 展示最大的 10 个启发值及其相关指标（将被选择）
                        num_show = min(10, len(valid_conf))
                        top_heuristic, top_idx_in_valid = torch.topk(valid_conf, k=num_show, largest=True)
                        top_positions = valid_indices[top_idx_in_valid]
                        
                        print(f"[DEBUG] Top-{num_show} 启发值（最大，将被选择）:")
                        print(f"  {'位置':<6} {'启发值':<12} {'置信度':<12} {'熵':<12}")
                        print(f"  {'-'*42}")
                        for idx in range(num_show):
                            pos = top_positions[idx].item()
                            h_val = top_heuristic[idx].item()
                            conf_val = block_confidence[j, pos].item()
                            entropy_val = block_entropy[j, pos].item()
                            print(f"  {pos:<6} {h_val:<12.4f} {conf_val:<12.4f} {entropy_val:<12.4f}")
                        
                        # 也展示最小的几个
                        bottom_heuristic, bottom_idx_in_valid = torch.topk(valid_conf, k=num_show, largest=False)
                        bottom_positions = valid_indices[bottom_idx_in_valid]
                        
                        print(f"\n[DEBUG] Bottom-{num_show} 启发值（最小）:")
                        print(f"  {'位置':<6} {'启发值':<12} {'置信度':<12} {'熵':<12}")
                        print(f"  {'-'*42}")
                        for idx in range(num_show):
                            pos = bottom_positions[idx].item()
                            h_val = bottom_heuristic[idx].item()
                            conf_val = block_confidence[j, pos].item()
                            entropy_val = block_entropy[j, pos].item()
                            print(f"  {pos:<6} {h_val:<12.4f} {conf_val:<12.4f} {entropy_val:<12.4f}")
                        print()
                
                # candidate_number<=1 或 position_temperature<=0 时退化为 PC-Sampler
                if candidate_number <= 1 or position_temperature <= 0 or k == 0:
                    # PC-Sampler：选择启发值最大的 k 个位置
                    _, select_index = torch.topk(confidence[j], k=k, largest=True)
                    # 累积熵：累加被选择位置的熵
                    if return_cumulative_entropy and k > 0:
                        selected_entropy = block_entropy[j, select_index].sum().item()
                        cumulative_entropy += selected_entropy
                    # 计算单调性残差（即使 PC-Sampler 模式也计算）
                    if save_monotone_residual_path is not None and k > 0:
                        monotone_data = compute_monotone_residual(
                            model=model,
                            x=x,
                            x0=x0,
                            logits=logits,
                            select_index=select_index,
                            block_start=prompt_len + num_block * block_length,
                            block_end=prompt_len + (num_block + 1) * block_length,
                            mask_id=mask_id,
                            is_dream=is_dream
                        )
                        if monotone_data is not None:
                            monotone_data['block'] = num_block
                            monotone_data['step'] = i
                            monotone_residual_data.append(monotone_data)
                else:
                    # IG-Sampler：采样候选动作，计算信息增益，选择最优
                    remaining_steps = steps - i  # 剩余转移步数
                    select_index, selected_x, monotone_data = ig_sampler_select(
                        model=model,
                        x=x,
                        x0=x0,
                        logits=logits,
                        confidence=confidence[j],
                        k=k,
                        candidate_number=candidate_number,
                        position_temperature=position_temperature,
                        block_start=prompt_len + num_block * block_length,
                        block_end=prompt_len + (num_block + 1) * block_length,
                        mask_id=mask_id,
                        remaining_steps=remaining_steps,
                        debug=step_debug,
                        is_dream=is_dream
                    )
                    # 收集单调性残差数据
                    if monotone_data is not None and save_monotone_residual_path is not None:
                        monotone_data['block'] = num_block
                        monotone_data['step'] = i
                        monotone_residual_data.append(monotone_data)
                    # 累积熵：累加被选择位置的熵
                    if return_cumulative_entropy and k > 0:
                        selected_entropy = block_entropy[j, select_index].sum().item()
                        cumulative_entropy += selected_entropy
                    # 直接使用已计算好的后继状态，不再额外前向计算
                    if selected_x is not None:
                        x = selected_x
                if return_order:
                    if num_block+1 not in orders:
                        orders[num_block+1] = []
                        orders[num_block+1].append(select_index.tolist())
                        # 已经用 selected_x 更新了 x，跳过后面的 transfer_index 更新
                        transfer_index = None
                
                if transfer_index is not None:
                    transfer_index[j, select_index + prompt_len + num_block * block_length] = True
                if return_order:
                        if num_block+1 not in orders:
                            orders[num_block+1] = []
                        orders[num_block+1].append(select_index.tolist())
            
                if transfer_index is not None:
                    x[transfer_index] = x0[transfer_index]
    
    # 保存单调性残差数据到文件
    if save_monotone_residual_path is not None and len(monotone_residual_data) > 0:
        import json
        os.makedirs(os.path.dirname(save_monotone_residual_path) if os.path.dirname(save_monotone_residual_path) else '.', exist_ok=True)
        with open(save_monotone_residual_path, 'w', encoding='utf-8') as f:
            json.dump(monotone_residual_data, f, indent=2, ensure_ascii=False)
        print(f"[INFO] 单调性残差数据已保存到: {save_monotone_residual_path}")
        print(f"[INFO] 共 {len(monotone_residual_data)} 条记录")
        # 统计信息
        residuals = [d['monotone_residual'] for d in monotone_residual_data]
        print(f"[INFO] 单调性残差统计: min={min(residuals):.6f}, max={max(residuals):.6f}, mean={sum(residuals)/len(residuals):.6f}")
        print(f"[INFO] 所有残差 < 0: {all(r < 0 for r in residuals)}")
    
    if return_order and return_cumulative_entropy:
        return x, orders, cumulative_entropy
    elif return_order:
        return x, orders
    elif return_cumulative_entropy:
        return x, cumulative_entropy
    return x


def compute_monotone_residual(model, x, x0, logits, select_index, block_start, block_end, mask_id, is_dream=False, x_next=None, next_logits=None):
    """
    计算单调性残差
    
    Args:
        model: 模型（如果提供了 next_logits 则可以为 None）
        x: 当前状态 [1, seq_len]
        x0: 预测的 token [1, seq_len]
        logits: 当前 logits [1, seq_len, vocab_size]
        select_index: 选中的位置索引（相对于 block 起始位置）[k]
        block_start: 当前 block 的起始位置
        block_end: 当前 block 的结束位置
        mask_id: mask token id
        is_dream: 是否是 Dream 模型（需要 shift logits）
        x_next: 可选，后继状态 [1, seq_len]。如果提供，则不需要重新构建
        next_logits: 可选，后继状态的 logits [1, seq_len, vocab_size]。如果提供，则不需要重新计算
    
    Returns:
        monotone_data: 包含 h_t, h_t_minus_1, cost, monotone_residual 的字典，如果无法计算则返回 None
    """
    if len(select_index) == 0:
        return None
    
    device = x.device
    
    # 计算全局位置索引（用于 cost 计算，无论是否提供了 x_next）
    global_select_positions = block_start + select_index
    
    # 构建后继状态（如果未提供）
    if x_next is None:
        x_next = x.clone()
        x_next[0, global_select_positions] = x0[0, global_select_positions]
    
    # 计算后继状态的 logits（如果未提供）
    if next_logits is None:
        with torch.no_grad():
            next_logits = model(x_next).logits  # [1, seq_len, vocab_size]
            if is_dream:
                next_logits = torch.cat([next_logits[:, :1], next_logits[:, :-1]], dim=1)
    
    # h(t): 当前状态掩码位置的熵的平均值
    current_mask = (x == mask_id)  # [1, seq_len]
    current_probs_all = F.softmax(logits, dim=-1)  # [1, seq_len, vocab_size]
    current_entropy_all = -torch.sum(current_probs_all * torch.log(current_probs_all + 1e-10), dim=-1)  # [1, seq_len]
    current_mask_entropy = current_entropy_all[current_mask]  # [num_current_masks]
    h_t = current_mask_entropy.mean().item() if len(current_mask_entropy) > 0 else 0.0
    
    # h(t-1): 转移后状态掩码位置的熵的平均值
    next_mask = (x_next == mask_id)  # [1, seq_len]
    next_probs_all = F.softmax(next_logits, dim=-1)  # [1, seq_len, vocab_size]
    next_entropy_all = -torch.sum(next_probs_all * torch.log(next_probs_all + 1e-10), dim=-1)  # [1, seq_len]
    next_mask_entropy = next_entropy_all[next_mask]  # [num_next_masks]
    h_t_minus_1 = next_mask_entropy.mean().item() if len(next_mask_entropy) > 0 else 0.0
    
    # cost: 转移动作的所有位置的熵的总和（在当前状态下）
    action_entropy = current_entropy_all[0, global_select_positions]  # [k]
    cost = action_entropy.sum().item()
    
    # monotone_residual = h(t) - h(t-1) - cost
    monotone_residual = h_t - h_t_minus_1 - cost
    
    monotone_data = {
        'h_t': h_t,
        'h_t_minus_1': h_t_minus_1,
        'cost': cost,
        'monotone_residual': monotone_residual,
        'num_current_masks': len(current_mask_entropy),
        'num_next_masks': len(next_mask_entropy),
        'k': len(select_index)
    }
    
    return monotone_data

def ig_sampler_select(model, x, x0, logits, confidence, k, candidate_number, position_temperature, block_start, block_end, mask_id, remaining_steps=1, debug=False, is_dream=False):
    """
    IG-Sampler 动作选择
    
    1. 用启发值作为 action_sampler 的分数，采样 candidate_number 个候选动作组合
    2. 对每个候选动作，生成后继状态（并行处理）
    3. 计算启发值：动作集合熵总和 + 下一状态熵总和 / 剩余步数
    4. 选择启发值最小的动作组合
    
    Args:
        model: 模型
        x: 当前状态 [1, seq_len]
        x0: 预测的 token [1, seq_len]
        logits: 当前 logits [1, seq_len, vocab_size]
        confidence: 当前 block 内每个位置的启发值 [block_len]
        k: 需要转移的 token 数量
        candidate_number: 候选动作数量
        position_temperature: 位置采样温度
        block_start: 当前 block 的起始位置
        block_end: 当前 block 的结束位置
        mask_id: mask token id
        remaining_steps: 剩余转移步数
        debug: 是否输出调试信息
        is_dream: 是否是 Dream 模型（需要 shift logits）
    
    Returns:
        select_index: 选中的位置索引（相对于 block 起始位置）
        selected_x: 选中的后继状态 [1, seq_len]，如果已计算则直接返回，否则为 None
    """
    device = x.device
    block_len = block_end - block_start
    
    # 获取有效的 mask 位置（confidence > -inf 的位置，非 mask 位置为 -inf）
    valid_mask = confidence > -np.inf
    valid_indices = torch.where(valid_mask)[0]
    num_valid = valid_indices.shape[0]
    
    # DEBUG: 输出原始位置的启发分数（最大的几个）
    if debug:
        print(f"[IG-Sampler DEBUG] k={k}, num_valid={num_valid}, candidate_number={candidate_number}")
        valid_conf = confidence[valid_indices]
        # 显示最大的 5 个
        top_k_conf, top_k_idx = torch.topk(valid_conf, k=min(5, num_valid), largest=True)
        print(f"[IG-Sampler DEBUG] Top-5 启发分数（最大）: {top_k_conf.tolist()}")
        print(f"[IG-Sampler DEBUG] Top-5 位置索引: {valid_indices[top_k_idx].tolist()}")
    
    if num_valid <= k:
        # 如果有效位置数量不超过 k，直接返回所有有效位置
        # 没有计算后继状态，返回 None
        return valid_indices, None, None
    
    # 获取有效位置的 confidence
    valid_confidence = confidence[valid_indices]
    
    # 计算采样概率（大的 confidence 有大的采样概率）
    if position_temperature > 0:
        # 使用 softmax 温度
        sample_logits = valid_confidence / position_temperature
    else:
        # 贪婪：直接选最大的 k 个
        _, select_index = torch.topk(confidence, k=k, largest=True)
        return select_index, None, None
    
    # ========== k=1 的优化：直接取加了 Gumbel 噪声后的 top candidate_number 个 ==========
    if k == 1:
        # 对 sample_logits 加单个 Gumbel 噪声
        gumbel_noise = -torch.log(-torch.log(torch.rand_like(sample_logits) + 1e-10) + 1e-10)
        perturbed_logits = sample_logits + gumbel_noise  # [num_valid]
        
        # 取 top candidate_number 个位置，每个位置作为一个候选动作
        num_candidates = min(candidate_number, num_valid)
        _, top_indices_in_valid = torch.topk(perturbed_logits, k=num_candidates, largest=True)
        candidate_positions = valid_indices[top_indices_in_valid]  # [num_candidates]
        
        # 每个位置是一个独立的候选动作（k=1，所以每个动作只包含 1 个位置）
        unique_actions = [candidate_positions[i:i+1] for i in range(num_candidates)]
        
        if debug:
            print(f"[IG-Sampler DEBUG] k=1 优化：采样了 {num_candidates} 个候选位置: {candidate_positions.tolist()}")
    else:
        # ========== k>1：使用 Gumbel-Top-k 采样多个动作组合 ==========
        sample_probs = F.softmax(sample_logits, dim=-1)
        
        # 并行采样 candidate_number 个动作组合（使用 Gumbel-Top-k）
        # 生成 [candidate_number, num_valid] 的 Gumbel 噪声
        gumbel_noise = -torch.log(-torch.log(torch.rand(candidate_number, num_valid, device=device) + 1e-10) + 1e-10)
        perturbed_logits = torch.log(sample_probs + 1e-10).unsqueeze(0) + gumbel_noise  # [candidate_number, num_valid]
        
        # 对每个候选取 top-k（采样概率最高的，即原始 confidence 最大的）
        _, sampled_indices_batch = torch.topk(perturbed_logits, k=min(k, num_valid), dim=-1)  # [candidate_number, k]
        
        # 转换为原始 block 内的索引
        candidate_actions_tensor = valid_indices[sampled_indices_batch]  # [candidate_number, k]
        
        # 去重：相同的动作组合只保留一个
        unique_actions = []
        seen = set()
        for i in range(candidate_number):
            action = candidate_actions_tensor[i]
            action_tuple = tuple(sorted(action.tolist()))
            if action_tuple not in seen:
                seen.add(action_tuple)
                unique_actions.append(action)
    
    if len(unique_actions) == 0:
        # fallback: 选择启发值最大的 k 个
        _, select_index = torch.topk(confidence, k=k, largest=True)
        return select_index, None, None
    
    # 并行构建后继状态
    num_candidates = len(unique_actions)
    x_batch = x.expand(num_candidates, -1).clone()  # [num_candidates, seq_len]
    
    # 构建索引用于并行赋值
    # 创建 [num_candidates, k] 的全局位置索引
    action_tensor = torch.stack(unique_actions)  # [num_candidates, k]
    global_positions = block_start + action_tensor  # [num_candidates, k]
    
    # 使用 scatter 并行赋值
    batch_indices = torch.arange(num_candidates, device=device).unsqueeze(1).expand(-1, k)  # [num_candidates, k]
    x0_values = x0[0, global_positions.flatten()].view(num_candidates, k)  # [num_candidates, k]
    
    # 并行赋值到 x_batch
    for c_idx in range(num_candidates):
        x_batch[c_idx, global_positions[c_idx]] = x0_values[c_idx]
    
    # 如果只有一个候选，直接返回后继状态，不需要比较
    if num_candidates == 1:
        # 先计算后继状态的 logits（需要前向传播）
        with torch.no_grad():
            next_logits_single = model(x_batch[0:1]).logits  # [1, seq_len, vocab_size]
            if is_dream:
                next_logits_single = torch.cat([next_logits_single[:, :1], next_logits_single[:, :-1]], dim=1)
        # 使用辅助函数计算单调性残差
        monotone_data = compute_monotone_residual(
            model=None,  # 不需要，因为已经提供了 next_logits
            x=x,
            x0=x0,
            logits=logits,
            select_index=unique_actions[0],
            block_start=block_start,
            block_end=block_end,
            mask_id=mask_id,
            is_dream=is_dream,
            x_next=x_batch[0:1],
            next_logits=next_logits_single
        )
        return unique_actions[0], x_batch[0:1], monotone_data
    
    # 批量计算后继状态的 logits（单次前向传播）
    with torch.no_grad():
        next_logits = model(x_batch).logits  # [num_candidates, seq_len, vocab_size]
        # Dream 模型需要 shift logits
        if is_dream:
            next_logits = torch.cat([next_logits[:, :1], next_logits[:, :-1]], dim=1)
    
    # 计算后继状态每个位置的熵
    next_probs = F.softmax(next_logits, dim=-1)  # [num_candidates, seq_len, vocab_size]
    next_entropy = -torch.sum(next_probs * torch.log(next_probs + 1e-10), dim=-1)  # [num_candidates, seq_len]
    
    # 只考虑剩余 mask 位置的熵
    remaining_mask = (x_batch == mask_id)  # [num_candidates, seq_len]
    
    # 计算后继状态的总熵（只计算 mask 位置）
    masked_next_entropy = torch.where(remaining_mask, next_entropy, torch.zeros_like(next_entropy))
    next_total_entropy = masked_next_entropy.sum(dim=-1)  # [num_candidates]
    
    # 计算当前状态的熵（用于动作集合的熵）
    current_probs = F.softmax(logits, dim=-1)  # [1, seq_len, vocab_size]
    current_entropy = -torch.sum(current_probs * torch.log(current_probs + 1e-10), dim=-1)  # [1, seq_len]
    
    # 计算每个候选动作的启发值：动作集合熵总和 + 下一状态熵总和 / 转移后剩余步数
    # 越小越好
    action_entropy_sum = torch.zeros(num_candidates, device=device)
    for c_idx in range(num_candidates):
        # 动作集合中每个位置的熵（在当前状态下）
        action_positions = unique_actions[c_idx]
        global_action_positions = block_start + action_positions
        action_entropy_sum[c_idx] = current_entropy[0, global_action_positions].sum()
    
    # 转移后的剩余步数
    remaining_steps_after = remaining_steps - 1
    
    # 启发值 = 动作集合熵总和 + 下一状态熵总和 / 转移后剩余步数
    heuristic_value = action_entropy_sum + next_total_entropy / max(remaining_steps_after, 1)
    
    # 选择启发值最小的候选动作
    best_idx = torch.argmin(heuristic_value).item()
    
    if debug:
        print(f"\n[IG-Sampler DEBUG] ===== 候选动作评估 =====")
        print(f"[IG-Sampler DEBUG] 当前剩余步数: {remaining_steps}, 转移后剩余步数: {remaining_steps_after}")
        print(f"[IG-Sampler DEBUG] 共 {num_candidates} 个候选动作")
        print(f"  {'候选':<6} {'位置':<20} {'动作熵':<12} {'后继熵':<12} {'后继熵/步数':<12} {'启发值':<12} {'剩余mask':<10}")
        print(f"  {'-'*96}")
        for c_idx in range(num_candidates):
            action_pos = unique_actions[c_idx].tolist()
            action_ent = action_entropy_sum[c_idx].item()
            next_ent = next_total_entropy[c_idx].item()
            next_ent_per_step = next_ent / max(remaining_steps_after, 1)
            h_value = heuristic_value[c_idx].item()
            remaining_masks = remaining_mask[c_idx].sum().item()
            marker = " <-- BEST" if c_idx == best_idx else ""
            print(f"  {c_idx:<6} {str(action_pos):<20} {action_ent:<12.4f} {next_ent:<12.4f} {next_ent_per_step:<12.4f} {h_value:<12.4f} {int(remaining_masks):<10}{marker}")
        print(f"\n[IG-Sampler DEBUG] 选择第 {best_idx} 个候选 (启发值最小={heuristic_value[best_idx].item():.4f})")
        print(f"[IG-Sampler DEBUG] 选中的位置: {unique_actions[best_idx].tolist()}")
    
    # ========== 计算单调性残差 ==========
    # 使用辅助函数计算单调性残差（使用已计算好的后继状态和 logits）
    selected_x = x_batch[best_idx:best_idx+1]  # [1, seq_len]
    next_logits_selected = next_logits[best_idx:best_idx+1]  # [1, seq_len, vocab_size]
    monotone_data = compute_monotone_residual(
        model=None,  # 不需要，因为已经提供了 next_logits
        x=x,
        x0=x0,
        logits=logits,
        select_index=unique_actions[best_idx],
        block_start=block_start,
        block_end=block_end,
        mask_id=mask_id,
        is_dream=is_dream,
        x_next=selected_x,
        next_logits=next_logits_selected
    )
    
    return unique_actions[best_idx], x_batch[best_idx:best_idx+1], monotone_data

@torch.no_grad()
def generate_with_eb_sampler(model, prompt, gamma=0.1, gen_length=128, temperature=0.,
                       cfg_scale=0., mask_id=126336, is_dream=False, eos_penalty=0.0):
    
    x = torch.full((1, prompt.shape[1] + gen_length), mask_id, dtype=torch.long).to(model.device)
    x[:, :prompt.shape[1]] = prompt.clone()
    prompt_index = (x != mask_id)

    while (x == mask_id).any():
        
        mask_index = (x == mask_id)
        
        if cfg_scale > 0.:
            un_x = x.clone()
            un_x[prompt_index] = mask_id
            x_ = torch.cat([x, un_x], dim=0)
            logits = model(x_).logits
            # Dream 模型需要 shift logits
            if is_dream:
                logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1)
            logits, un_logits = torch.chunk(logits, 2, dim=0)
            logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
            # Apply EOS penalty after CFG
            logits = apply_eos_penalty(logits, model, eos_penalty)
        else:
            logits = model(x).logits
            # Dream 模型需要 shift logits
            if is_dream:
                logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1)
            # Apply EOS penalty
            logits = apply_eos_penalty(logits, model, eos_penalty)

        logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
        predicted_tokens = torch.argmax(logits_with_noise, dim=-1)
        masked_logits = logits[mask_index]
        
        err_proxy = torch.distributions.Categorical(logits=masked_logits).entropy()

        masked_token_indices = mask_index.nonzero(as_tuple=True)[1]
        sorted_err_indices = torch.argsort(err_proxy)
        sorted_indices = masked_token_indices[sorted_err_indices]
        
        sorted_entropies = err_proxy[sorted_err_indices]
        
        acc_entropy = torch.cumsum(sorted_entropies, dim=0)
        cummax_entropy = torch.cummax(sorted_entropies, dim=0).values
        
        k = (acc_entropy - cummax_entropy <= gamma).sum()
        
        num_masks_available = len(sorted_indices)
        k = torch.clamp(k, min=1, max=num_masks_available)

        indices_to_unmask = sorted_indices[:k]
        
        x[0, indices_to_unmask] = predicted_tokens[0, indices_to_unmask]

    return x

@ torch.no_grad()
def generate_with_fast_dllm(model, prompt, steps=128, gen_length=128, block_length=128, temperature=0.,
             remasking='low_confidence', mask_id=126336, threshold=None, factor=None, is_dream=False, eos_penalty=0.0):
    '''
    Args:
        model: Mask predictor.
        prompt: A tensor of shape (1, L).
        steps: Sampling steps, less than or equal to gen_length.
        gen_length: Generated answer length.
        block_length: Block length, less than or equal to gen_length. If less than gen_length, it means using semi_autoregressive remasking.
        temperature: Categorical distribution sampling temperature.
        cfg_scale: Unsupervised classifier-free guidance scale.
        remasking: Remasking strategy. 'low_confidence' or 'random'.
        mask_id: The toke id of [MASK] is 126336.
        is_dream: 是否是 Dream 模型（需要 shift logits）
    '''
    x = torch.full((1, prompt.shape[1] + gen_length), mask_id, dtype=torch.long).to(model.device)
    x[:, :prompt.shape[1]] = prompt.clone()

    assert gen_length % block_length == 0
    num_blocks = gen_length // block_length

    assert steps % num_blocks == 0
    steps = steps // num_blocks

    nfe = 0
    for num_block in range(num_blocks):
        block_mask_index = (x[:, prompt.shape[1] + num_block * block_length: prompt.shape[1] + (num_block + 1) * block_length] == mask_id)
        num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps)
        i = 0
        while True:
            nfe += 1
            mask_index = (x == mask_id)
            logits = model(x).logits
            # Dream 模型需要 shift logits
            if is_dream:
                logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1)
            
            # Apply EOS penalty
            logits = apply_eos_penalty(logits, model, eos_penalty)
            
            mask_index[:, prompt.shape[1] + (num_block + 1) * block_length:] = 0
            if factor is None:
                x0, transfer_index = get_transfer_index(logits, temperature, remasking, mask_index, x, num_transfer_tokens[:, i] if threshold is None else None, threshold)
            else:
                x0, transfer_index = get_transfer_index_dynamic(logits, temperature, remasking, mask_index, x, None, factor)
            x[transfer_index] = x0[transfer_index]
            i += 1
            if (x[:, prompt.shape[1] + num_block * block_length: prompt.shape[1] + (num_block + 1) * block_length] == mask_id).sum() == 0:
                break
    return x, nfe


# ============================================================================
# Info-Gain Sampler (Information-Gain Planner) Helper Functions
# ============================================================================

def compute_entropy_info_gain(probs=None, logits=None, eps=1e-12):
    """
    从 logits 或 probs 计算熵 H(p) = -sum(p * log(p))
    """
    if logits is not None:
        log_probs = F.log_softmax(logits, dim=-1)
        probs = torch.exp(log_probs)
    elif probs is None:
        raise ValueError("Either probs or logits must be provided")
    
    probs = torch.clamp(probs, min=eps, max=1.0)
    probs = probs / (probs.sum(dim=-1, keepdim=True) + eps)
    log_probs = torch.log(probs + eps)
    entropy = -torch.sum(probs * log_probs, dim=-1)
    return torch.clamp(entropy, min=0.0)

def get_confidence_scores_info_gain(probs, heuristic='confidence'):
    """
    根据不同的启发式方法计算置信度分数
    """
    if probs.shape[0] == 0:
        return torch.tensor([], dtype=probs.dtype, device=probs.device)
    
    if heuristic == 'confidence':
        confidence, _ = probs.max(dim=-1)
    elif heuristic == 'margin':
        sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
        top1_probs = sorted_probs[..., 0]
        if sorted_probs.shape[-1] > 1:
            top2_probs = sorted_probs[..., 1]
            confidence = top1_probs - top2_probs
        else:
            confidence = top1_probs
    elif heuristic == 'neg_entropy':
        confidence = -compute_entropy_info_gain(probs=probs)
    else:
        raise ValueError(f"Unknown heuristic: {heuristic}")
    
    if confidence.dim() == 0:
        confidence = confidence.unsqueeze(0)
    elif confidence.dim() > 1:
        confidence = confidence.flatten()
    
    return confidence

def position_sampler_info_gain(confidence, K, tau=1.0, device=None):
    """
    Position Sampler: 使用 Gumbel 噪声采样 K 个位置
    """
    if device is None:
        device = confidence.device
    
    if confidence.dim() == 0:
        confidence = confidence.unsqueeze(0)
    elif confidence.dim() > 1:
        confidence = confidence.flatten()
    
    num_masked = confidence.shape[0]
    K = min(K, num_masked)
    
    if K == 0 or num_masked == 0:
        return torch.tensor([], dtype=torch.long, device=device)
    
    if tau == 0.0:
        _, sampled_indices = torch.topk(confidence, K)
        return sampled_indices
    
    try:
        import torch.distributions as dists
        gumbel = dists.Gumbel(0, 1).sample(confidence.shape).to(device)
    except:
        # 如果distributions不可用，使用简单的随机采样
        gumbel = -torch.log(-torch.log(torch.rand_like(confidence) + 1e-10) + 1e-10)
    scores = torch.log(confidence.clamp(min=1e-10)) + tau * gumbel
    _, sampled_indices = torch.topk(scores, K)
    return sampled_indices

def token_sampler_info_gain(logits, temperature=0.0, top_p=None, top_k=None, device=None):
    """
    Token Sampler: 从概率分布中采样 token
    与 PC-Sampler 保持一致：使用 add_gumbel_noise 和 logits argmax
    """
    if device is None:
        device = logits.device
    
    # 使用与 PC-Sampler 相同的 Gumbel 噪声方法
    logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
    
    # 直接对 logits 做 argmax（与 PC-Sampler 一致）
    tokens = torch.argmax(logits_with_noise, dim=-1)
    
    return tokens

class ActionSet:
    """表示一个动作集（位置-token 对的集合）"""
    def __init__(self, positions, tokens):
        self.positions = positions
        self.tokens = tokens

def action_sampler_info_gain(logits, mask_positions, uncertainty_scores, K, beam_actions=1, position_temperature=1.0, 
                       token_temperature=0.0, top_p=None, top_k=None, heuristic='confidence',
                       similarity_threshold=0.5, max_resample_attempts=3, device=None):
    """
    Action Sampler: 生成多个候选动作集
    
    按照论文描述，但调整顺序为先采样Token再采样位置：
    (1) Token Sampling: 先为所有位置从 p_θ 中采样 tokens v_ℓ，使用温度 τ_token
    (2) Position Sampling: 然后从 M_t 中选择位置 ℓ，使用基于不确定性分数 φ(p_θ) 的 softmax，温度 τ_pos
    每个候选动作 a = (ℓ, v_ℓ) 是通过配对这些采样形成的
    
    Args:
        logits: [num_masked, vocab_size] 每个mask位置的logits
        mask_positions: [num_masked] 每个mask位置在原始序列中的位置
        uncertainty_scores: [num_masked] 不确定性分数（由外部计算好传入）
        K: 每个候选动作包含的位置数量
        beam_actions: 候选动作数量
        position_temperature: 位置采样温度 τ_pos
        token_temperature: token采样温度 τ_token
        heuristic: 启发量类型（保留参数，用于兼容性）
    """
    if device is None:
        device = logits.device
    
    num_masked = logits.shape[0]
    K = min(K, num_masked)
    
    if K == 0:
        return []
    
    # 使用传入的不确定性分数
    uncertainty = uncertainty_scores
    
    if uncertainty.shape[0] == 0 or uncertainty.shape[0] != num_masked:
        return []
    
    candidates = []
    
    # 生成 beam_actions 个候选动作
    # 每个候选动作是位置-标记对的集合（K=1时为单个对，K>1时为K个对）
    for c in range(beam_actions):
        resample_count = 0
        valid_candidate = False
        
        while not valid_candidate and resample_count <= max_resample_attempts:
            # (1) Token Sampling: 先为所有位置采样 tokens
            # 从 p_θ 中采样 tokens v_ℓ，使用温度 τ_token
            all_sampled_tokens = token_sampler_info_gain(logits, token_temperature, top_p, top_k, device)
            
            # (2) Position Sampling: 使用基于不确定性分数的 softmax 采样位置
            # 对不确定性分数应用 softmax（带温度 τ_pos）
            if K == 1:
                # K=1: 采样单个位置
                if position_temperature == 0.0:
                    # 确定性选择：选择不确定性最高的位置
                    _, sampled_mask_idx = torch.topk(uncertainty, 1)
                    sampled_mask_idx = sampled_mask_idx[0]
                else:
                    # 使用 softmax 采样
                    uncertainty_logits = uncertainty / position_temperature
                    position_probs = torch.softmax(uncertainty_logits, dim=0)
                    # 从分布中采样
                    sampled_mask_idx = torch.multinomial(position_probs, 1)[0]
                
                # 确保 idx 是整数
                idx_int = sampled_mask_idx.item() if isinstance(sampled_mask_idx, torch.Tensor) else sampled_mask_idx
                
                # 获取对应的 token 和位置
                sampled_token = all_sampled_tokens[idx_int:idx_int+1]
                selected_position = mask_positions[idx_int:idx_int+1]
                
                # 创建动作集：单个位置-标记对 (ℓ, v_ℓ)
                action_set = ActionSet(
                    positions=selected_position,
                    tokens=sampled_token
                )
            else:
                # K > 1: 采样 K 个不同的位置
                # 使用不放回采样，每次采样一个位置
                remaining_indices = list(range(num_masked))
                sampled_indices = []
                
                for k in range(K):
                    if len(remaining_indices) == 0:
                        break
                    
                    # 计算剩余位置的不确定性
                    remaining_uncertainty = uncertainty[remaining_indices]
                    
                    if position_temperature == 0.0:
                        # 确定性选择
                        local_idx = torch.argmax(remaining_uncertainty)
                        global_idx = remaining_indices[local_idx.item()]
                    else:
                        # 使用 softmax 采样
                        uncertainty_logits = remaining_uncertainty / position_temperature
                        position_probs = torch.softmax(uncertainty_logits, dim=0)
                        local_idx = torch.multinomial(position_probs, 1)[0]
                        global_idx = remaining_indices[local_idx.item()]
                    
                    sampled_indices.append(global_idx)
                    remaining_indices.remove(global_idx)
                
                # 获取对应的 tokens 和位置
                sampled_mask_indices = torch.tensor(sampled_indices, device=device, dtype=torch.long)
                sampled_tokens = all_sampled_tokens[sampled_mask_indices]
                selected_positions = mask_positions[sampled_mask_indices]
                
                # 创建动作集：K 个位置-标记对
                action_set = ActionSet(
                    positions=selected_positions,
                    tokens=sampled_tokens
                )
            
            # 检查与已有候选集的相似度（仅当 K > 1 时）
            if K > 1:
                valid_candidate = True
                for existing in candidates:
                    # 简单的相似度计算：位置重叠度
                    existing_positions_set = set(existing.positions.cpu().tolist())
                    current_positions_set = set(action_set.positions.cpu().tolist())
                    intersection = len(existing_positions_set & current_positions_set)
                    max_size = max(len(existing_positions_set), len(current_positions_set))
                    if max_size > 0:
                        similarity = intersection / max_size
                        if similarity > similarity_threshold:
                            valid_candidate = False
                            resample_count += 1
                            break
            else:
                valid_candidate = True
            
            if valid_candidate or resample_count > max_resample_attempts:
                candidates.append(action_set)
                break
    
    return candidates

def action_selector_info_gain(model, x, candidates, mask_token_id, current_logits, mask_positions, 
                       num_masks=-1, device=None):
    """
    Action Selector: 选择最优动作集
    简化计算：通过最小化状态值来选择最优动作集
    """
    if len(candidates) == 0:
        raise ValueError("No candidates to select from")
    
    if len(candidates) == 1:
        single_action = candidates[0]
        x_next = x.clone()
        positions = single_action.positions.to(device).long()
        tokens = single_action.tokens.to(device).long()
        if len(positions) > 0:
            x_next[0, positions] = tokens
        return x_next
    
    best_action = None
    best_score = float('inf')
    
    num_masked = len(mask_positions)
    if num_masks == -1:
        num_masks_to_consider = num_masked
    else:
        num_masks_to_consider = min(num_masks, num_masked)
    
    # 计算所有mask位置的熵
    current_all_entropy = compute_entropy_info_gain(logits=current_logits)
    current_entropy_by_pos = {mask_positions[i].item(): current_all_entropy[i] for i in range(len(mask_positions))}
    
    # 批量计算所有候选动作的后继状态
    # 构建批量x_next: [num_candidates, seq_len]
    batch_size = x.shape[0]
    num_candidates = len(candidates)
    seq_len = x.shape[1]
    
    x_next_batch = x.repeat(num_candidates, 1)  # [num_candidates, seq_len]
    
    # 应用每个动作到对应的序列
    for i, action in enumerate(candidates):
        positions = action.positions.to(device).long()
        tokens = action.tokens.to(device).long()
        if len(positions) > 0:
            x_next_batch[i, positions] = tokens
    
    # 批量前向传播计算所有候选的后继状态logits
    with torch.no_grad():
        batch_next_logits = model(x_next_batch).logits  # [num_candidates, seq_len, vocab_size]
        # 注意：MDM模型直接预测当前位置的token，不需要偏移
    
    # 分别处理每个候选动作的结果
    for i, action in enumerate(candidates):
        # 计算近似损失：所选位置的熵值之和
        approx_loss = torch.tensor(0.0, device=device)
        for pos, tok in zip(action.positions, action.tokens):
            pos_item = pos.item()
            if pos_item in current_entropy_by_pos:
                approx_loss = approx_loss + current_entropy_by_pos[pos_item]
        
        # 获取当前候选的x_next和next_logits
        x_next = x_next_batch[i:i+1]  # [1, seq_len]
        next_logits = batch_next_logits[i]  # [seq_len, vocab_size]
        
        # 找出所有仍然是 mask 的位置
        next_mask_index = (x_next == mask_token_id)
        next_mask_positions = torch.where(next_mask_index[0])[0]
        
        if len(next_mask_positions) > 0:
            if num_masks > 0:
                num_consider = min(num_masks, len(next_mask_positions))
                considered_next_mask_positions = next_mask_positions[:num_consider]
            else:
                considered_next_mask_positions = next_mask_positions
            
            next_mask_logits = next_logits[considered_next_mask_positions]
            next_entropy = compute_entropy_info_gain(logits=next_mask_logits)
            total_next_entropy = next_entropy.sum()
            
            num_remaining_masks = len(next_mask_positions)
            num_transfer = len(action.positions)
            
            if num_remaining_masks > 0:
                avg_next_entropy = total_next_entropy / num_remaining_masks
            else:
                avg_next_entropy = torch.tensor(0.0, device=device)
        else:
            avg_next_entropy = torch.tensor(0.0, device=device)
            num_transfer = len(action.positions)
        
        # 总分数（越小越好）
        score = approx_loss + avg_next_entropy * num_transfer
        score_value = float(score.item())
        
        if score_value < best_score:
            best_score = score_value
            best_action = action
    
    if best_action is None:
        best_action = candidates[0]
    
    x_next = x.clone()
    positions = best_action.positions.to(device).long()
    tokens = best_action.tokens.to(device).long()
    if len(positions) > 0:
        x_next[0, positions] = tokens
    
    return x_next

@torch.no_grad()
def generate_with_info_gain(model, prompt, steps=128, gen_length=128, block_length=128, temperature=0.,
                     cfg_scale=0., remasking='low_confidence', mask_id=126336, return_order=False,
                     position_temperature=0.1, candidate_number=8,
                     lambd=0.0, alpha=10, baseline_name='../data/baseline/reference_corpus.json', debug=True,
                     prefilled_positions=None, heuristic='confidence', return_cumulative_entropy=False,
                     tokens_per_step=None, is_dream=False, save_monotone_residual_path=None, eos_penalty=0.0):
    """
    Info-Gain Sampler 生成函数（基于 PC-Sampler 扩展）
    
    Args:
        position_temperature: 位置采样温度。<=0 时退化为 PC-Sampler
        candidate_number: 候选动作数量。<=1 时退化为 PC-Sampler
        lambd: PC-Sampler 的 lambda 参数
        alpha: PC-Sampler 的 alpha 参数
        baseline_name: baseline 文件路径
        debug: 是否输出调试信息（默认开启）
        prefilled_positions: 预填充位置列表 [(position, token_id), ...]
        heuristic: 启发函数类型，可选 'pc', 'confidence', 'neg_entropy', 'margin', 'uniform'
        return_cumulative_entropy: 是否返回累积熵
        tokens_per_step: 每步解码的 token 数量 (K)。若设置，则 steps = num_masks // K
        is_dream: 是否是 Dream 模型（需要 shift logits）
        save_monotone_residual_path: 保存单调性残差数据的文件路径（JSON格式）
    """
    return generate(
        model=model,
        prompt=prompt,
        steps=steps,
        gen_length=gen_length,
        block_length=block_length,
        lambd=lambd,
        alpha=alpha,
        baseline_name=baseline_name,
        temperature=temperature,
        cfg_scale=cfg_scale,
        remasking=remasking,
        mask_id=mask_id,
        return_order=return_order,
        candidate_number=candidate_number,
        position_temperature=position_temperature,
        debug=debug,
        prefilled_positions=prefilled_positions,
        heuristic=heuristic,
        return_cumulative_entropy=return_cumulative_entropy,
        tokens_per_step=tokens_per_step,
        is_dream=is_dream,
        save_monotone_residual_path=save_monotone_residual_path,
        eos_penalty=eos_penalty
    )