#!/usr/bin/env python3
"""
合并多GPU并行任务的JSONL和JSON评估文件
"""

import os
import json
import argparse
import glob
import re
from typing import List, Dict, Any
from collections import defaultdict
import numpy as np


def find_matching_files(input_dir: str) -> Dict[str, Dict[str, str]]:
    """
    查找匹配的JSONL和JSON文件对
    返回格式: {base_name: {"jsonl": jsonl_path, "json": json_path}}
    """
    # 查找所有jsonl文件
    jsonl_pattern = os.path.join(input_dir, "*.jsonl")
    jsonl_files = glob.glob(jsonl_pattern)
    
    file_pairs = {}
    
    for jsonl_file in jsonl_files:
        # 提取基础名称（去掉.jsonl后缀）
        base_name = os.path.basename(jsonl_file)[:-6]  # 去掉.jsonl
        
        # 寻找对应的metrics.json文件
        json_file = os.path.join(input_dir, f"{base_name}_metrics.json")
        
        if os.path.exists(json_file):
            file_pairs[base_name] = {
                "jsonl": jsonl_file,
                "json": json_file
            }
        else:
            print(f"Warning: 未找到对应的metrics文件: {json_file}")
    
    return file_pairs


def load_jsonl(file_path: str) -> List[Dict[str, Any]]:
    """加载JSONL文件"""
    data = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if line:
                data.append(json.loads(line))
    return data


def load_json(file_path: str) -> Dict[str, Any]:
    """加载JSON文件"""
    with open(file_path, 'r', encoding='utf-8') as f:
        return json.load(f)


def merge_jsonl_files(jsonl_data_list: List[List[Dict[str, Any]]]) -> List[Dict[str, Any]]:
    """
    合并多个JSONL文件的数据
    """
    if not jsonl_data_list:
        return []
    
    # 获取数据长度，所有文件应该有相同的行数
    num_samples = len(jsonl_data_list[0])
    for i, data in enumerate(jsonl_data_list[1:], 1):
        if len(data) != num_samples:
            raise ValueError(f"文件{i+1}的样本数({len(data)})与第一个文件不匹配({num_samples})")
    
    merged_data = []
    
    # 识别哪些字段是数组字段
    array_fields = ['code', 'pred', 'report', 'finish_reason', 'score']
    
    for sample_idx in range(num_samples):
        # 获取所有文件中的同一行数据
        samples = [data[sample_idx] for data in jsonl_data_list]
        
        # 验证idx和question是否相同
        base_sample = samples[0]
        for i, sample in enumerate(samples[1:], 1):
            if sample['idx'] != base_sample['idx']:
                raise ValueError(f"样本{sample_idx}: idx不匹配 ({sample['idx']} vs {base_sample['idx']})")
            if sample['question'] != base_sample['question']:
                raise ValueError(f"样本{sample_idx}: question不匹配")
        
        # 创建合并后的样本
        merged_sample = {}
        
        # 复制非数组字段（使用第一个文件的值）
        for key, value in base_sample.items():
            if key not in array_fields:
                merged_sample[key] = value
        
        # 合并数组字段
        for field in array_fields:
            merged_array = []
            for sample in samples:
                if field in sample and sample[field] is not None:
                    if isinstance(sample[field], list):
                        merged_array.extend(sample[field])
                    else:
                        merged_array.append(sample[field])
            merged_sample[field] = merged_array
        
        merged_data.append(merged_sample)
    
    return merged_data


def pass_at_k(n, c, k):
    """
    计算 pass@k 的无偏估计
    :param n: total number of samples
    :param c: number of correct samples
    :param k: k in pass@$k$
    """
    if n - c < k:
        return 1.0
    return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))


