import json
import os
from pathlib import Path

# 设置你的文件夹路径
folder_path = ""

# 所有可能的 dataset/source_dataset 类型
valid_datasets = ["DeepInception", "GCG", "PAIR", "AutoDAN", "SAP30", "SQL", "Just-Eval", "helpful_base"]

# 统计函数 - 同时返回总时间和样本数
def calculate_time_stats(data_list, dataset_field, dataset_name=None):
    total_time = 0.0
    count = 0
    for item in data_list:
        # 根据指定的字段名检查数据集
        if dataset_name is None or item.get(dataset_field) == dataset_name:
            time_cost = item.get("time_cost", 0)
            total_time += time_cost
            count += 1
    return total_time, count

# 主函数
def process_all_json_files(root_folder):
    results = {}
    
    # 遍历根文件夹下的所有子文件夹
    for subfolder in Path(root_folder).iterdir():
        if not subfolder.is_dir():
            continue
        
        # 查找不带 _safe_eval 的 JSON 文件
        json_files = list(subfolder.glob("*.json"))
        target_file = None
        
        for file in json_files:
            if "_safe_eval" not in file.name:
                target_file = file
                break
        
        if not target_file:
            print(f"警告: 在子文件夹 {subfolder.name} 中未找到符合条件的JSON文件")
            continue
        
        try:
            # 读取 JSON 文件
            with open(target_file, 'r', encoding='utf-8') as f:
                data = json.load(f)
            
            # 检测 JSON 结构类型（列表或字典）
            if isinstance(data, list):
                data_list = data
                dataset_field = "dataset"  # 列表格式使用 "dataset" 字段
            elif isinstance(data, dict) and "data" in data:
                data_list = data["data"]
                dataset_field = "source_dataset"  # 字典格式使用 "source_dataset" 字段
            else:
                print(f"警告: 文件 {target_file.name} 的格式不符合预期")
                continue
                
            if not data_list:
                print(f"警告: 文件 {target_file.name} 中没有数据")
                continue
            
            # 计算总时间和样本数
            total_all_time, total_count = calculate_time_stats(data_list, dataset_field)
            avg_all_time = total_all_time / total_count if total_count > 0 else 0
            
            # 按数据集计算时间和样本数
            dataset_stats = {}
            for ds in valid_datasets:
                time, count = calculate_time_stats(data_list, dataset_field, ds)
                if count > 0:
                    dataset_stats[ds] = {
                        "total_time": time,
                        "count": count,
                        "avg_time": time / count
                    }
            
            # 记录结果
            results[subfolder.name] = {
                "total_time": total_all_time,
                "count": total_count,
                "avg_time": avg_all_time,
                "per_dataset": dataset_stats
            }
            
            print(f"处理完成: {subfolder.name}")
            print(f"  总时间: {total_all_time:.4f} 秒")
            print(f"  样本数: {total_count}")
            print(f"  平均时间: {avg_all_time:.4f} 秒")
            
            for ds, stats in dataset_stats.items():
                print(f"  {ds}:")
                print(f"    总时间: {stats['total_time']:.2f} 秒")
                print(f"    样本数: {stats['count']}")
                print(f"    平均时间: {stats['avg_time']:.4f} 秒")
            print()
                
        except Exception as e:
            print(f"错误: 处理文件 {target_file} 时出错: {e}")
    
    return results

# 执行处理
all_results = process_all_json_files(folder_path)

# 保存结果到文件
output_path = os.path.join(folder_path, "time_cost_summary.json")
with open(output_path, 'w', encoding='utf-8') as f:
    json.dump(all_results, f, ensure_ascii=False, indent=2)

print(f"所有结果已保存到: {output_path}")