from tqdm import tqdm
import os
import torch.nn.functional as F
from watermark.our_watermark import OurBlacklistLogitsProcessor
from watermark.old_watermark2_sentence_fullGenerated import BlacklistLogitsProcessor
from watermark.gptwm import GPTWatermarkLogitsWarper
from watermark.watermark_v2 import WatermarkLogitsProcessor
from transformers import LogitsProcessorList, LogitsProcessor
# from transformers import StoppingCriteria, StoppingCriteriaList
import torch, gc
from typing import List
import time
from collections import Counter
import math
import numpy as np
import random
import re
import logging
from datetime import datetime
from transformers import AutoTokenizer, AutoModelForSequenceClassification, GenerationConfig
# from sentence_transformers import SentenceTransformer, util
from transformers import LogitsProcessor
from transformers import DynamicCache

class TextMetrics:
    def __init__(self, tokenizer, device):
        self.tokenizer = tokenizer
        self.device = device
    
    def tokenize(self, text):
        """使用tokenizer对文本进行分词"""
        tokens = self.tokenizer.tokenize(text)
        return tokens
    
    def get_ngrams(self, tokens, n):
        """获取n-gram列表"""
        if len(tokens) < n:
            return []
        return [tuple(tokens[i:i+n]) for i in range(len(tokens) - n + 1)]
    
    def calculate_bleu(self, reference, candidate, max_n=4, weights=None):
        """
        计算BLEU评分
        
        Args:
            reference: 参考句子(字符串)
            candidate: 候选句子(字符串)
            max_n: 最大n-gram长度，默认为4
            weights: 各n-gram的权重，默认为均匀权重
        
        Returns:
            BLEU评分 (0-1之间的浮点数)
        """
        if weights is None:
            weights = [1.0/max_n] * max_n
        
        # 分词
        ref_tokens = self.tokenize(reference)
        cand_tokens = self.tokenize(candidate)
        
        if len(cand_tokens) == 0:
            return 0.0
        
        # 计算各n-gram的精确度
        precisions = []
        
        for n in range(1, max_n + 1):
            ref_ngrams = Counter(self.get_ngrams(ref_tokens, n))
            cand_ngrams = Counter(self.get_ngrams(cand_tokens, n))
            
            if len(cand_ngrams) == 0:
                precisions.append(0.0)
                continue
            
            # 计算匹配的n-gram数量
            matches = 0
            for ngram, count in cand_ngrams.items():
                matches += min(count, ref_ngrams.get(ngram, 0))
            
            precision = matches / sum(cand_ngrams.values())
            precisions.append(precision)
        
        # 计算几何平均数（对数空间）
        if any(p == 0 for p in precisions):
            return 0.0
        
        log_precisions = [math.log(p) for p in precisions]
        geometric_mean = math.exp(sum(w * lp for w, lp in zip(weights, log_precisions)))
        
        # 计算brevity penalty
        ref_len = len(ref_tokens)
        cand_len = len(cand_tokens)
        
        if cand_len > ref_len:
            bp = 1.0
        else:
            bp = math.exp(1 - ref_len / cand_len) if cand_len > 0 else 0.0
        
        return bp * geometric_mean
    
    def lcs_length(self, seq1, seq2):
        """
        计算最长公共子序列的长度
        使用动态规划算法
        """
        m, n = len(seq1), len(seq2)
        
        # 创建DP表
        dp = [[0] * (n + 1) for _ in range(m + 1)]
        
        # 填充DP表
        for i in range(1, m + 1):
            for j in range(1, n + 1):
                if seq1[i-1] == seq2[j-1]:
                    dp[i][j] = dp[i-1][j-1] + 1
                else:
                    dp[i][j] = max(dp[i-1][j], dp[i][j-1])
        
        return dp[m][n]
    
    def calculate_rouge_l(self, reference, candidate):
        """
        计算ROUGE-L评分
        
        Args:
            reference: 参考句子(字符串)
            candidate: 候选句子(字符串)
        
        Returns:
            字典包含precision, recall, f1三个指标
        """
        # 分词
        ref_tokens = self.tokenize(reference)
        cand_tokens = self.tokenize(candidate)
        
        if len(ref_tokens) == 0 or len(cand_tokens) == 0:
            return {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}
        
        # 计算LCS长度
        lcs_len = self.lcs_length(ref_tokens, cand_tokens)
        
        # 计算precision, recall和F1
        precision = lcs_len / len(cand_tokens)
        recall = lcs_len / len(ref_tokens)
        
        if precision + recall == 0:
            f1 = 0.0
        else:
            f1 = 2 * precision * recall / (precision + recall)
        
        return {
            'precision': precision,
            'recall': recall,
            'f1': f1
        }
    
    def evaluate_batch(self, references, candidates):
        """
        批量评估多个句子对
        
        Args:
            references: 参考句子列表
            candidates: 候选句子列表
        
        Returns:
            包含平均BLEU和ROUGE-L指标的字典
        """
        if len(references) != len(candidates):
            raise ValueError("参考句子和候选句子数量不匹配")
        
        bleu_scores = []
        rouge_scores = {'precision': [], 'recall': [], 'f1': []}
        
        for ref, cand in zip(references, candidates):
            # 计算BLEU
            bleu = self.calculate_bleu(ref, cand)
            bleu_scores.append(bleu)
            
            # 计算ROUGE-L
            rouge = self.calculate_rouge_l(ref, cand)
            rouge_scores['precision'].append(rouge['precision'])
            rouge_scores['recall'].append(rouge['recall'])
            rouge_scores['f1'].append(rouge['f1'])
        
        return {
            'bleu': {
                'mean': np.mean(bleu_scores),
                'std': np.std(bleu_scores),
                'scores': bleu_scores
            },
            'rouge_l': {
                'precision': {
                    'mean': np.mean(rouge_scores['precision']),
                    'std': np.std(rouge_scores['precision'])
                },
                'recall': {
                    'mean': np.mean(rouge_scores['recall']),
                    'std': np.std(rouge_scores['recall'])
                },
                'f1': {
                    'mean': np.mean(rouge_scores['f1']),
                    'std': np.std(rouge_scores['f1'])
                }
            }
        }

# class ParallelLogitsProcessor(LogitsProcessor):
#     def __init__(self, processors: List[LogitsProcessor]):
#         self.processors = processors

#     def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
#         batch_size = input_ids.shape[0]
#         batch_black_list = [None for _ in range(batch_size)]
        
#         if batch_size != len(self.processors):
#             raise ValueError(
#                 f"Batch size ({batch_size}) must match number of processors ({len(self.processors)})"
#             )

