import json
import os
from collections import defaultdict
import re

def load_model_results(results_dir, model_name, budgets):
    """
    加载指定模型在所有budget下的结果
    
    Args:
        results_dir: 结果目录路径
        model_name: 模型名称
        budgets: budget列表
    
    Returns:
        Dict[int, List[Dict]]: budget -> 问题列表
    """
    model_results = {}
    
    for budget in budgets:
        # 尝试不同的文件命名模式
        possible_paths = [
            os.path.join(results_dir, f"{model_name}_{budget}_test.jsonl"),
            os.path.join(results_dir, f"{model_name}-{budget}_test.jsonl"),
        ]
        
        file_path = None
        for path in possible_paths:
            if os.path.exists(path):
                file_path = path
                break
        
        if not file_path:
            print(f"警告: 未找到模型 {model_name} 在budget {budget}下的文件")
            continue
        
        try:
            responses = []
            with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
                for line in f:
                    if line.strip():
                        # 移除非法控制字符
                        clean_line = re.sub(r'[\x00-\x1F\x7F]', '', line)
                        try:
                            data_point = json.loads(clean_line)
                            if 'prompt' in data_point and 'generated_text' in data_point and 'answer' in data_point:
                                responses.append(data_point)
                        except json.JSONDecodeError as e:
                            print(f"解析错误: {e}，在行: {clean_line}")

            model_results[budget] = responses
            print(f"预算 {budget}: 加载了 {len(responses)} 个回答")
                        
        except Exception as e:
            print(f"错误: 无法加载文件 {file_path}: {e}")
            continue
    
    return model_results

def analyze_questions_by_mode(results_dir, models, budgets, mode):
    """
    根据指定模式分析问题
    
    Args:
        results_dir: 结果目录路径
        models: 模型名称列表
        budgets: budget列表
        mode: 分析模式 ('all_wrong', 'all_correct', 'early_correct_late_wrong', 'early_wrong_late_correct')
    
    Returns:
        List[Dict] or Dict[str, List[Dict]]: 符合条件的问题列表，对于模式3和4返回按模型分组的字典
    """
    print(f"开始分析模式: {mode}")
    
    # 存储每个模型在每个budget下的结果
    all_model_results = {}
    
    # 加载每个模型的结果
    for model_name in models:
        print(f"\n加载模型 {model_name} 的结果...")
        model_results = load_model_results(results_dir, model_name, budgets)
        all_model_results[model_name] = model_results
    
    # 找出所有模型都有数据的问题
    common_questions = set()
    first_model = models[0]
    if first_model in all_model_results and budgets[0] in all_model_results[first_model]:
        first_results = all_model_results[first_model][budgets[0]]
        common_questions = set(range(len(first_results)))
    
    # 检查每个模型在每个budget下是否都有数据
    for model_name in models:
        if model_name not in all_model_results:
            continue
        for budget in budgets:
            if budget not in all_model_results[model_name]:
                continue
            current_questions = set(range(len(all_model_results[model_name][budget])))
            common_questions = common_questions.intersection(current_questions)
    
    print(f"\n找到 {len(common_questions)} 个所有模型都有数据的问题")
    
    # 根据模式筛选问题
    if mode in ['early_correct_late_wrong', 'early_wrong_late_correct']:
        # 模式3和4：按模型分开统计
        return analyze_by_model_for_early_late_modes(all_model_results, models, budgets, common_questions, first_model, mode)
    else:
        # 模式1和2：所有模型一起统计
        filtered_questions = []
        
        for question_idx in common_questions:
            if mode == 'all_wrong':
                # 模式1: 全错 - 所有模型在所有budget下都回答不对
                if check_all_wrong(all_model_results, models, budgets, question_idx):
                    filtered_questions.append(collect_question_info(all_model_results, models, budgets, question_idx, first_model))
            
            elif mode == 'all_correct':
                # 模式2: 全对 - 所有模型在所有budget下都回答正确
                if check_all_correct(all_model_results, models, budgets, question_idx):
                    filtered_questions.append(collect_question_info(all_model_results, models, budgets, question_idx, first_model))
        
        return filtered_questions

def analyze_by_model_for_early_late_modes(all_model_results, models, budgets, common_questions, first_model, mode):
    """
    为早期/后期模式按模型分开分析
    
    Args:
        all_model_results: 所有模型的结果
        models: 模型列表
        budgets: budget列表
        common_questions: 共同问题集合
        first_model: 第一个模型名称
        mode: 分析模式
    
    Returns:
        Dict[str, List[Dict]]: 按模型分组的问题列表
    """
    model_questions = {model_name: [] for model_name in models}
    
    for question_idx in common_questions:
        for model_name in models:
            if model_name not in all_model_results:
                continue
            
            # 检查该模型是否满足条件
            if mode == 'early_correct_late_wrong':
                if check_early_correct_late_wrong_for_single_model(all_model_results[model_name], budgets, question_idx):
                    question_info = collect_question_info_for_single_model(all_model_results, model_name, budgets, question_idx, first_model)
                    model_questions[model_name].append(question_info)
            
            elif mode == 'early_wrong_late_correct':
                if check_early_wrong_late_correct_for_single_model(all_model_results[model_name], budgets, question_idx):
                    question_info = collect_question_info_for_single_model(all_model_results, model_name, budgets, question_idx, first_model)
                    model_questions[model_name].append(question_info)
    
    return model_questions

