import json
from collections import defaultdict
import math
import os
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np
import csv  # 引入CSV模块


# --- 核心计算函数 ---
# 这些函数在每个 data_source 内部被调用
def calculate_brier_score(samples):
    """计算 Brier Score"""
    if not samples: return 0.0
    n = len(samples)
    squared_diff_sum = sum((s['confidence'] - s['accuracy']) ** 2 for s in samples)
    return squared_diff_sum / n


def calculate_ece(samples, num_bins=10):
    """计算 Expected Calibration Error (ECE)"""
    if not samples: return 0.0
    n = len(samples)
    bins = defaultdict(list)
    for sample in samples:
        bin_index = min(int(sample['confidence'] * num_bins), num_bins - 1)
        bins[bin_index].append(sample)
    ece = 0.0
    for m in range(num_bins):
        if not bins[m]: continue
        bin_samples = bins[m]
        num_in_bin = len(bin_samples)
        avg_accuracy_in_bin = sum(s['accuracy'] for s in bin_samples) / num_in_bin
        avg_confidence_in_bin = sum(s['confidence'] for s in bin_samples) / num_in_bin
        weight = num_in_bin / n
        ece += weight * abs(avg_accuracy_in_bin - avg_confidence_in_bin)
    return ece


# --- 用于CSV导出的、计算每个数据源指标的函数 ---
def calculate_per_source_metrics(json_path, num_bins_for_ece=10):
    """为CSV导出计算每个数据源的详细指标。"""
    try:
        with open(json_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
    except (FileNotFoundError, json.JSONDecodeError):
        return None
    if not data: return {}

    grouped_samples = defaultdict(list)
    for sample in data:
        try:
            grouped_samples[sample['data_source']].append(sample)
        except KeyError:
            continue

    source_metrics = {}
    for source, samples in grouped_samples.items():
        if not samples: continue  # 添加安全检查
        avg_acc = sum(s['accuracy'] for s in samples) / len(samples)
        brier = calculate_brier_score(samples)
        ece = calculate_ece(samples, num_bins=num_bins_for_ece)
        source_metrics[source] = {'acc': avg_acc, 'brier': brier, 'ece': ece}

    return source_metrics

def export_summary_csv_at_steps_revised(experiments, base_dir, steps_to_summarize, output_filename, num_bins=10):
    """
    为所有方法在指定的step(s)上生成一个总结性的CSV报告，采用两阶段聚合逻辑。
    计算逻辑:
    1.  对于方法的单次运行，首先对其在所有指定 `steps` 上的性能指标按数据源求均值，
        得到该次运行的“最终性能”。
    2.  然后，在方法的所有独立运行之间，计算这些“最终性能”的均值和样本标准差，
        作为表格中最终报告的值。
    """
    if isinstance(steps_to_summarize, int):
        steps_to_summarize = [steps_to_summarize]

    steps_str = "-".join(map(str, steps_to_summarize))
    print(f"\n{'=' * 20} 开始生成总结性报告 (Steps: {steps_str}) {'=' * 20}")
    print(f"计算逻辑: 1. 单次运行内跨Step求均值 -> 2. 多次运行间求最终均值和样本标准差")

    all_methods_summary_data = []
    all_data_sources = set()

    for method_name, dir_list in experiments.items():
        run_final_metrics = []
        for dir_path in dir_list:
            run_metrics_across_steps = defaultdict(lambda: defaultdict(list))
            for step in steps_to_summarize:
                f_path = os.path.join(dir_path, f'validation_step_{step}.json')
                if os.path.exists(f_path):
                    per_source_metrics = calculate_per_source_metrics(f_path, num_bins)
                    if per_source_metrics:
                        for source, metrics in per_source_metrics.items():
                            all_data_sources.add(source)
                            run_metrics_across_steps[source]['acc'].append(metrics['acc'])
                            run_metrics_across_steps[source]['brier'].append(metrics['brier'])
                            run_metrics_across_steps[source]['ece'].append(metrics['ece'])
                else:
                    print(f"  - 警告: 在路径 {f_path} 中未找到任何指定step的数据，跳过此次运行。")
            if not run_metrics_across_steps:
                print(f"  - 警告: 在路径 {dir_path} 中未找到任何指定step的数据，跳过此次运行。")
                continue

            final_metrics_for_this_run = {}
            for source, metrics_dict in run_metrics_across_steps.items():
                final_metrics_for_this_run[source] = {
                    'acc': np.mean(metrics_dict['acc']) if metrics_dict['acc'] else np.nan,
                    'brier': np.mean(metrics_dict['brier']) if metrics_dict['brier'] else np.nan,
                    'ece': np.mean(metrics_dict['ece']) if metrics_dict['ece'] else np.nan,
                }
            run_final_metrics.append(final_metrics_for_this_run)
        print(method_name)
        # print(run_final_metrics)
        if not run_final_metrics:
            print(f"  - 错误: 方法 '{method_name}' 未能收集到任何有效运行的数据。")
            continue

        summary_row = {'Method': method_name}
        sorted_sources = sorted(list(all_data_sources))
        for source in sorted_sources:
            acc_vals = [run[source].get('acc', np.nan) for run in run_final_metrics if source in run]
            brier_vals = [run[source].get('brier', np.nan) for run in run_final_metrics if source in run]
            ece_vals = [run[source].get('ece', np.nan) for run in run_final_metrics if source in run]

            summary_row.update({
                f'{source}_mean_acc': np.nanmean(acc_vals), f'{source}_std_acc': np.nanstd(acc_vals, ddof=1),
                # <--- MODIFIED
                f'{source}_mean_bs': np.nanmean(brier_vals), f'{source}_std_bs': np.nanstd(brier_vals, ddof=1),
                # <--- MODIFIED
                f'{source}_mean_ece': np.nanmean(ece_vals), f'{source}_std_ece': np.nanstd(ece_vals, ddof=1)
                # <--- MODIFIED
            })

        macro_avg_acc_per_run = []
        macro_avg_brier_per_run = []
        macro_avg_ece_per_run = []

        for run_data in run_final_metrics:
            accs_this_run = [metrics.get('acc', np.nan) for source, metrics in run_data.items() if
                             source in sorted_sources]
            briers_this_run = [metrics.get('brier', np.nan) for source, metrics in run_data.items() if
                               source in sorted_sources]
            eces_this_run = [metrics.get('ece', np.nan) for source, metrics in run_data.items() if
                             source in sorted_sources]
            macro_avg_acc_per_run.append(np.nanmean(accs_this_run))
            macro_avg_brier_per_run.append(np.nanmean(briers_this_run))
            macro_avg_ece_per_run.append(np.nanmean(eces_this_run))

        summary_row.update({
            'MACRO_AVG_ACC_mean': np.nanmean(macro_avg_acc_per_run),
            'MACRO_AVG_ACC_std': np.nanstd(macro_avg_acc_per_run, ddof=1),  # <--- MODIFIED
            'MACRO_AVG_BS_mean': np.nanmean(macro_avg_brier_per_run),
            'MACRO_AVG_BS_std': np.nanstd(macro_avg_brier_per_run, ddof=1),  # <--- MODIFIED
            'MACRO_AVG_ECE_mean': np.nanmean(macro_avg_ece_per_run),
            'MACRO_AVG_ECE_std': np.nanstd(macro_avg_ece_per_run, ddof=1),  # <--- MODIFIED
        })
        all_methods_summary_data.append(summary_row)

    if not all_methods_summary_data:
        print("错误: 未能生成任何总结性数据。")
        return

    header = ['Method']
    if all_data_sources:
        for source in sorted(list(all_data_sources)):
            header.extend([f'{source}_mean_acc', f'{source}_std_acc', f'{source}_mean_bs', f'{source}_std_bs',
                           f'{source}_mean_ece', f'{source}_std_ece'])
        header.extend(
            ['MACRO_AVG_ACC_mean', 'MACRO_AVG_ACC_std', 'MACRO_AVG_BS_mean', 'MACRO_AVG_BS_std', 'MACRO_AVG_ECE_mean',
             'MACRO_AVG_ECE_std'])

    csv_output_dir = os.path.join(base_dir, 'csv_reports')
    os.makedirs(csv_output_dir, exist_ok=True)
    full_output_path = os.path.join(csv_output_dir, output_filename)
    try:
        with open(full_output_path, 'w', newline='', encoding='utf-8-sig') as f:
            writer = csv.DictWriter(f, fieldnames=header, restval='N/A')
            writer.writeheader()
            for row_data in all_methods_summary_data:
                formatted_row = {k: (f"{v:.4f}" if isinstance(v, (float, np.floating)) and not np.isnan(v) else 'N/A')
                                 for k, v in row_data.items()}
                formatted_row['Method'] = row_data['Method']
                writer.writerow(formatted_row)
        print(f"✅ 成功将总结性报告保存到: {full_output_path}")
    except IOError as e:
        print(f"❌ 写入文件 '{full_output_path}' 时出错: {e}")


if __name__ == "__main__":
    base_dir = 'training_logs_valid'

    # 【重要】在这里定义您的实验，每个方法对应一个包含多次运行目录的列表
    experiments_to_analyze = {
        'CCGSPG': [
            os.path.join(base_dir, 'Qwen2.5-1.5B-Instruct_45epochs'),
        ],
    }
    for i in range(0, 406, 5):
        export_summary_csv_at_steps_revised(experiments_to_analyze, base_dir,
                                steps_to_summarize=i,
                                output_filename=f"summary_report_step_{i}.csv")

    #
    # export_summary_csv_at_steps_revised(experiments_to_analyze, base_dir,
    #                             steps_to_summarize=[80, 90, 100],
    #                             output_filename="summary_report_steps_80_90_100_avg_all.csv")