#         processed_scores = []
#         for i in range(batch_size):
#             processor = self.processors[i]
#             sample_input_ids = input_ids[i].unsqueeze(0)
#             sample_scores = scores[i].unsqueeze(0)
            
#             black_list = None
#             if processor is None:
#                 modified_scores = sample_scores
#             else:
#                 # 统一处理不同处理器的返回格式
#                 try:
#                     result = processor(sample_input_ids, sample_scores)
#                     if isinstance(result, tuple):
#                         modified_scores, black_list = result
#                     else:
#                         modified_scores = result
#                         black_list = None
#                 except Exception as e:
#                     print(f"Processor {i} failed: {e}")
#                     modified_scores = sample_scores
#                     black_list = None
            
#             # 检查并修复异常值
#             modified_scores = self._fix_scores(modified_scores)
#             batch_black_list[i] = black_list
#             processed_scores.append(modified_scores[0])
        
#         return torch.stack(processed_scores, dim=0), batch_black_list
    
#     def _fix_scores(self, scores):
#         """修复logits中的异常值"""
#         # 检查并替换inf和nan值
#         scores = torch.where(torch.isinf(scores), torch.tensor(-1e9, dtype=scores.dtype, device=scores.device), scores)
#         scores = torch.where(torch.isnan(scores), torch.tensor(-1e9, dtype=scores.dtype, device=scores.device), scores)
        
#         # 确保数值在合理范围内，防止溢出
#         scores = torch.clamp(scores, min=-1e9, max=1e9)
        
#         return scores

class ParallelLogitsProcessor(LogitsProcessor):
    def __init__(self, processors: List[LogitsProcessor]):
        """
        并行处理多个LogitsProcessor
        
        Args:
            processors: 包含多个LogitsProcessor实例的列表，可以包含None值表示标准生成
        """
        self.processors = processors

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        """
        并行处理logits
        
        Args:
            input_ids: 当前生成的token IDs [batch_size, seq_len]
            scores: 模型输出的logits [batch_size, vocab_size]
            
        Returns:
            处理后的logits [batch_size, vocab_size]
        """
        # 检查批次大小是否与处理器数量匹配
        batch_size = input_ids.shape[0]
        batch_black_list = [None for _ in range(batch_size)]
        if batch_size != len(self.processors):
            raise ValueError(
                f"Batch size ({batch_size}) must match number of processors ({len(self.processors)})"
            )

        processed_scores = []
        for i in range(batch_size):
            
            # 为每个样本选择对应的处理器
            processor = self.processors[i]
            
            # 提取当前样本的输入和logits
            sample_input_ids = input_ids[i].unsqueeze(0)  # [1, seq_len]
            sample_scores = scores[i].unsqueeze(0)        # [1, vocab_size]
            
            # 应用处理器，如果processor为None则直接使用原始scores（标准生成）
            black_list = None
            if processor is None:
                modified_scores = sample_scores
            else:
                modified_scores, black_list = processor(sample_input_ids, sample_scores)
            batch_black_list[i] = black_list
            processed_scores.append(modified_scores[0])

        return torch.stack(processed_scores, dim=0), batch_black_list
    
