#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
AI判断脚本 - 对问答对进行质量评估
从all_merged_results.json中读取数据，对每个问题的答案进行两两配对，
然后使用Oracle Judge对每个配对进行偏好判断
"""

import json
import os
import time
import shutil
from datetime import datetime
from openai import OpenAI
import traceback
from itertools import combinations
import asyncio
import aiohttp
from concurrent.futures import ThreadPoolExecutor
import threading
from typing import Optional, Dict, Any

class AIJudgeResponses:
    def __init__(self, api_key, base_url="https://aihubmix.com/v1", max_concurrent=5, max_retries=3):
        """初始化AI客户端"""
        self.client = OpenAI(
            api_key=api_key,
            base_url=base_url
        )
        self.api_key = api_key
        self.base_url = base_url
        self.batch_size = 20  # 增加批次大小
        self.max_concurrent = max_concurrent  # 最大并发数
        self.max_retries = max_retries  # 最大重试次数
        self.retry_delay = 1.0  # 重试延迟（秒）
        self.lock = threading.Lock()  # 用于线程安全的计数器
        
    def load_classified_results(self, file_path):
        """加载按模型对分类的结果文件"""
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                data = json.load(f)
            print(f"成功加载文件: {file_path}")
            
            total_samples = sum(len(samples) for samples in data.values())
            print(f"包含 {len(data)} 个模型对")
            print(f"总共 {total_samples} 个样本")
            
            # 打印每个模型对的样本数量
            for pair_key, samples in data.items():
                print(f"  {pair_key}: {len(samples)} 个样本")
            
            return data
        except Exception as e:
            print(f"加载文件失败: {e}")
            return None
    
    def load_existing_judgments(self, file_path):
        """加载已存在的判断结果文件（按模型对分类的格式）"""
        if not os.path.exists(file_path):
            print(f"未找到已存在的判断文件: {file_path}")
            return None
        
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                data = json.load(f)
            print(f"成功加载已存在的判断文件: {file_path}")
            
            total_processed = sum(
                len([sample for sample in samples if sample.get('oracle_preference') is not None]) 
                for samples in data.values()
            )
            total_samples = sum(len(samples) for samples in data.values())
            print(f"已处理 {total_processed}/{total_samples} 个样本")
            return data
        except Exception as e:
            print(f"加载已存在判断文件失败: {e}")
            return None
    
    def find_samples_to_process(self, input_data, existing_data):
        """找出需要进行Oracle判断的样本"""
        samples_to_process = []
        
        for pair_key, samples in input_data.items():
            print(f"\n检查模型对: {pair_key}")
            
            existing_samples = existing_data.get(pair_key, []) if existing_data else []
            
            # 创建已处理样本的索引（基于问题+回答对的组合）
            processed_samples = {}
            for sample in existing_samples:
                if sample.get('oracle_preference') is not None:
                    # 使用问题和两个回答的组合作为唯一标识
                    key = (
                        sample.get('question', ''),
                        sample.get('response_pair', {}).get('response_a', {}).get('answer', ''),
                        sample.get('response_pair', {}).get('response_b', {}).get('answer', '')
                    )
                    processed_samples[key] = sample
            
            # 找出未处理的样本
            unprocessed_samples = []
            for sample in samples:
                key = (
                    sample.get('question', ''),
                    sample.get('response_pair', {}).get('response_a', {}).get('answer', ''),
                    sample.get('response_pair', {}).get('response_b', {}).get('answer', '')
                )
                
                if key not in processed_samples:
                    unprocessed_samples.append({
                        'pair_key': pair_key,
                        'sample': sample
                    })
            
            print(f"  总样本: {len(samples)}")
            print(f"  已处理: {len(processed_samples)}")
            print(f"  需处理: {len(unprocessed_samples)}")
            
            samples_to_process.extend(unprocessed_samples)
        
        total_to_process = len(samples_to_process)
        print(f"\n总共需要处理 {total_to_process} 个样本")
        
        return samples_to_process
    
    def judge_sample(self, sample_data):
        """对单个样本进行Oracle判断"""
        pair_key = sample_data['pair_key']
        sample = sample_data['sample']
        
        question = sample.get('question', '')
        response_pair = sample.get('response_pair', {})
        response_a = response_pair.get('response_a', {}).get('answer', '')
        response_b = response_pair.get('response_b', {}).get('answer', '')
        
        # 调用Oracle判断
        preference = self.judge_response_pair_with_retry(question, response_a, response_b, f"{pair_key}")
        
        # 更新样本数据
        updated_sample = sample.copy()
        updated_sample['oracle_preference'] = preference
        updated_sample['judgment_success'] = preference is not None
        
        return updated_sample, preference is not None
    
    def judge_response_pair_with_retry(self, question, response1, response2, pair_id=""):
        """使用AI判断哪个回答更好，带重试机制"""
        for attempt in range(self.max_retries):
            try:
                result = self._single_judgment_request(question, response1, response2)
                if result is not None:
                    return result
                
                if attempt < self.max_retries - 1:
                    print(f"    配对 {pair_id}: 第{attempt+1}次尝试失败，{self.retry_delay}秒后重试...")
                    time.sleep(self.retry_delay)
                    
            except Exception as e:
                if attempt < self.max_retries - 1:
                    print(f"    配对 {pair_id}: 第{attempt+1}次尝试异常 ({e})，{self.retry_delay}秒后重试...")
                    time.sleep(self.retry_delay)
                else:
                    print(f"    配对 {pair_id}: 所有重试都失败了: {e}")
        
        return None
    
    def _single_judgment_request(self, question, response1, response2):
        """单次判断请求"""
        prompt = f"""Please judge which of the following two answers is better. Only return the result in JSON format, without any explanation.