def calculate_metrics(merged_data: List[Dict[str, Any]], json_metrics_list: List[Dict[str, Any]]) -> Dict[str, Any]:
    """
    基于合并后的数据计算新的评估指标
    """
    if not merged_data:
        return {}
    
    # 1. 构建二维 scores 数组
    scores_2d = []
    valid_problems = []
    
    for i, sample in enumerate(merged_data):
        if 'score' in sample and sample['score'] and len(sample['score']) > 0:
            scores_2d.append(sample['score'])
            valid_problems.append(i)
    
    if not scores_2d:
        # 没有有效的score数据
        return {
            "num_samples": len(merged_data),
            "num_scores": 0,
            "timeout_samples": 0,
            "empty_samples": 0,
            "acc": 0.0,
            "pass_acc": 0.0,
            "pass@k": {"1": 0.0},
            "time_use_in_second": 0.0,
            "time_use_in_minite": "0:00"
        }
    
    # 转换为 numpy 数组，确保所有问题的采样次数相同
    max_samples = max(len(scores) for scores in scores_2d)
    
    # 填充较短的数组（用 False 填充，表示未尝试）
    padded_scores = []
    for scores in scores_2d:
        padded = scores + [False] * (max_samples - len(scores))
        padded_scores.append(padded)
    
    scores_array = np.array(padded_scores, dtype=bool)  # shape: (num_problems, N_sampling)
    num_problems, N_sampling = scores_array.shape
    
    print(f"构建的 scores 数组形状: ({num_problems}, {N_sampling})")
    
    # 2. 计算每个问题答对的次数
    correct_counts = np.sum(scores_array, axis=1)  # shape: (num_problems,)
    
    # 3. 计算 pass@k 对于不同的 k 值
    # 生成 k 值：1 到 N_sampling 的所有值
    k_values = list(range(1, N_sampling + 1))
    
    # 计算每个问题的 pass@k 矩阵
    pass_at_k_matrix = np.zeros((num_problems, N_sampling))  # shape: (num_problems, N_sampling)
    
    for problem_idx in range(num_problems):
        n = N_sampling  # 总采样次数
        c = correct_counts[problem_idx]  # 该问题答对的次数
        
        for k_idx, k in enumerate(k_values):
            pass_at_k_matrix[problem_idx, k_idx] = pass_at_k(n, c, k)
    
    # 4. 对所有问题求平均，得到最终的 pass@k
    avg_pass_at_k = np.mean(pass_at_k_matrix, axis=0)  # shape: (N_sampling,)
    
    # 5. 提取 2 的幂次以及关键值
    pass_at_k_dict = {}
    
    # 总是包含 pass@1 (用于 acc)
    pass_at_k_dict["1"] = round(avg_pass_at_k[0] * 100, 1)
    
    # 添加 2 的幂次
    k = 2
    while k <= N_sampling:
        if k <= len(avg_pass_at_k):
            pass_at_k_dict[str(k)] = round(avg_pass_at_k[k-1] * 100, 1)
        k *= 2
    
    # 添加最大采样次数 (用于 pass_acc)
    if N_sampling > 1 and str(N_sampling) not in pass_at_k_dict:
        pass_at_k_dict[str(N_sampling)] = round(avg_pass_at_k[N_sampling-1] * 100, 1)
    
    # acc = pass@1, pass_acc = pass@N
    acc = pass_at_k_dict.get("1", 0.0)
    pass_acc = pass_at_k_dict.get(str(N_sampling), 0.0)
    
    # 计算其他指标的总和
    legends = [ metrics['legend'] for metrics in json_metrics_list ]
    benchmarks = [ metrics['benchmark'] for metrics in json_metrics_list ]
    model_names = [ metrics['model_name'] for metrics in json_metrics_list ]
    total_num_samples = sum(metrics['num_samples'] for metrics in json_metrics_list)
    total_num_scores = sum(metrics['num_scores'] for metrics in json_metrics_list)
    total_timeout_samples = sum(metrics.get('timeout_samples', 0) for metrics in json_metrics_list)
    total_empty_samples = sum(metrics.get('empty_samples', 0) for metrics in json_metrics_list)
    
    # 计算平均时间
    avg_time_seconds = sum(metrics['time_use_in_second'] for metrics in json_metrics_list) / len(json_metrics_list)
    avg_time_minutes = f"{int(avg_time_seconds // 60)}:{int(avg_time_seconds % 60):02d}"
    
    return {
        "num_samples": total_num_samples,
        "num_scores": total_num_scores,
        "timeout_samples": total_timeout_samples,
        "empty_samples": total_empty_samples,
        "acc": acc,
        "pass_acc": pass_acc,
        "pass@k": pass_at_k_dict,
        "time_use_in_second": avg_time_seconds,
        "time_use_in_minite": avg_time_minutes,
        "legend": legends[0],
        "benchmark": benchmarks[0],
        "model_name": model_names[0],
    }