class Generator():
    def __init__(self, args, tokenizer, model, dataset_name) -> None:
        self.dataset = dataset_name
        self.model_name = args.model
        self.mode = args.mode # watermark mode
        self.init_seed, self.dyna_seed, self.gamma, \
        self.delta, self.bl_type, self.num_beams, self.sampling_temp = args.initial_seed, args.dynamic_seed, args.gamma, args.delta, args.bl_type, args.num_beams, args.sampling_temp
        self.tokenizer = tokenizer
        self.model = model # language model
        self.device = next(self.model.parameters()).device
        self.all_token_ids = list(tokenizer.get_vocab().values())
        self.vocab_size = len(self.all_token_ids)
        if "qwen" in self.model_name:
            self.vocab_size = 152064
        # if self.vocab_size != model.config.padded_vocab_size:
        #     self.vocab_size = model.config.padded_vocab_size
        self.seeding_scheme = args.seeding_scheme
        self.select_green_tokens = args.select_green_tokens
        # 初始化种子池和其他参数
        self.K = args.K  # 每次生成的候选数量
        # self.used_seed = [] 
        # 新增：chunk生成相关参数
        self.chunk_size = args.chunk_size
        self.logger = self.setup_logger()
        self.alpha = args.alpha
        # self.bl_processor = BlacklistLogitsProcessor(
        #                                     bad_words_ids=None, 
        #                                     eos_token_id=tokenizer.eos_token_id, 
        #                                     vocab=self.all_token_ids, 
        #                                     vocab_size=self.vocab_size, 
        #                                     bl_proportion=1-self.gamma,
        #                                     bl_logit_bias=self.delta,
        #                                     bl_type=self.bl_type, 
        #                                     initial_seed=self.init_seed, 
        #                                     dynamic_seed=self.dyna_seed)
        # self.logit_processor_lst = LogitsProcessorList([self.bl_processor])
        
        self.metrics = TextMetrics(tokenizer, self.device)
        # if args.mode == 'new': 
        #     # 初始化TextMetrics用于计算ROUGE-L
            
        #     self.bl_processor = OurBlacklistLogitsProcessor(tokenizer=tokenizer,
        #                                 bad_words_ids=None, 
        #                                 eos_token_id=tokenizer.eos_token_id, 
        #                                 vocab=self.all_token_ids, 
        #                                 all_vocab_size=self.vocab_size, 
        #                                 bl_proportion=1-self.gamma,
        #                                 bl_logit_bias=self.delta,
        #                                 bl_type=self.bl_type, 
        #                                 initial_seed=self.init_seed, 
        #                                 dynamic_seed=self.dyna_seed)
        #     self.logit_processor_lst = LogitsProcessorList([self.bl_processor])
            
        # if args.mode == 'gpt':
            
        #     watermark_processor = GPTWatermarkLogitsWarper(vocab_size=self.vocab_size,
        #                                                 fraction=args.gamma,
        #                                                 strength=args.delta)
            
        #     self.logit_processor_lst = LogitsProcessorList([watermark_processor])   
            
        # if args.mode == 'v2':
        #     watermark_processor = WatermarkLogitsProcessor(vocab=list(tokenizer.get_vocab().values()),
        #                                                 gamma=args.gamma,
        #                                                 delta=args.delta,
        #                                                 seeding_scheme=args.seeding_scheme,
        #                                                 select_green_tokens=args.select_green_tokens)
        #     self.logit_processor_lst = LogitsProcessorList([watermark_processor])
        
    @staticmethod
    def simple_hash_safe(x, seed=42):
        """
        防溢出版本的整数哈希函数
        限制在32位有符号整数范围内
        
        Args:
            x: 输入整数
            seed: 种子值，默认42
        
        Returns:
            int: 哈希后的整数值（-2^31 到 2^31-1）
        """
        x = x & 0xFFFFFFFF  # 保持32位
        large_prime = 15485863
        
        result = (x * large_prime + seed) ^ (seed << 1)

        return result
    
    def setup_logger(self, log_file_path=None, console_output=False):
        """设置logger配置
        
        Args:
            log_file_path: 日志文件路径，如果为None则自动生成带时间戳的文件名
            console_output: 是否同时在控制台输出，默认为False
        """
        if log_file_path is None:
            # 创建带时间戳的日志文件名
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            log_file_path = f"./log/{self.model_name}_{self.dataset}_g{self.gamma}_d{self.delta}_{self.mode}_{timestamp}.log"
        
        # 创建logger
        logger = logging.getLogger('generate_test_sentence')
        logger.setLevel(logging.INFO)
        logger.propagate = False
        # 如果logger已经有handler，就不再添加
        if not logger.handlers:
            # 创建文件handler
            file_handler = logging.FileHandler(log_file_path, encoding='utf-8')
            file_handler.setLevel(logging.INFO)
            
            # 创建formatter
            formatter = logging.Formatter(
                '%(asctime)s - %(name)s - %(levelname)s - %(message)s',
                datefmt='%Y-%m-%d %H:%M:%S'
            )
            
            # 设置formatter
            file_handler.setFormatter(formatter)
            
            # 添加文件handler到logger
            logger.addHandler(file_handler)
            
            # 如果需要控制台输出，添加控制台handler
            if console_output:
                console_handler = logging.StreamHandler()
                console_handler.setLevel(logging.INFO)
                console_handler.setFormatter(formatter)
                logger.addHandler(console_handler)
        
        return logger
    def create_logits_processor_list(self, seeds):
        """根据种子列表创建多个logits processor"""
        processors = []
        if self.mode == 'new':
            for seed in seeds:
                bl_processor = OurBlacklistLogitsProcessor(
                    tokenizer=self.tokenizer,
                    bad_words_ids=None,
                    eos_token_id=self.tokenizer.eos_token_id,
                    vocab=self.all_token_ids,
                    all_vocab_size=self.vocab_size,
                    bl_proportion=1-self.gamma,
                    bl_logit_bias=self.delta,
                    bl_type=self.bl_type,
                    initial_seed=self.init_seed,
                    dynamic_seed=self.dyna_seed,
                    hash_seed=seed
                )
                processors.append(LogitsProcessorList([bl_processor]))
        elif self.mode == 'old':
            for seed in seeds:
                bl_processor = BlacklistLogitsProcessor(
                    bad_words_ids=None,
                    eos_token_id=self.tokenizer.eos_token_id,
                    vocab=self.all_token_ids,
                    vocab_size=self.vocab_size,
                    bl_proportion=1-self.gamma,
                    bl_logit_bias=self.delta,
                    bl_type=self.bl_type,
                    initial_seed=self.init_seed,
                    dynamic_seed=self.dyna_seed,
                    hash_seed=seed,
                    logger=self.logger
                )
                processors.append(LogitsProcessorList([bl_processor]))
        elif self.mode == 'gpt':
            for seed in seeds:
                bl_processor = GPTWatermarkLogitsWarper(
                    vocab_size=self.vocab_size,
                    fraction=self.gamma,
                    strength=self.delta,
                    watermark_key=seed
                )
                processors.append(LogitsProcessorList([bl_processor]))
        elif self.mode == 'v2':
            for seed in seeds:
                bl_processor = WatermarkLogitsProcessor(
                    vocab=list(self.tokenizer.get_vocab().values()),
                    gamma=self.gamma,
                    delta=self.delta,
                    seeding_scheme=self.seeding_scheme,
                    select_green_tokens=self.select_green_tokens,
                    hash_seed=seed
                )
                processors.append(LogitsProcessorList([bl_processor]))
        return processors
    
    # def generate_chunk_standard(self, input_ids, chunk_size):
    #     """生成标准输出（不加水印）的固定长度chunk"""
    #     device = self.device
    #     input_ids = input_ids.to(device)
        
    #     outputs = self.model.generate(
    #         input_ids,
    #         max_new_tokens=chunk_size,
    #         do_sample=True,
    #         top_k=0,
    #         temperature=self.sampling_temp,
    #         return_dict_in_generate=True,
    #         output_scores=True,
    #         pad_token_id=self.tokenizer.eos_token_id
    #     )
        
    #     input_length = input_ids.shape[-1]
    #     generated_tokens = outputs.sequences[0][input_length:]
    #     generated_scores = outputs.scores
        
    #     if len(generated_tokens) == 0:
    #         return {
    #             'text': '',
    #             'tokens': [],
    #             'scores': [],
    #             'sequences': input_ids,
    #             'finished': False
    #         }
        
    #     # 检查是否遇到EOS token
    #     eos_token_id = self.tokenizer.eos_token_id
    #     finished = False
    #     if eos_token_id in generated_tokens:
    #         eos_pos = (generated_tokens == eos_token_id).nonzero(as_tuple=True)[0][0].item()
    #         generated_tokens = generated_tokens[:eos_pos + 1]  # 包含EOS token
    #         generated_scores = generated_scores[:eos_pos + 1]
    #         finished = True
        
    #     generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
        
    #     return {
    #         'text': generated_text,
    #         'tokens': generated_tokens.tolist(),
    #         'scores': list(generated_scores) if generated_scores else [],
    #         'sequences': outputs.sequences,
    #         'finished': finished
    #     }
    
    # def generate_chunk_parallel(self, input_ids, chunk_size, processors):
    #     """并行生成多个候选chunk"""
    #     device =self.device
    #     input_ids = input_ids.to(device)
        
    #     # 扩展输入以匹配处理器数量
    #     n_processors = len(processors)
    #     expanded_input_ids = input_ids.repeat(n_processors, 1)
        
    #     # 创建并行处理器
    #     parallel_processor = ParallelLogitsProcessor(processors)
        
    #     outputs = self.model.generate(
    #         expanded_input_ids,
    #         max_new_tokens=chunk_size,
    #         logits_processor=LogitsProcessorList([parallel_processor]),
    #         do_sample=True,
    #         top_k=0,
    #         temperature=self.sampling_temp,
    #         return_dict_in_generate=True,
    #         output_scores=True,
    #         pad_token_id=self.tokenizer.eos_token_id
    #     )
        
    #     input_length = input_ids.shape[-1]
    #     processed_outputs = []
        
    #     for i in range(n_processors):
    #         single_sequence = outputs.sequences[i]
    #         generated_tokens = single_sequence[input_length:]
            
    #         if len(generated_tokens) == 0:
    #             processed_outputs.append({
    #                 'text': '',
    #                 'tokens': [],
    #                 'scores': [],
    #                 'sequences': input_ids,
    #                 'finished': False
    #             })
    #             continue
            
    #         # 检查是否遇到EOS token
    #         eos_token_id = self.tokenizer.eos_token_id
    #         finished = False
    #         if eos_token_id in generated_tokens:
    #             eos_pos = (generated_tokens == eos_token_id).nonzero(as_tuple=True)[0][0].item()
    #             generated_tokens = generated_tokens[:eos_pos + 1]
    #             finished = True
            
    #         generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
            
    #         processed_outputs.append({
    #             'text': generated_text,
    #             'tokens': generated_tokens.tolist(),
    #             'scores': [outputs.scores[j][i] for j in range(len(generated_tokens))] if outputs.scores else [],
    #             'sequences': single_sequence.unsqueeze(0),
    #             'finished': finished
    #         })
        
        # return processed_outputs
    @staticmethod
    def calculate_score(green_num_fraction, metrics, weights=[0, 1]):
        return weights[0] * green_num_fraction + weights[1] * metrics
    
    def generate_chunk_unified(self, input_ids, chunk_size, watermark_processors):
        ultimate_chunk_size = chunk_size
        chunk_size += 3
        """
        统一的并行生成方法，同时生成标准版本和加水印版本
        
        Args:
            input_ids: 输入token序列
            chunk_size: 生成的chunk大小
            watermark_processors: 加水印的处理器列表
            
        Returns:
            包含标准输出和所有加水印输出的列表，第一个元素是标准输出
        """
        device = self.device
        input_ids = input_ids.to(device)
        
        # 创建处理器列表：第一个为None（标准生成），后面是加水印的处理器
        all_processors = [None] + watermark_processors
        # all_processors = [None for _ in range(5)]
        n_total = len(all_processors)
        
        # 扩展输入以匹配处理器数量
        expanded_input_ids = input_ids.repeat(n_total, 1)
        
        # 创建并行处理器
        parallel_processor = ParallelLogitsProcessor(all_processors)
        # todo: eos_token for internlm
        if "intern" in self.model_name:
            eos_token_id = [
                self.tokenizer.eos_token_id,
                self.tokenizer.convert_tokens_to_ids(["<eoa>"])[0]
            ]
            
        else:
            eos_token_id = self.tokenizer.eos_token_id
        pad_token_id = self.tokenizer.eos_token_id
        # cache_position = torch.arange(expanded_input_ids.shape[1], dtype=torch.long, device=input_ids.device)
        if self.mode == 'gpt':
            outputs, green_num_list = self.model.generate(
                expanded_input_ids,
                max_new_tokens=chunk_size,
                logits_processor=LogitsProcessorList([parallel_processor]),
                do_sample=True,
                top_k=0,
                top_p=0.9,
                use_cache=True,
                return_dict_in_generate=True,
                # output_scores=True,
                eos_token_id=eos_token_id,
                pad_token_id=pad_token_id,
                past_key_values=self.past_key_values,
                repetition_penalty=1.0
                # cache_position=cache_position
            )
        else:
            outputs, green_num_list = self.model.generate(
                expanded_input_ids,
                max_new_tokens=chunk_size,
                logits_processor=LogitsProcessorList([parallel_processor]),
                do_sample=True,
                top_k=0,
                temperature=self.sampling_temp,
                return_dict_in_generate=True,
                use_cache=True,
                eos_token_id=eos_token_id,
                pad_token_id=pad_token_id,
                past_key_values=self.past_key_values,
                repetition_penalty=1.0
                # cache_position=cache_position
                # output_scores=True,
            )
        # self.past_key_values = outputs.past_key_values
        input_length = input_ids.shape[-1]
        processed_outputs = []
        
        for i in range(n_total):
            b_green_num_list = green_num_list[i]
            available = False
            single_sequence = outputs.sequences[i]
            generated_tokens = single_sequence[input_length:]
            finished = False
            
            # 检查是否遇到EOS token
            if "intern" in self.model_name:
                for eos_id in eos_token_id:
                    if eos_id in generated_tokens:
                        eos_pos = (generated_tokens == eos_id).nonzero(as_tuple=True)[0][0].item()
                        generated_tokens = generated_tokens[:eos_pos]
                        finished = True
                        break
            else:
                # eos_token_id = self.tokenizer.eos_token_id
                # todo: 对于internlm模型，需要添加<eoa>结束符
                if eos_token_id in generated_tokens:
                    eos_pos = (generated_tokens == eos_token_id).nonzero(as_tuple=True)[0][0].item()
                    generated_tokens = generated_tokens[:eos_pos]
                    finished = True
                
            if len(generated_tokens) == 0:
                processed_outputs.append({
                    'text': '',
                    'tokens': [],
                    # 'scores': [],
                    'sequences': input_ids,
                    'finished': False,
                    'is_standard': (i == 0),  # 标记是否为标准生成
                    'available': available,
                    'green_num_fraction': 0
                })
                continue
            
            # self.logger.info(f"生成的token结果为: {generated_tokens}")
            # 解码再编码
            generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
            # self.logger.info(f"解码出的结果为: {generated_text}")
            # if "internlm" in self.model_name:
            #     generated_text = generated_text.split("<eoa>")[0]
            actual_tokens = self.tokenizer.encode(generated_text, return_tensors="pt", truncation=True, add_special_tokens=False)[0].to(device)
            actual_tokens = actual_tokens[:ultimate_chunk_size]
            # tmp_text = self.tokenizer.decode(actual_tokens, skip_special_tokens=True)
            
            # print(f"{generated_tokens=}")
            # print(f"{actual_tokens=}")
            
            # self.logger.info(f"编码出的结果为: {actual_tokens}")
            fraction = 0
            if len(b_green_num_list) > 0:
                if len(actual_tokens) > 0:
                    valid_l = min(len(b_green_num_list), len(actual_tokens))
                    # 可能有误差
                    fraction = sum(b_green_num_list[:valid_l]) / len(actual_tokens)
                else:
                    fraction = 0
                
            # if len(actual_tokens) > chunk_size:
            #     actual_tokens = actual_tokens[:chunk_size]
            #     generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
            #     actual_tokens = self.tokenizer.encode(generated_text, return_tensors="pt", truncation=True, add_special_tokens=False)[0].to(device)
            if len(actual_tokens) == ultimate_chunk_size or finished:
                # available
                available = True
            else:
                available = False
                
            final_sequence = torch.cat((input_ids[0], actual_tokens), dim=0)
            
            processed_outputs.append({
                'text': generated_text,
                # 'tokens': generated_tokens.tolist(),
                'tokens': actual_tokens.tolist(), 
                # 'scores': [outputs.scores[j][i] for j in range(len(actual_tokens))] if outputs.scores else [],
                # 'sequences': single_sequence.unsqueeze(0),
                'sequences': final_sequence.unsqueeze(0),
                'finished': finished,
                'is_standard': (i == 0),  # 标记是否为标准生成
                'available': available,
                'green_num_fraction': fraction
            })
        
        return processed_outputs, outputs.past_key_values
    

    
    def generate(self, input_ids, max_new_tokens):
        self.past_key_values = None
        self.seed_pool = list(range(1, 5000))  # 种子池
        
        # if self.mode == 'new':
        #     # 记录整体开始时间
        #     total_start_time = time.time()
            
        #     # 修改后的统一chunk生成逻辑
        #     device = self.device
        #     current_input_ids = input_ids.clone().to(device)
        #     total_generated = 0
        #     chunk_count = 0
            
        #     # 初始化收集器来存储所有生成的信息
        #     all_scores = []
        #     all_output_ids = []
        #     all_vocabularies = []
            
        #     # 时间统计变量
        #     total_generation_time = 0
        #     total_rouge_time = 0
            
        #     # 新增：记录清理状态
        #     previous_total_clean_length = 0
            
        #     logger.info(f"开始按chunk生成，chunk大小: {self.chunk_size}")
        #     logger.info(f"模型设备: {device}, 输入设备: {current_input_ids.device}")
            
        #     while total_generated < max_new_tokens:
        #         remaining_tokens = max_new_tokens - total_generated
        #         current_chunk_size = min(self.chunk_size, remaining_tokens)
                
        #         logger.info(f"=== 生成第 {chunk_count + 1} 个chunk (大小: {current_chunk_size}) ===")
                
        #         # 1. 选择前K个种子进行加水印生成
        #         random.shuffle(self.seed_pool)
        #         cur_seeds = self.seed_pool[:self.K]
        #         logger.info(f'选择种子: {cur_seeds}')
                
        #         # 2. 创建K个不同种子的处理器
        #         processors_list = self.create_logits_processor_list(cur_seeds)
        #         # 提取单个processor用于并行
        #         single_processors = [proc[0] for proc in processors_list]
                
        #         # 3. 统一并行生成（包含标准版本和加水印版本）
        #         logger.info("统一并行生成...")
        #         generation_start_time = time.time()
                
        #         try:
        #             unified_outputs = self.generate_chunk_unified(
        #                 current_input_ids, 
        #                 current_chunk_size,
        #                 single_processors
        #             )
        #             generation_time = time.time() - generation_start_time
        #             total_generation_time += generation_time
        #             logger.info(f"统一生成用时: {generation_time:.2f}秒")
                    
        #         except Exception as e:
        #             logger.error(f"统一生成失败: {e}")
        #             break
                
        #         # 4. 分离标准输出和加水印输出
        #         standard_outputs = unified_outputs[0]  # 第一个是标准输出
        #         watermark_outputs_list = unified_outputs[1:]  # 其余是加水印输出
                
        #         if len(standard_outputs['tokens']) == 0:
        #             logger.warning("标准生成没有生成新token，停止生成")
        #             break
                    
        #         logger.info(f"标准生成内容: {standard_outputs['text']}")
                
        #         # 5. 计算每个候选的ROUGE-L F1分数并选择最佳
        #         rouge_start_time = time.time()
        #         best_f1 = -1.0
        #         best_outputs = None
        #         best_processor_idx = None
        #         successful_generations = 0
                
        #         for i, watermark_outputs in enumerate(watermark_outputs_list):
        #             if watermark_outputs is None or len(watermark_outputs['tokens']) == 0:
        #                 logger.warning(f"种子 {cur_seeds[i]} 没有生成新token，跳过")
        #                 continue
                    
        #             try:
        #                 # 计算ROUGE-L F1分数
        #                 rouge_l = self.metrics.calculate_rouge_l(
        #                     standard_outputs['text'], 
        #                     watermark_outputs['text']
        #                 )
        #                 f1 = rouge_l['f1']
        #                 logger.info(f"种子 {cur_seeds[i]}: F1={f1:.4f}, finished={watermark_outputs['finished']}")
        #                 successful_generations += 1
                        
        #                 if f1 > best_f1:
        #                     best_f1 = f1
        #                     best_outputs = watermark_outputs
        #                     best_processor_idx = i
                            
        #             except Exception as e:
        #                 logger.error(f"种子 {cur_seeds[i]} ROUGE计算失败: {e}")
        #                 continue
                
        #         rouge_time = time.time() - rouge_start_time
        #         total_rouge_time += rouge_time
                
        #         # 6. 检查是否有成功的生成
        #         if best_outputs is None or successful_generations == 0:
        #             logger.error("所有加水印生成都失败，停止生成")
        #             break
                
        #         logger.info(f"选择最佳结果 (F1={best_f1:.4f}, 种子={cur_seeds[best_processor_idx]}): {best_outputs['text']}")
                
        #         # 7. 方案3：在外层处理清理逻辑
        #         # 7.1 使用best_outputs的序列作为临时结果
        #         temp_input_ids = best_outputs['sequences'].to(device)
                
        #         # 7.2 暂时使用原始结果，稍后在验证步骤中进行最终清理
        #         current_input_ids = best_outputs['sequences'].to(device)
                
        #         # 7.3 计算基于原始结果的token数（用于种子记录）
        #         temp_length = len(current_input_ids[0]) - len(input_ids[0])
        #         temp_chunk_length = temp_length - previous_total_clean_length
                
        #         # 7.4 记录临时的种子使用（基于原始token数）
        #         if temp_chunk_length > 0:
        #             logger.info(f"临时记录种子使用 (种子={cur_seeds[best_processor_idx]}, 原始chunk长度={temp_chunk_length})")
                
        #         # 7.5 进行最终验证和清理
        #         verification_text = self.tokenizer.decode(current_input_ids[0][len(input_ids[0]):], skip_special_tokens=True)
        #         verification_tokens = self.tokenizer.encode(verification_text, return_tensors="pt", 
        #                                                 truncation=True, add_special_tokens=False)[0].to(device)
                
        #         if torch.equal(current_input_ids[0][len(input_ids[0]):], verification_tokens):
        #             logger.info("序列一致性验证通过 ✓ (无需清理)")
        #             final_clean_tokens = current_input_ids[0][len(input_ids[0]):]
        #         else:
        #             logger.info("序列一致性验证失败，进行清理 ⚠️")
        #             logger.info(f"清理前 tokens: {current_input_ids[0][len(input_ids[0]):]}")
        #             logger.info(f"清理后 tokens: {verification_tokens}")
                    
        #             # 分析变化
        #             original_tokens = current_input_ids[0][len(input_ids[0]):]
        #             logger.info(f"清理变化分析:")
        #             logger.info(f"  长度变化: {len(original_tokens)} -> {len(verification_tokens)}")
        #             if len(original_tokens) > 0 and len(verification_tokens) > 0:
        #                 if original_tokens[0] != verification_tokens[0]:
        #                     logger.info(f"  第一个token变化: {original_tokens[0].item()} -> {verification_tokens[0].item()}")
                    
        #             # 检查29871的移除情况
        #             original_29871_count = (original_tokens == 29871).sum().item()
        #             final_29871_count = (verification_tokens == 29871).sum().item()
        #             if original_29871_count != final_29871_count:
        #                 logger.info(f"  29871 token数量变化: {original_29871_count} -> {final_29871_count}")
                    
        #             # 使用清理后的结果
        #             final_clean_tokens = verification_tokens
        #             current_input_ids = torch.cat((input_ids[0], verification_tokens), dim=0).unsqueeze(0)
                
        #         # 7.6 计算最终的chunk长度和种子记录
        #         final_total_length = len(final_clean_tokens)
        #         actual_chunk_length = final_total_length - previous_total_clean_length
                
        #         logger.info(f"清理前总长度: {temp_length}")
        #         logger.info(f"清理后总长度: {final_total_length}")
        #         logger.info(f"当前chunk实际覆盖token数: {actual_chunk_length}")
                
        #         # 7.7 最终记录种子使用（基于清理后的token数）
        #         if actual_chunk_length > 0:
        #             self.used_seed.append(cur_seeds[best_processor_idx])
        #             if actual_chunk_length < self.chunk_size:
        #                 self.used_seed.append(actual_chunk_length - self.chunk_size)
                
        #             # 7.8 收集当前chunk的tokens信息（从清理后的序列中提取）
        #             current_chunk_tokens = final_clean_tokens[previous_total_clean_length:final_total_length]
        #             all_output_ids.extend(current_chunk_tokens.tolist())
                    
        #             # 获取vocabulary信息
        #             best_processor = processors_list[best_processor_idx][0]
        #             if hasattr(best_processor, 'get_and_clear_vocabularys'):
        #                 vocabularies = best_processor.get_and_clear_vocabularys()
        #                 # 根据实际长度调整vocabulary
        #                 if len(vocabularies) >= actual_chunk_length:
        #                     all_vocabularies.extend(vocabularies[:actual_chunk_length])
        #                 else:
        #                     all_vocabularies.extend(vocabularies + [None] * (actual_chunk_length - len(vocabularies)))
        #             else:
        #                 all_vocabularies.extend([None] * actual_chunk_length)
                
        #         # 7.9 更新状态
        #         previous_total_clean_length = final_total_length
        #         total_generated += actual_chunk_length
        #         chunk_count += 1
                
        #         logger.info(f"已生成 {total_generated}/{max_new_tokens} tokens")
                
        #         # 8. 检查是否需要结束生成
        #         if best_outputs['finished']:
        #             logger.info("遇到EOS token，结束生成")
        #             break
            
        #     # 计算logprob和构造返回值
        #     completions_tokens = []
            
        #     # 确保两个列表长度一致
        #     min_length = min(len(all_output_ids), len(all_vocabularies))
            
        #     for i in range(min_length):
        #         token = all_output_ids[i]
        #         vocabulary = all_vocabularies[i]
                
        #         completions_tokens.append({
        #             'text': self.tokenizer.decode(token),
        #             'vocabulary': vocabulary,
        #         })

        #     # 解码所有生成的tokens获得完整文本
        #     completions_text = self.tokenizer.decode(all_output_ids, skip_special_tokens=True)
            
        #     # 计算总耗时
        #     total_time = time.time() - total_start_time
            
        #     # 详细的时间统计日志
        #     logger.info("="*60)
        #     logger.info("生成完成 - 时间统计")
        #     logger.info("="*60)
        #     logger.info(f"总耗时: {total_time:.2f}秒")
        #     logger.info(f"统一生成总耗时: {total_generation_time:.2f}秒 ({total_generation_time/total_time*100:.1f}%)")
        #     logger.info(f"ROUGE计算总耗时: {total_rouge_time:.2f}秒 ({total_rouge_time/total_time*100:.1f}%)")
        #     logger.info(f"其他处理耗时: {total_time - total_generation_time - total_rouge_time:.2f}秒")
        #     logger.info(f"生成效率: {len(all_output_ids)/total_time:.2f} tokens/秒")
        #     logger.info(f"总共生成: {chunk_count} 个chunks，{len(all_output_ids)} 个tokens")
        #     logger.info(f"平均每chunk耗时: {total_time/chunk_count:.2f}秒" if chunk_count > 0 else "无chunk生成")
        #     logger.info("="*60)
        #     logger.info(f"完整生成内容: {completions_text}")
        #     logger.info("="*60)
            
        #     return completions_text, completions_tokens
        
        # else:    
        # 保持原有的其他mode逻辑不变
        if self.mode == 'no':
            gen_config = GenerationConfig(
                max_new_tokens=max_new_tokens,
                return_dict_in_generate=True,
                output_scores=True
            )
            outputs, _ = self.model.generate(
                input_ids=input_ids, generation_config=gen_config, 
            )
            scores = outputs.scores
            output_ids = outputs.sequences[0, -len(scores):]
            list_data = output_ids.cpu().tolist()
            self.logger.info(f"生成的token序列为: {list_data}")
            # compute logprob for each token
            completions_tokens = []
            completions_logprob = 0
            for score, token in zip(scores, output_ids):
                logprobs = F.log_softmax(score[0], dim=-1)
                logprob = logprobs[token].item()
                completions_tokens.append({
                    'text': self.tokenizer.decode(token),
                    'logprob': logprob,
                })
                completions_logprob += logprob
            completions_text = self.tokenizer.decode(output_ids, skip_special_tokens=True)
            return completions_text, completions_tokens
        else:
            # 记录整体开始时间
            total_start_time = time.time()
            
            # 修改后的统一chunk生成逻辑（其他watermark模式）
            device = self.device
            current_input_ids = input_ids.clone().to(device)
            total_generated = 0
            chunk_count = 0
            
            # 初始化收集器来存储所有生成的信息
            # all_scores = []
            all_output_ids = []
            
            # 时间统计变量
            total_generation_time = 0
            total_rouge_time = 0
            
            # 新增：记录清理状态
            
            self.logger.info(f"开始按chunk生成，chunk大小: {self.chunk_size}")
            self.logger.info(f"模型设备: {device}, 输入设备: {current_input_ids.device}")
            
            while total_generated < max_new_tokens:
                remaining_tokens = max_new_tokens - total_generated
                current_chunk_size = min(self.chunk_size, remaining_tokens)
                
                self.logger.info(f"=== 生成第 {chunk_count + 1} 个chunk (大小: {current_chunk_size}) ===")
                
                # 1. 选择前K个种子进行加水印生成
                # todo: 修改种子
                pre_seed = 1
                for pre_token in current_input_ids[0][-4:]:
                    pre_seed *= (pre_token.item() + 1)
                random.seed(Generator.simple_hash_safe(pre_seed))
                random.shuffle(self.seed_pool)
                cur_seeds = self.seed_pool[:self.K]
                self.logger.info(f'选择种子: {cur_seeds}')
                
                # 2. 创建K个不同种子的处理器
                processors_list = self.create_logits_processor_list(cur_seeds)
                # 提取单个processor用于并行
                single_processors = [proc[0] for proc in processors_list]
                
                # 3. 统一并行生成（包含标准版本和加水印版本）
                self.logger.info("统一并行生成...")
                generation_start_time = time.time()
                
                unified_outputs, self.past_key_values = self.generate_chunk_unified(
                    current_input_ids, 
                    current_chunk_size,
                    single_processors
                )
                generation_time = time.time() - generation_start_time
                total_generation_time += generation_time
                self.logger.info(f"统一生成用时: {generation_time:.2f}秒")
                    
                
                # 4. 分离标准输出和加水印输出
                standard_outputs = unified_outputs[0]  # 第一个是标准输出
                watermark_outputs_list = unified_outputs[1:]  # 其余是加水印输出
                
                if len(standard_outputs['tokens']) == 0:
                    self.logger.warning("标准生成没有生成新token，停止生成")
                    break
                    
                self.logger.info(f"标准生成内容: {standard_outputs['text']}")
                
                # 5. 计算每个候选的ROUGE-L F1分数并选择最佳
                rouge_start_time = time.time()

                
                # 候选序列
                candidate_list = []
                
                for i, watermark_outputs in enumerate(watermark_outputs_list):
                    if watermark_outputs is None or len(watermark_outputs['tokens']) == 0:
                        self.logger.warning(f"种子 {cur_seeds[i]} 没有生成新token，跳过")
                        continue
                    # if not watermark_outputs["available"]:
                    #     self.logger.warning(f"种子 {cur_seeds[i]} 不可用，跳过")
                    #     continue
                    # try:
                    # 计算ROUGE-L F1分数
                    rouge_l = self.metrics.calculate_rouge_l(
                        standard_outputs['text'], 
                        watermark_outputs['text']
                    )
                    f1 = Generator.calculate_score(green_num_fraction=watermark_outputs["green_num_fraction"], 
                                metrics=rouge_l['f1'],
                                weights=[self.alpha, 1 - self.alpha])
                    self.logger.info(f"种子 {cur_seeds[i]}: F1={f1:.4f}, finished={watermark_outputs['finished']}")
                    self.logger.info(f"生成的文本为: {watermark_outputs['text']}")
                    # if f1 > best_f1:
                    #     best_f1 = f1
                    #     best_outputs = watermark_outputs
                    #     best_processor_idx = i
                    candidate_list.append((i, f1, watermark_outputs["available"]))
                        
                            
                    # except Exception as e:
                    #     self.logger.error(f"种子 {cur_seeds[i]} ROUGE计算失败: {e}")
                    #     continue
                    
                # 6. 检查是否有成功的生成
                if len(candidate_list) == 0:
                    self.logger.error("所有加水印生成都为空，结束生成")
                    break
                    
                available_list = []
                for candidate in candidate_list:
                    # available
                    if candidate[2]:
                        # idx, f1
                        available_list.append((candidate[0], candidate[1]))
                if len(available_list) > 0:
                    # 有"available"==True的
                    best_f1 = -1.0
                    best_processor_idx = None
                    for candidate in available_list:
                        if best_f1 < candidate[1]:
                            best_f1 = candidate[1]
                            best_processor_idx = candidate[0]
                else:
                    # 全部失败，选一个质量最好的
                    best_f1 = -1.0
                    best_processor_idx = None
                    for candidate in candidate_list:
                        if best_f1 < candidate[1]:
                            best_f1 = candidate[1]
                            best_processor_idx = candidate[0]
                    self.logger.info(f"全部编解码不一致，选择一个分数最高的")
                            
                best_outputs = watermark_outputs_list[best_processor_idx]
                
                rouge_time = time.time() - rouge_start_time
                total_rouge_time += rouge_time
                

                    
                # if len(standard_outputs['tokens']) == 0:
                #     break
                
                self.logger.info(f"选择最佳结果 (F1={best_f1:.4f}, 种子={cur_seeds[best_processor_idx]}): {best_outputs['text']}")
                self.logger.info(f"生成的token序列: {best_outputs['tokens']}") 
                
                # current_input_ids = torch.tensor(best_outputs['tokens']).to(device)
                current_input_ids = best_outputs['sequences'].to(device)
                # current_input_ids = best_outputs['sequences'][0][current_input_ids.shape[1]:].unsqueeze(0).to(device)
                # all_inputs = best_outputs['sequences'].to(device)
                total_generated += len(best_outputs["tokens"])
                all_output_ids.extend(best_outputs["tokens"])
                chunk_count += 1
                
                # 释放past_key_values
                # if hasattr(self, 'past_key_values') and self.past_key_values is not None:
                #     del self.past_key_values
                #     torch.cuda.empty_cache()
                
                # single_past = tuple(
                #     (
                #         layer_key[best_processor_idx + 1:best_processor_idx + 2, ...],   # shape: (1, num_heads, seq_len, head_dim)
                #         layer_value[best_processor_idx + 1:best_processor_idx + 2, ...]
                #     )
                #     for layer_key, layer_value in self.past_key_values
                # )
                # truncated_past = tuple(
                #     (
                #         layer_k[:, :, :current_input_ids.shape[1], :].contiguous(),
                #         layer_v[:, :, :current_input_ids.shape[1], :].contiguous()
                #     )
                #     for layer_k, layer_v in single_past
                # )

                # # C. 扩展到 batch_size 大小
                # self.past_key_values = tuple(
                #     (
                #         layer_k.repeat(self.K + 1, *[1] * (layer_k.ndim - 1)),
                #         layer_v.repeat(self.K + 1, *[1] * (layer_v.ndim - 1))
                #     )
                #     for layer_k, layer_v in truncated_past
                # )
                
                
                # def create_dynamic_cache(past_key_values, num_layers):
                #     cache = DynamicCache()
                #     for layer_idx, (key, value) in enumerate(past_key_values):
                #         # 确保key和value的形状正确: (batch_size, num_heads, seq_len, head_dim)
                #         cache.update(key, value, layer_idx)
                #     return cache

                # # 修改你的代码
                # single_past = tuple(
                #     (
                #         layer_key[best_processor_idx + 1:best_processor_idx + 2, ...],
                #         layer_value[best_processor_idx + 1:best_processor_idx + 2, ...]
                #     )
                #     for layer_key, layer_value in self.past_key_values
                # )

                # truncated_past = tuple(
                #     (
                #         layer_k[:, :, :current_input_ids.shape[1], :].contiguous(),
                #         layer_v[:, :, :current_input_ids.shape[1], :].contiguous()
                #     )
                #     for layer_k, layer_v in single_past
                # )

                # expanded_past = tuple(
                #     (
                #         layer_k.repeat(self.K + 1, *[1] * (layer_k.ndim - 1)),
                #         layer_v.repeat(self.K + 1, *[1] * (layer_v.ndim - 1))
                #     )
                #     for layer_k, layer_v in truncated_past
                # )

                # # 创建DynamicCache对象
                # num_layers = len(expanded_past)
                # self.past_key_values = create_dynamic_cache(expanded_past, num_layers)

                
                # # 释放之前的single_past和truncated_past
                # del single_past
                # del truncated_past
                # torch.cuda.empty_cache()
                

                
                self.logger.info(f"已生成 {total_generated}/{max_new_tokens} tokens")
                # 8. 检查是否需要结束生成
                if best_outputs['finished']:
                    self.logger.info("遇到EOS token，结束生成")
                    break
                
                original_input_length = current_input_ids.shape[1] - len(best_outputs["tokens"])
                # print(original_input_length)
                self.update_past_key_values_correctly(
                    best_processor_idx, 
                    current_input_ids, 
                    original_input_length
                )
        
            # 计算logprob和构造返回值
            completions_tokens = []
            
            # 使用output_ids的长度
            min_length = len(all_output_ids)
            
            for i in range(min_length):
                token = all_output_ids[i]
                
                completions_tokens.append({
                    'text': self.tokenizer.decode(token),
                })

            # 解码所有生成的tokens获得完整文本
            completions_text = self.tokenizer.decode(all_output_ids, skip_special_tokens=True)
            # all_tensor = torch.tensor(all_output_ids,dtype=input_ids.dtype).to(device)
            # full_text = self.tokenizer.decode(torch.cat([input_ids[0], all_tensor]), skip_special_tokens=True)
            # self.logger.info(f"完整文本为: {full_text}")
            # 计算总耗时
            total_time = time.time() - total_start_time
            
            # 详细的时间统计日志
            self.logger.info("="*60)
            self.logger.info("生成完成 - 时间统计")
            self.logger.info("="*60)
            self.logger.info(f"总耗时: {total_time:.2f}秒")
            self.logger.info(f"统一生成总耗时: {total_generation_time:.2f}秒 ({total_generation_time/total_time*100:.1f}%)")
            self.logger.info(f"ROUGE计算总耗时: {total_rouge_time:.2f}秒 ({total_rouge_time/total_time*100:.1f}%)")
            self.logger.info(f"其他处理耗时: {total_time - total_generation_time - total_rouge_time:.2f}秒")
            self.logger.info(f"生成效率: {len(all_output_ids)/total_time:.2f} tokens/秒")
            self.logger.info(f"总共生成: {chunk_count} 个chunks，{len(all_output_ids)} 个tokens")
            self.logger.info(f"平均每chunk耗时: {total_time/chunk_count:.2f}秒" if chunk_count > 0 else "无chunk生成")
            self.logger.info("="*60)
            self.logger.info(f"完整生成内容: {completions_text}")
            self.logger.info("="*60)
            
            return completions_text, completions_tokens
        
    def update_past_key_values_correctly(self, best_processor_idx, current_input_ids, original_input_length):
        """
        正确更新past_key_values，确保维度匹配和模型兼容性
        
        Args:
            best_processor_idx: 选择的最佳处理器索引
            current_input_ids: 当前的input_ids（包含新生成的tokens）
            original_input_length: 原始输入的长度（在生成这个chunk之前）
        """
        if self.past_key_values is None:
            return
        
        # 1. 提取最佳处理器对应的cache（单batch）
        single_past = tuple(
            (
                layer_key[best_processor_idx + 1:best_processor_idx + 2, ...],   # +1是因为第一个是标准生成
                layer_value[best_processor_idx + 1:best_processor_idx + 2, ...]
            )
            for layer_key, layer_value in self.past_key_values
        )
        
        # 2. 关键修复：只保留到原始输入长度的cache，不包含刚生成的tokens
        # 这样下次生成时，cache_position可以正确计算新token的位置
        truncated_past = tuple(
            (
                layer_k[:, :, :current_input_ids.shape[1] - 1, :].contiguous(),  # 只到原始输入长度
                layer_v[:, :, :current_input_ids.shape[1] - 1, :].contiguous()
            )
            for layer_k, layer_v in single_past
        )
        
        # 3. 扩展到batch_size大小以供下次并行生成使用
        expanded_past = tuple(
            (
                layer_k.repeat(self.K + 1, *[1] * (layer_k.ndim - 1)),
                layer_v.repeat(self.K + 1, *[1] * (layer_v.ndim - 1))
            )
            for layer_k, layer_v in truncated_past
        )
        
        # 4. 检查模型类型并使用合适的cache格式
        if "internlm" in self.model_name.lower():
            # InternLM模型使用传统tuple格式
            self.past_key_values = expanded_past
        else:
            # 其他模型可以使用DynamicCache
            def create_dynamic_cache(past_key_values, num_layers):
                cache = DynamicCache()
                for layer_idx, (key, value) in enumerate(past_key_values):
                    cache.update(key, value, layer_idx)
                return cache
            
            num_layers = len(expanded_past)
            self.past_key_values = create_dynamic_cache(expanded_past, num_layers)
        
        # 清理临时变量
        del single_past, truncated_past, expanded_past