import os
import numpy as np
import glob
import argparse
from collections import defaultdict

# 命令行参数解析
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, required=True, choices=['Diabetes', 'Adults'])
parser.add_argument('--output_dir', type=str, default='Aggregated_Results')
args = parser.parse_args()

trial_dir = f"{args.dataset}_Trials"
os.makedirs(args.output_dir, exist_ok=True)

# 收集所有 CSV 文件
trial_files = glob.glob(os.path.join(trial_dir, '*.csv'))
if not trial_files:
    raise RuntimeError(f"No .csv files found in {trial_dir}")

# 按 (k, lambda) 分组
grouped = defaultdict(list)
for path in trial_files:
    basename = os.path.basename(path)
    parts = basename.replace('.csv', '').split('-')
    k_part = next(p for p in parts if p.startswith('k='))
    l_part = next(p for p in parts if p.startswith('lambda='))
    key = (k_part, l_part)
    grouped[key].append(path)

# 每组文件：合并、计算均值与置信区间
for (k_str, l_str), files in grouped.items():
    all_data = [np.loadtxt(f, delimiter=',') for f in files]
    arr = np.vstack(all_data)  # shape = (num_trials, num_metrics)
    mean = arr.mean(axis=0)
    ci95 = 1.96 * arr.std(axis=0) / np.sqrt(len(files))
    out = np.vstack([mean, ci95])
    out_path = os.path.join(args.output_dir, f"{k_str}-{l_str}-{args.dataset}.csv")
    np.savetxt(out_path, out, delimiter=',', fmt='%.6f')
    print(f"Saved: {out_path}")