# import json
# import torch
# from transformers import AutoModelForCausalLM, AutoTokenizer
# import re

# class QwenEvaluator:
#     def __init__(self, model_path):
#         """初始化Qwen模型"""
#         print("正在加载Qwen模型...")
#         self.tokenizer = AutoTokenizer.from_pretrained(
#             model_path, 
#             trust_remote_code=True
#         )
#         self.model = AutoModelForCausalLM.from_pretrained(
#             model_path,
#             torch_dtype=torch.float16,
#             device_map="auto",
#             trust_remote_code=True
#         )
#         print("模型加载完成!")
    
#     def load_jsonl(self, file_path):
#         """读取jsonl文件"""
#         data = []
#         with open(file_path, 'r', encoding='utf-8') as f:
#             for line in f:
#                 data.append(json.loads(line.strip()))
#         return data
    
#     def create_prompt(self, question, choices):
#         """创建问答prompt"""
#         choice_text = ""
#         for label, text in zip(choices["label"], choices["text"]):
#             choice_text += f"{label}. {text}\n"
        
#         prompt = f"""请回答以下选择题，只需要返回选项字母(A、B、C、D或E)。

#                 问题: {question}

#                 选项:
#                 {choice_text}
#                 答案:"""
#         # print(prompt)
#         return prompt
    
#     def extract_answer(self, response):
#         """从模型回复中提取答案"""
#         # 查找第一个出现的选项字母
#         print(response)
#         matches = re.findall(r'[ABCDE]', response.upper())
#         if matches:
#             answer = matches[-1]
#             print(f"提取到答案: {answer} (共找到 {len(matches)} 个字母: {matches})")
#         return answer
    
#     def model_predict(self, question, choices):
#         """使用Qwen模型进行预测"""
#         prompt = self.create_prompt(question, choices)
        
#         try:
#             # 对话格式
import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import re
from concurrent.futures import ThreadPoolExecutor
import threading
import time