def check_early_correct_late_wrong_for_single_model(model_results, budgets, question_idx):
    """检查单个模型是否满足早期正确后期错误的条件"""
    early_budgets = [512, 1024, 2048]
    late_budgets = [budget for budget in budgets if budget not in early_budgets]
    
    # 检查早期budget是否都正确
    for budget in early_budgets:
        if budget not in model_results:
            continue
        if question_idx < len(model_results[budget]):
            question_data = model_results[budget][question_idx]
            if not question_data.get('correctness', False):
                return False
    
    # 检查后期budget是否都错误
    for budget in late_budgets:
        if budget not in model_results:
            continue
        if question_idx < len(model_results[budget]):
            question_data = model_results[budget][question_idx]
            if question_data.get('correctness', False):
                return False
    
    return True

def check_early_wrong_late_correct_for_single_model(model_results, budgets, question_idx):
    """检查单个模型是否满足早期错误后期正确的条件"""
    early_budgets = [512, 1024, 2048]
    late_budgets = [budget for budget in budgets if budget not in early_budgets]
    
    # 检查早期budget是否都错误
    for budget in early_budgets:
        if budget not in model_results:
            continue
        if question_idx < len(model_results[budget]):
            question_data = model_results[budget][question_idx]
            if question_data.get('correctness', False):
                return False
    
    # 检查后期budget是否都正确
    for budget in late_budgets:
        if budget not in model_results:
            continue
        if question_idx < len(model_results[budget]):
            question_data = model_results[budget][question_idx]
            if not question_data.get('correctness', False):
                return False
    
    return True

def collect_question_info_for_single_model(all_model_results, target_model, budgets, question_idx, first_model):
    """为单个模型收集问题信息"""
    question_info = {
        'question_index': question_idx,
        'prompt': None,
        'answer': None,
        'model_results': {}
    }
    
    # 从第一个模型获取问题信息
    if first_model in all_model_results and budgets[0] in all_model_results[first_model]:
        if question_idx < len(all_model_results[first_model][budgets[0]]):
            first_question = all_model_results[first_model][budgets[0]][question_idx]
            question_info['prompt'] = first_question.get('prompt', '')
            question_info['answer'] = first_question.get('answer', '')
    
    # 只收集目标模型的结果
    if target_model in all_model_results:
        question_info['model_results'][target_model] = {}
        for budget in budgets:
            if budget not in all_model_results[target_model]:
                continue
            if question_idx < len(all_model_results[target_model][budget]):
                question_data = all_model_results[target_model][budget][question_idx]
                question_info['model_results'][target_model][budget] = {
                    'generated_text': question_data.get('generated_text', ''),
                    'correctness': question_data.get('correctness', False)
                }
    
    return question_info

def check_all_wrong(all_model_results, models, budgets, question_idx):
    """检查是否所有模型在所有budget下都回答不对"""
    for model_name in models:
        if model_name not in all_model_results:
            continue
        for budget in budgets:
            if budget not in all_model_results[model_name]:
                continue
            if question_idx < len(all_model_results[model_name][budget]):
                question_data = all_model_results[model_name][budget][question_idx]
                if question_data.get('correctness', False):
                    return False
    return True

def check_all_correct(all_model_results, models, budgets, question_idx):
    """检查是否所有模型在所有budget下都回答正确"""
    for model_name in models:
        if model_name not in all_model_results:
            continue
        for budget in budgets:
            if budget not in all_model_results[model_name]:
                continue
            if question_idx < len(all_model_results[model_name][budget]):
                question_data = all_model_results[model_name][budget][question_idx]
                if not question_data.get('correctness', False):
                    return False
    return True

def check_early_correct_late_wrong(all_model_results, models, budgets, question_idx):
    """检查是否512 1024 2048对，后面错"""
    early_budgets = [512, 1024, 2048]
    late_budgets = [budget for budget in budgets if budget not in early_budgets]
    
    # 检查早期budget是否都正确
    for model_name in models:
        if model_name not in all_model_results:
            continue
        for budget in early_budgets:
            if budget not in all_model_results[model_name]:
                continue
            if question_idx < len(all_model_results[model_name][budget]):
                question_data = all_model_results[model_name][budget][question_idx]
                if not question_data.get('correctness', False):
                    return False
    
    # 检查后期budget是否都错误
    for model_name in models:
        if model_name not in all_model_results:
            continue
        for budget in late_budgets:
            if budget not in all_model_results[model_name]:
                continue
            if question_idx < len(all_model_results[model_name][budget]):
                question_data = all_model_results[model_name][budget][question_idx]
                if question_data.get('correctness', False):
                    return False
    
    return True