Question: {question}
Answer 1: {response1}
Answer 2: {response2}
Please respond strictly in the following JSON format:
{{"preference": 1}}  or  {{"preference": 2}}
Where 1 means Answer 1 is better, and 2 means Answer 2 is better."""

        response = self.client.chat.completions.create(
            # model="gemini-2.5-pro",
            model="claude-sonnet-4-20250514",
            # model="gpt-5-chat-latest",
            # model= "deepseek-v3.1",
            # model = "glm-4.5",
            messages=[
                {"role": "system", "content": "You are a professional text quality assessment expert. Please carefully compare the quality of two answers, focusing on: 1) Accuracy - whether the information is correct; 2) Relevance - whether it addresses the question; 3) Clarity - whether the expression is clear and understandable; 4) Conciseness - whether it is concise and avoids redundancy; 5) Depth - whether it has insights; 6) Logic - whether it is well-organized; 7) Practicality - whether it is helpful to the questioner. Find the best balance between information content and readability. Only return the result in JSON format, without any explanation"},
                {"role": "user", "content": prompt}
            ],
            response_format={"type": "json_object"},
        )
        
        result_text = response.choices[0].message.content.strip()
        
        # 尝试解析JSON结果
        try:
            result = json.loads(result_text)
            preference = result.get("preference")
            if preference in [1, 2]:
                return preference
            else:
                print(f"无效的preference值: {preference}")
                return None
        except Exception as e:
            print(f"无法解析AI回复: {result_text}")
            return None
        # except json.JSONDecodeError:
        #     # 如果JSON解析失败，尝试从文本中提取数字
        #     if "1" in result_text and "2" not in result_text:
        #         return 1
        #     elif "2" in result_text and "1" not in result_text:
        #         return 2
        #     else:
        #         print(f"无法解析AI回复: {result_text}")
        #         return None
    
    def judge_samples_batch_concurrent(self, samples_data):
        """并发处理一批样本判断"""
        results = []
        
        def process_single_sample(sample_data):
            pair_key = sample_data['pair_key']
            sample_idx = sample_data.get('sample_idx', 0)
            total_samples = sample_data.get('total_samples', 1)
            
            try:
                print(f"  样本 {sample_idx+1}/{total_samples} ({pair_key})...")
                
                updated_sample, success = self.judge_sample(sample_data)
                
                if success:
                    preference = updated_sample.get('oracle_preference')
                    print(f"    ✓ 判断结果: 答案{preference}更好")
                else:
                    print(f"    ✗ 判断失败")
                
                return updated_sample, success, pair_key
                
            except Exception as e:
                print(f"    处理样本时出错: {e}")
                # 创建错误样本
                error_sample = sample_data['sample'].copy()
                error_sample['oracle_preference'] = None
                error_sample['judgment_success'] = False
                error_sample['error'] = str(e)
                
                return error_sample, False, pair_key
        
        # 使用线程池进行并发处理
        with ThreadPoolExecutor(max_workers=self.max_concurrent) as executor:
            futures = [executor.submit(process_single_sample, sample_data) for sample_data in samples_data]
            
            for future in futures:
                try:
                    updated_sample, success, pair_key = future.result()
                    results.append((updated_sample, success, pair_key))
                except Exception as e:
                    print(f"线程执行失败: {e}")
                    # 创建一个错误结果
                    error_sample = {
                        "oracle_preference": None,
                        "judgment_success": False,
                        "error": str(e)
                    }
                    results.append((error_sample, False, "unknown"))
        
        return results
    
    def safe_save_json(self, data, output_path):
        """安全保存JSON文件"""
        try:
            # 创建临时文件
            temp_path = output_path + ".tmp"
            backup_path = output_path + ".backup"
            
            # 如果目标文件存在，创建备份
            if os.path.exists(output_path):
                shutil.copy2(output_path, backup_path)
            
            # 写入临时文件
            with open(temp_path, 'w', encoding='utf-8') as f:
                json.dump(data, f, ensure_ascii=False, indent=2)
            
            # 验证临时文件
            with open(temp_path, 'r', encoding='utf-8') as f:
                json.load(f)  # 验证JSON格式正确
            
            # 原子性移动临时文件到目标位置
            if os.path.exists(output_path):
                os.remove(output_path)
            os.rename(temp_path, output_path)
            
            print(f"文件已安全保存: {output_path}")
            return True
            
        except Exception as e:
            print(f"保存文件失败: {e}")
            # 如果有备份文件，尝试恢复
            if os.path.exists(backup_path) and not os.path.exists(output_path):
                shutil.copy2(backup_path, output_path)
                print("已从备份恢复文件")
            return False
    
    def process_classified_results(self, input_file, output_file):
        """处理按模型对分类的结果文件，进行Oracle判断"""
        # 加载输入数据
        input_data = self.load_classified_results(input_file)
        if input_data is None:
            return False
        
        # 检查是否存在已处理的文件
        existing_data = self.load_existing_judgments(output_file)
        
        # 找出需要处理的样本
        samples_to_process = self.find_samples_to_process(input_data, existing_data)
        
        # 如果没有需要处理的样本
        if not samples_to_process:
            print("所有样本都已完整处理，无需继续！")
            return True
        
        total_samples_to_process = len(samples_to_process)
        print(f"\n开始处理 {total_samples_to_process} 个样本...")
        print("=" * 60)
        
        # 初始化输出数据结构
        if existing_data:
            output_data = existing_data.copy()
        else:
            output_data = input_data.copy()
            # 清除所有oracle_preference字段
            for pair_key, samples in output_data.items():
                for sample in samples:
                    sample['oracle_preference'] = None
                    sample['judgment_success'] = False
        
        processed_samples = 0
        successful_judgments = 0
        
        start_time = time.time()
        
        # 按批次处理样本
        batch_size = min(self.batch_size, total_samples_to_process)
        
        for batch_start in range(0, total_samples_to_process, batch_size):
            batch_end = min(batch_start + batch_size, total_samples_to_process)
            batch_samples = samples_to_process[batch_start:batch_end]
            
            print(f"\n处理批次 {batch_start//batch_size + 1}/{(total_samples_to_process-1)//batch_size + 1}")
            print(f"样本 {batch_start+1}-{batch_end}/{total_samples_to_process}")
            
            # 为批次中的每个样本添加索引信息
            for idx, sample_data in enumerate(batch_samples):
                sample_data['sample_idx'] = batch_start + idx
                sample_data['total_samples'] = total_samples_to_process
            
            # 并发处理批次
            batch_results = self.judge_samples_batch_concurrent(batch_samples)
            
            # 更新输出数据
            for updated_sample, success, pair_key in batch_results:
                # 找到对应的样本并更新
                if pair_key in output_data:
                    # 找到对应的样本并更新
                    question = updated_sample.get('question', '')
                    response_a_answer = updated_sample.get('response_pair', {}).get('response_a', {}).get('answer', '')
                    response_b_answer = updated_sample.get('response_pair', {}).get('response_b', {}).get('answer', '')
                    
                    for i, sample in enumerate(output_data[pair_key]):
                        if (sample.get('question', '') == question and 
                            sample.get('response_pair', {}).get('response_a', {}).get('answer', '') == response_a_answer and
                            sample.get('response_pair', {}).get('response_b', {}).get('answer', '') == response_b_answer):
                            output_data[pair_key][i] = updated_sample
                            break
                
                processed_samples += 1
                if success:
                    successful_judgments += 1
            
            print(f"  批次完成: {sum(1 for _, s, _ in batch_results if s)}/{len(batch_results)} 样本成功")
            
            # 每处理完一个批次就保存一次
            print(f"\n--- 保存进度 (已处理 {processed_samples}/{total_samples_to_process} 样本) ---")
            success = self.safe_save_json(output_data, output_file)
            if success:
                elapsed_time = time.time() - start_time
                print(f"成功判断: {successful_judgments}/{processed_samples}")
                print(f"已用时间: {elapsed_time:.1f}秒")
                if processed_samples > 0:
                    print(f"平均速度: {processed_samples/elapsed_time:.2f}样本/秒")
                    remaining_samples = total_samples_to_process - processed_samples
                    if remaining_samples > 0:
                        eta = remaining_samples / (processed_samples/elapsed_time)
                        print(f"预计剩余时间: {eta:.1f}秒")
            else:
                print("保存失败！")
                return False
        
        # 最终保存
        print(f"\n--- 最终保存 ---")
        success = self.safe_save_json(output_data, output_file)
        
        # 最终统计
        total_time = time.time() - start_time
        print("\n" + "=" * 60)
        print("处理完成!")
        print(f"本次处理样本数: {processed_samples}/{total_samples_to_process}")
        print(f"成功判断: {successful_judgments}/{processed_samples}")
        if processed_samples > 0:
            success_rate = (successful_judgments / processed_samples) * 100
            print(f"成功率: {success_rate:.1f}%")
        print(f"总用时: {total_time:.1f}秒")
        if processed_samples > 0:
            print(f"平均速度: {processed_samples/total_time:.2f}样本/秒")
        print(f"结果已保存到: {output_file}")
        
        return success

def main():
    """主函数"""
    # 配置
    API_KEY = ""
    BASE_URL = ""
    # INPUT_FILE = "/root/gMad/4_oracle_judge/origin/classified_reward_scores_by_pair_1.json"
    INPUT_FILE = "/root/gMad/4_oracle_judge/random/random_sampled_reward_pairs_5.json"
    OUTPUT_FILE = "/root/gMad/4_oracle_judge/random_result/classified_reward_scores_by_pair_5.json"  
    
    print("AI Oracle Judge 脚本启动（处理按模型对分类的数据）")
    print("=" * 60)
    print(f"输入文件: {INPUT_FILE}")
    print(f"输出文件: {OUTPUT_FILE}")
    print(f"API Base URL: {BASE_URL}")
    
    # 检查输入文件是否存在
    if not os.path.exists(INPUT_FILE):
        print(f"错误: 输入文件不存在: {INPUT_FILE}")
        return
    
    # 创建输出目录
    output_dir = os.path.dirname(OUTPUT_FILE)
    os.makedirs(output_dir, exist_ok=True)
    
    # 创建判断器并处理文件（配置并发参数）
    judge = AIJudgeResponses(
        API_KEY, 
        BASE_URL, 
        max_concurrent=3,  # 降低并发数以避免API限制
        max_retries=3      # 最大重试次数
    )
    
    print(f"并发配置: 最大并发数={judge.max_concurrent}, 最大重试次数={judge.max_retries}")
    print("=" * 60)
    
    success = judge.process_classified_results(INPUT_FILE, OUTPUT_FILE)
    
    if success:
        print("\n✓ 所有处理完成！")
        print(f"✓ 结果已保存到: {OUTPUT_FILE}")
    else:
        print("\n✗ 处理过程中出现错误")

if __name__ == "__main__":
    main() 