def save_jsonl(data: List[Dict[str, Any]], file_path: str):
    """保存JSONL文件"""
    with open(file_path, 'w', encoding='utf-8') as f:
        for item in data:
            f.write(json.dumps(item, ensure_ascii=False) + '\n')


def save_json(data: Dict[str, Any], file_path: str):
    """保存JSON文件"""
    with open(file_path, 'w', encoding='utf-8') as f:
        json.dump(data, f, indent=4, ensure_ascii=False)


def generate_output_name(base_names: List[str]) -> str:
    """
    基于输入文件名生成输出文件名
    例如：test_cot_-1_seed42_t0.6_Gpu0_s0_e-1 -> test_cot_-1_seed42_t0.6_merged
    """
    # 提取公共前缀
    if not base_names:
        return "merged"
    
    # 找到所有文件名的公共部分
    first_name = base_names[0]
    
    # 移除GPU相关的部分，生成合并后的名称
    # 假设格式为: prefix_Gpu{N}_suffix
    pattern = r'_Gpu\d+_.*'
    if re.search(pattern, first_name):
        merged_name = re.sub(pattern, '_merged', first_name)
    else:
        merged_name = first_name + '_merged'
    
    return merged_name


def main():
    parser = argparse.ArgumentParser(description='合并多GPU并行任务的输出文件')
    parser.add_argument('--input_dir', help='输入目录路径')
    parser.add_argument('--output_dir', help='输出目录路径', default='/chencongliang/project/icl-diversity/limit-of-RLVR/math/outputs/final_res')
    
    args = parser.parse_args()
    
    input_dir = args.input_dir
    output_dir = args.output_dir
    
    # 创建输出目录
    os.makedirs(output_dir, exist_ok=True)
    
    print(f"输入目录: {input_dir}")
    print(f"输出目录: {output_dir}")
    
    # 查找匹配的文件对
    print("\n查找匹配的文件...")
    file_pairs = find_matching_files(input_dir)
    
    if not file_pairs:
        print("未找到匹配的文件对!")
        return
    
    print(f"找到 {len(file_pairs)} 个文件对:")
    for base_name in sorted(file_pairs.keys()):
        print(f"  - {base_name}")
    
    # 加载所有数据
    print("\n加载数据...")
    jsonl_data_list = []
    json_metrics_list = []
    
    for base_name, files in file_pairs.items():
        print(f"  加载 {base_name}...")
        jsonl_data = load_jsonl(files['jsonl'])
        json_metrics = load_json(files['json'])
        
        jsonl_data_list.append(jsonl_data)
        json_metrics_list.append(json_metrics)
    
    # 合并JSONL数据
    print("\n合并JSONL数据...")
    try:
        merged_data = merge_jsonl_files(jsonl_data_list)
        print(f"成功合并 {len(merged_data)} 个样本")
    except ValueError as e:
        print(f"合并失败: {e}")
        return
    
    # 计算新的评估指标
    print("\n计算评估指标...")
    new_metrics = calculate_metrics(merged_data, json_metrics_list)
    
    print("合并后的指标:")
    print(f"  样本数: {new_metrics['num_samples']}")
    print(f"  acc (pass@1): {new_metrics['acc']}%")
    print(f"  pass_acc (pass@N): {new_metrics['pass_acc']}%")
    print("  pass@k指标:")
    for k, score in sorted(new_metrics['pass@k'].items(), key=lambda x: int(x[0])):
        print(f"    pass@{k}: {score}%")
    print(f"  平均时间: {new_metrics['time_use_in_minite']}")
    
    # 生成输出文件名
    output_base_name = generate_output_name(list(file_pairs.keys()))
    
    # 保存文件
    print(f"\n保存结果到 {output_base_name}...")
    
    output_jsonl_path = os.path.join(f"{output_dir}/JSONLS", f"{output_base_name}.jsonl")
    output_json_path = os.path.join(f"{output_dir}/JSONS", f"{output_base_name}_metrics.json")
    
    save_jsonl(merged_data, output_jsonl_path)
    save_json(new_metrics, output_json_path)
    
    print(f"✓ JSONL文件已保存: {output_jsonl_path}")
    print(f"✓ 指标文件已保存: {output_json_path}")
    print("\n合并完成!")


if __name__ == "__main__":
    main()