class QwenEvaluator:
    def __init__(self, model_path, batch_size=4):
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_path, 
            trust_remote_code=True,
            padding_side="left" 
        )
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
            
        self.model = AutoModelForCausalLM.from_pretrained(
            model_path,
            torch_dtype=torch.float16,
            device_map="auto",
            trust_remote_code=True
        )
        self.batch_size = batch_size
        self.lock = threading.Lock() 
        print("模型加载完成!")
    
    def load_jsonl(self, file_path):
        """读取jsonl文件"""
        data = []
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                data.append(json.loads(line.strip()))
        return data
    
    def create_prompt(self, question, choices):
        """创建问答prompt"""
        choice_text = ""
        for label, text in zip(choices["label"], choices["text"]):
            choice_text += f"{label}. {text}\n"
        
        prompt = f"""请回答以下选择题，只需要返回选项字母(A、B、C、D或E)。

问题: {question}

选项:
{choice_text}
答案:"""
        return prompt
    
    def extract_answer(self, response):
        """从模型回复中提取答案"""
        matches = re.findall(r'[ABCDE]', response.upper())
        if matches:
            answer = matches[-1]
            return answer
        return "A"
    
    def batch_predict(self, batch_data):
        """批量预测"""
        prompts = []
        for item in batch_data:
            prompt = self.create_prompt(item["question"], item["choices"])
            prompts.append(prompt)
        
        try:
            with self.lock:  # 保护GPU访问
                # 准备批量输入
                messages_batch = []
                for prompt in prompts:
                    messages = [{"role": "user", "content": prompt}]
                    text = self.tokenizer.apply_chat_template(
                        messages, tokenize=False, add_generation_prompt=True
                    )
                    messages_batch.append(text)
                
                # 批量编码
                model_inputs = self.tokenizer(
                    messages_batch, 
                    return_tensors="pt", 
                    padding=True,
                    truncation=True,
                    max_length=2048
                ).to(self.model.device)
                
                # 记录输入tokens数
                input_tokens_count = model_inputs.input_ids.shape[1] * len(batch_data)
                
                # 批量生成
                with torch.no_grad():
                    generated_ids = self.model.generate(
                        **model_inputs,
                        max_new_tokens=512,
                        do_sample=True,
                        pad_token_id=self.tokenizer.pad_token_id,
                        eos_token_id=self.tokenizer.eos_token_id,
                        return_dict_in_generate=True,
                        output_scores=False
                    )
                
                # 计算输出tokens数
                output_tokens_count = 0
                input_lengths = model_inputs.input_ids.shape[1]
                
                responses = []
                for i, output_ids in enumerate(generated_ids.sequences):
                    # 计算每个样本的输出tokens数
                    output_length = len(output_ids) - input_lengths
                    output_tokens_count += output_length
                    
                    response = self.tokenizer.decode(
                        output_ids[input_lengths:], 
                        skip_special_tokens=True
                    )
                    responses.append(response)
            
            # 提取答案
            answers = []
            for response in responses:
                answer = self.extract_answer(response)
                answers.append(answer)
            
            # 返回答案和tokens统计信息
            batch_stats = {
                'answers': answers,
                'input_tokens': input_tokens_count,
                'output_tokens': output_tokens_count,
                'total_tokens': input_tokens_count + output_tokens_count,
                'batch_size': len(batch_data)
            }
            
            return batch_stats
            
        except Exception as e:
            print(f"批量预测错误: {e}")
            # 返回默认答案和统计信息
            return {
                'answers': ["A"] * len(batch_data),
                'input_tokens': 0,
                'output_tokens': 0,
                'total_tokens': 0,
                'batch_size': len(batch_data)
            }
    
    def evaluate(self, jsonl_file_path):
        """评估模型准确率并统计tokens使用情况"""
        test_data = self.load_jsonl(jsonl_file_path)
        total = len(test_data)
        correct = 0
        
        # 统计tokens使用情况
        total_input_tokens = 0
        total_output_tokens = 0
        total_tokens = 0
        total_batches = 0
        
        print(f"开始评估，共{total}道题，批量大小: {self.batch_size}")
        start_time = time.time()
        
        # 分批处理
        for i in range(0, total, self.batch_size):
            batch_data = test_data[i:i+self.batch_size]
            batch_result = self.batch_predict(batch_data)
            
            batch_answers = batch_result['answers']
            
            # 更新tokens统计
            total_input_tokens += batch_result['input_tokens']
            total_output_tokens += batch_result['output_tokens']
            total_tokens += batch_result['total_tokens']
            total_batches += 1
            
            # 统计准确率
            for j, (item, predicted_answer) in enumerate(zip(batch_data, batch_answers)):
                true_answer = item["answerKey"]
                if predicted_answer == true_answer:
                    correct += 1
                
                # 显示进度
                current_idx = i + j + 1
                if current_idx % 10 == 0 or current_idx == total:  # 每10题显示一次进度
                    print(f"题目 {current_idx}/{total}: 预测={predicted_answer}, 正确={true_answer}, {'✓' if predicted_answer == true_answer else '✗'}")
            
            # 显示当前统计信息
            current_acc = correct / (i + len(batch_data))
            avg_input_tokens = total_input_tokens / (i + len(batch_data)) if (i + len(batch_data)) > 0 else 0
            avg_output_tokens = total_output_tokens / (i + len(batch_data)) if (i + len(batch_data)) > 0 else 0
            avg_total_tokens = total_tokens / (i + len(batch_data)) if (i + len(batch_data)) > 0 else 0
            
            if (i // self.batch_size) % 1 == 0:  # 每5个batch显示一次统计
                print(f"进度: {i + len(batch_data)}/{total} | 准确率: {current_acc:.2%} | "
                      f"平均输入tokens: {avg_input_tokens:.1f} | 平均输出tokens: {avg_output_tokens:.1f} | "
                      f"平均总tokens: {avg_total_tokens:.1f}")
        
        # 计算最终统计信息
        end_time = time.time()
        total_time = end_time - start_time
        
        accuracy = correct / total if total > 0 else 0
        avg_input_tokens_per_sample = total_input_tokens / total if total > 0 else 0
        avg_output_tokens_per_sample = total_output_tokens / total if total > 0 else 0
        avg_total_tokens_per_sample = total_tokens / total if total > 0 else 0
        tokens_per_second = total_tokens / total_time if total_time > 0 else 0
        samples_per_second = total / total_time if total_time > 0 else 0
        
        # 打印详细统计报告
        print(f"\n{'='*60}")
        print(f"=== 最终评估结果 ===")
        print(f"{'='*60}")
        print(f"总题数: {total}")
        print(f"正确数: {correct}")
        print(f"准确率: {accuracy:.4f} ({accuracy:.2%})")
        print(f"\n=== Tokens 使用统计 ===")
        print(f"总输入 tokens: {total_input_tokens}")
        print(f"总输出 tokens: {total_output_tokens}")
        print(f"总 tokens: {total_tokens}")
        print(f"\n=== 平均值统计 ===")
        print(f"平均输入 tokens/样本: {avg_input_tokens_per_sample:.2f}")
        print(f"平均输出 tokens/样本: {avg_output_tokens_per_sample:.2f}")
        print(f"平均总 tokens/样本: {avg_total_tokens_per_sample:.2f}")
        print(f"\n=== 性能统计 ===")
        print(f"总耗时: {total_time:.2f} 秒")
        print(f"处理速度: {samples_per_second:.2f} 样本/秒")
        print(f"Tokens 处理速度: {tokens_per_second:.2f} tokens/秒")
        print(f"批量大小: {self.batch_size}")
        print(f"总批次数: {total_batches}")
        print(f"{'='*60}")
        
        # 返回详细结果
        results = {
            'accuracy': accuracy,
            'total_samples': total,
            'correct_samples': correct,
            'tokens_statistics': {
                'total_input_tokens': total_input_tokens,
                'total_output_tokens': total_output_tokens,
                'total_tokens': total_tokens,
                'avg_input_tokens_per_sample': avg_input_tokens_per_sample,
                'avg_output_tokens_per_sample': avg_output_tokens_per_sample,
                'avg_total_tokens_per_sample': avg_total_tokens_per_sample
            },
            'performance_statistics': {
                'total_time_seconds': total_time,
                'samples_per_second': samples_per_second,
                'tokens_per_second': tokens_per_second,
                'batch_size': self.batch_size,
                'total_batches': total_batches
            }
        }
        
        return results

def save_results(results, output_file="evaluation_results.json"):
    """保存评估结果到文件"""
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(results, f, indent=2, ensure_ascii=False)
    print(f"\n评估结果已保存到: {output_file}")

def main():
    model_path = ""
    jsonl_path = ""
    output_file = ""
    
    
    batch_size = 128  
    
    evaluator = QwenEvaluator(model_path, batch_size=batch_size)
    results = evaluator.evaluate(jsonl_path)
    
    # 保存结果
    save_results(results, output_file)
    
    return results

if __name__ == "__main__":
    results = main()