def check_early_wrong_late_correct(all_model_results, models, budgets, question_idx):
    """检查是否512 1024 2048错，后面对"""
    early_budgets = [512, 1024, 2048]
    late_budgets = [budget for budget in budgets if budget not in early_budgets]
    
    # 检查早期budget是否都错误
    for model_name in models:
        if model_name not in all_model_results:
            continue
        for budget in early_budgets:
            if budget not in all_model_results[model_name]:
                continue
            if question_idx < len(all_model_results[model_name][budget]):
                question_data = all_model_results[model_name][budget][question_idx]
                if question_data.get('correctness', False):
                    return False
    
    # 检查后期budget是否都正确
    for model_name in models:
        if model_name not in all_model_results:
            continue
        for budget in late_budgets:
            if budget not in all_model_results[model_name]:
                continue
            if question_idx < len(all_model_results[model_name][budget]):
                question_data = all_model_results[model_name][budget][question_idx]
                if not question_data.get('correctness', False):
                    return False
    
    return True

def collect_question_info(all_model_results, models, budgets, question_idx, first_model):
    """收集问题信息"""
    question_info = {
        'question_index': question_idx,
        'prompt': None,
        'answer': None,
        'model_results': {}
    }
    
    # 从第一个模型获取问题信息
    if first_model in all_model_results and budgets[0] in all_model_results[first_model]:
        if question_idx < len(all_model_results[first_model][budgets[0]]):
            first_question = all_model_results[first_model][budgets[0]][question_idx]
            question_info['prompt'] = first_question.get('prompt', '')
            question_info['answer'] = first_question.get('answer', '')
    
    # 收集每个模型在每个budget下的结果
    for model_name in models:
        if model_name not in all_model_results:
            continue
            
        question_info['model_results'][model_name] = {}
        for budget in budgets:
            if budget not in all_model_results[model_name]:
                continue
                
            if question_idx < len(all_model_results[model_name][budget]):
                question_data = all_model_results[model_name][budget][question_idx]
                question_info['model_results'][model_name][budget] = {
                    'generated_text': question_data.get('generated_text', ''),
                    'correctness': question_data.get('correctness', False)
                }
    
    return question_info

def main():
    """
    主函数
    """
    # 配置参数
    results_dir = "../results"  # 结果文件目录
    budgets = [512, 1024, 2048, 4096, 8192, 16384]  # 要分析的budget
    
    # 要分析的模型
    models = [
        'l1-8b',
        'l1-1.5',  # 对应l1-1.5b
        'l1-8b-ours-deepscaler-LUFFY-style',  # 对应ours-luffy
        'l1-8b-ours-openr1',  # 对应our openr1
        'seed-36b'  # 对应seed
    ]
    
    # 定义四种分析模式
    modes = {
        'all_wrong': '全错 - 所有模型在所有budget下都回答不对',
        'all_correct': '全对 - 所有模型在所有budget下都回答正确',
        'early_correct_late_wrong': '512 1024 2048 对 后面错',
        'early_wrong_late_correct': '512 1024 2048 错 后面对'
    }
    
    print("=" * 80)
    print("问题分析工具 - 支持四种模式")
    print("=" * 80)
    print(f"结果目录: {results_dir}")
    print(f"预算: {budgets}")
    print(f"模型: {models}")
    print("=" * 80)
    
    # 检查目录是否存在
    if not os.path.isdir(results_dir):
        print(f"错误: 目录 '{results_dir}' 不存在")
        return
    
    # 对每种模式进行分析
    for mode_key, mode_desc in modes.items():
        print(f"\n{'='*60}")
        print(f"分析模式: {mode_desc}")
        print(f"{'='*60}")
        
        # 分析该模式下的问题
        result = analyze_questions_by_mode(results_dir, models, budgets, mode_key)
        
        if mode_key in ['early_correct_late_wrong', 'early_wrong_late_correct']:
            # 模式3和4：按模型分开保存
            total_questions = 0
            for model_name, questions in result.items():
                print(f"\n模型 {model_name}: 找到 {len(questions)} 个符合条件的问题")
                total_questions += len(questions)
                
                if questions:
                    # 为每个模型保存单独的文件
                    output_file = f"collected_questions/{mode_key}_{model_name}_questions.json"
                    with open(output_file, 'w', encoding='utf-8') as f:
                        json.dump(questions, f, ensure_ascii=False, indent=2)
                    print(f"结果已保存到: {output_file}")
                else:
                    print(f"模型 {model_name} 没有找到符合条件的问题")
            
            print(f"\n总计: {total_questions} 个问题")
        
        
        else:
            # 模式1和2：所有模型一起保存
            print(f"\n找到 {len(result)} 个符合条件的问题")
            
            if result:
                # 保存结果到文件
                output_file = f"collected_questions/{mode_key}_questions.json"
                with open(output_file, 'w', encoding='utf-8') as f:
                    json.dump(result, f, ensure_ascii=False, indent=2)
                print(f"结果已保存到: {output_file}")
            else:
                print(f"没有找到符合模式 '{mode_desc}' 的问题")

if __name__ == "__main__":
    main()
