import os
import json
from typing import Dict, List, Any, Tuple

def calculate_metrics_averages(data: List[Dict[str, Any]]) -> Tuple[Dict[str, float], float]:
    """计算数据中五个指标的平均值及这些平均值的均值"""
    metrics = {
        "helpfulness": {"sum": 0.0, "count": 0},
        "clarity": {"sum": 0.0, "count": 0},
        "factuality": {"sum": 0.0, "count": 0},
        "depth": {"sum": 0.0, "count": 0},
        "engagement": {"sum": 0.0, "count": 0}
    }
    
    for item in data:
        # 1. 检查parsed_result是否存在且为字典
        if "parsed_result" not in item or not isinstance(item["parsed_result"], dict):
            continue  # 跳过无有效结果的条目
            
        parsed_result = item["parsed_result"]
        
        # 2. 遍历每个指标，确保parsed_result是字典后再检查成员
        for metric in metrics.keys():
            if metric not in parsed_result:
                continue
                
            metric_data = parsed_result[metric]
            if "score" not in metric_data:
                continue
                
            score = metric_data["score"]
            
            # 3. 尝试转换分数为数字
            try:
                score_val = float(score)
                if 1 <= score_val <= 5:
                    metrics[metric]["sum"] += score_val
                    metrics[metric]["count"] += 1
            except (ValueError, TypeError):
                continue  # 跳过无效分数
    
    # 计算平均值
    averages = {}
    valid_metrics_count = 0
    total_average = 0.0
    
    for metric, stats in metrics.items():
        if stats["count"] > 0:
            averages[metric] = stats["sum"] / stats["count"]
            total_average += averages[metric]
            valid_metrics_count += 1
        else:
            averages[metric] = None
    
    overall_average = total_average / valid_metrics_count if valid_metrics_count > 0 else None
    return averages, overall_average

def process_json_file(file_path: str) -> Tuple[Dict[str, Any], Dict[str, Any]]:
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            try:
                data = json.load(f)
            except json.JSONDecodeError as e:
                print(f"JSON解析错误在文件 {file_path}: {e}")
                return None, None
            
            # 确保数据是列表格式
            if not isinstance(data, list):
                print(f"文件 {file_path} 中的数据不是列表格式")
                return None, None
            
            # 计算所有条目的统计结果
            all_avg, all_overall = calculate_metrics_averages(data)
            
            # 计算前800条的统计结果
            first_800 = data[:800]
            first_800_avg, first_800_overall = calculate_metrics_averages(first_800)
            
            return {
                "count": len(data),
                "metrics_avg": all_avg,
                "overall_avg": all_overall
            }, {
                "count": len(first_800),
                "metrics_avg": first_800_avg,
                "overall_avg": first_800_overall
            }
            
    except Exception as e:
        print(f"处理文件 {file_path} 时出错: {e}")
        return None, None

def find_target_json_files(parent_dir: str) -> Dict[str, str]:
    target_files = {}
    for item in os.listdir(parent_dir):
        item_path = os.path.join(parent_dir, item)
        if not os.path.isdir(item_path):
            continue
        
        json_files = []
        for file in os.listdir(item_path):
            if (file.endswith('.json') and 
                'safe_eval' in file and 
                'eval_res' not in file):
                json_files.append(os.path.join(item_path, file))
        
        if len(json_files) == 1:
            target_files[item] = json_files[0]
        elif len(json_files) > 1:
            print(f"警告: 文件夹 {item} 中有多个符合条件的JSON文件，将使用第一个")
            target_files[item] = json_files[0]
    
    return target_files

def metrics_all_valid(metrics_avg: Dict[str, float]) -> bool:
    return all(value is not None for value in metrics_avg.values())

def main():
    target_dir = ""
    if not os.path.exists(target_dir):
        print(f"错误: 目录 {target_dir} 不存在")
        return
    
    target_files = find_target_json_files(target_dir)
    if not target_files:
        print("错误: 未找到符合条件的JSON文件")
        return
    
    results = {}
    for subdir_name, file_path in target_files.items():
        print(f"处理文件夹: {subdir_name}")
        all_stats, first_800_stats = process_json_file(file_path)
        
        if all_stats and first_800_stats:
            results[subdir_name] = {
                "all": all_stats,
                "first_800": first_800_stats,
                "file_path": file_path
            }
        else:
            print(f"警告: 文件夹 {subdir_name} 处理失败")
    
    output_file = os.path.join(target_dir, 'metrics_analysis_results.txt')
    with open(output_file, 'w', encoding='utf-8') as f:
        f.write("====== 指标分析结果 ======\n\n")
        
        for subdir_name, stats in results.items():
            f.write(f"文件夹: {subdir_name}\n")
            f.write(f"文件路径: {stats['file_path']}\n\n")
            
            # 前800条数据统计（仅当所有指标都有效时）
            if metrics_all_valid(stats['first_800']['metrics_avg']):
                f.write("  前800条数据统计:\n")
                f.write(f"    有效条目数: {stats['first_800']['count']}\n")
                f.write("    各指标平均值:\n")
                for metric, avg in stats['first_800']['metrics_avg'].items():
                    f.write(f"      {metric}: {avg:.4f}\n")
                f.write(f"    五个指标平均值的均值: {stats['first_800']['overall_avg']:.4f}\n\n")
            
            # 所有数据统计（仅当所有指标都有效时）
            if metrics_all_valid(stats['all']['metrics_avg']):
                f.write("  所有数据统计:\n")
                f.write(f"    总条目数: {stats['all']['count']}\n")
                f.write("    各指标平均值:\n")
                for metric, avg in stats['all']['metrics_avg'].items():
                    f.write(f"      {metric}: {avg:.4f}\n")
                f.write(f"    五个指标平均值的均值: {stats['all']['overall_avg']:.4f}\n")
            
            if metrics_all_valid(stats['first_800']['metrics_avg']) or metrics_all_valid(stats['all']['metrics_avg']):
                f.write("-" * 60 + "\n\n")
        
        f.write(f"分析完成时间: {os.popen('date').read().strip()}\n")
    
    print(f"分析完成，结果已保存到 {output_file}")

if __name__ == "__main__":
    main()