{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "import glob\n",
    "import csv\n",
    "import statistics\n",
    "from pathlib import Path\n",
    "\n",
    "def process_single_file(file_path):\n",
    "    \"\"\"处理单个文件并返回指标原始数据\"\"\"\n",
    "    with open(file_path, 'r') as f:\n",
    "        content = f.read()\n",
    "    \n",
    "    metrics = {\n",
    "        'A_auc': [],\n",
    "        'A_avg': [],\n",
    "        'A_last': [],\n",
    "        'F_last': []\n",
    "    }\n",
    "    \n",
    "    # 分割并解析所有Summary块\n",
    "    summaries = re.split(r'=+ Summary =+', content)\n",
    "    for block in summaries:\n",
    "        lines = block.strip().split('\\n')\n",
    "        for line in lines:\n",
    "            if line.startswith('A_auc'):\n",
    "                # 提取指标数值\n",
    "                parts = re.findall(r'([A-Za-z_]+) (\\d+\\.\\d+)', line)\n",
    "                for key, value in parts:\n",
    "                    if key in metrics:\n",
    "                        metrics[key].append(float(value)*100)\n",
    "    return metrics\n",
    "\n",
    "def batch_process_files(pattern, csv_path=\"results.csv\"):\n",
    "    \"\"\"批量处理文件并生成统计报告\"\"\"\n",
    "    all_files = glob.glob(pattern)\n",
    "    csv_data = []\n",
    "    headers = [\n",
    "        \"Filename\", \n",
    "        \"A_auc_mean\", \"A_auc_std\",\n",
    "        \"A_avg_mean\", \"A_avg_std\",\n",
    "        \"A_last_mean\", \"A_last_std\",\n",
    "        \"F_last_mean\", \"F_last_std\"\n",
    "    ]\n",
    "\n",
    "    # 全局数据收集器\n",
    "    global_collector = {\n",
    "        'A_auc_means': [],\n",
    "        'A_avg_means': [],\n",
    "        'A_last_means': [],\n",
    "        'F_last_means': []\n",
    "    }\n",
    "\n",
    "    for file_path in all_files:\n",
    "        try:\n",
    "            filename = Path(file_path).name\n",
    "            raw_metrics = process_single_file(file_path)\n",
    "            \n",
    "            # 计算单文件统计量\n",
    "            file_stats = {}\n",
    "            for metric in raw_metrics:\n",
    "                values = raw_metrics[metric]\n",
    "                \n",
    "                # 计算均值\n",
    "                mean = sum(values)/len(values) if values else 0.0\n",
    "                \n",
    "                # 计算标准差\n",
    "                try:\n",
    "                    std = statistics.stdev(values) if len(values) >=2 else 0.0\n",
    "                except:\n",
    "                    std = 0.0\n",
    "                \n",
    "                # 存储结果\n",
    "                file_stats[f\"{metric}_mean\"] = mean\n",
    "                file_stats[f\"{metric}_std\"] = std\n",
    "                \n",
    "                # 收集全局数据\n",
    "                global_collector[f\"{metric}_means\"].append(mean)\n",
    "\n",
    "            # 构建CSV行\n",
    "            csv_row = {\n",
    "                \"Filename\": filename,\n",
    "                \"A_auc_mean\": round(file_stats['A_auc_mean'], 4),\n",
    "                \"A_auc_std\": round(file_stats['A_auc_std'], 4),\n",
    "                \"A_avg_mean\": round(file_stats['A_avg_mean'], 4),\n",
    "                \"A_avg_std\": round(file_stats['A_avg_std'], 4),\n",
    "                \"A_last_mean\": round(file_stats['A_last_mean'], 4),\n",
    "                \"A_last_std\": round(file_stats['A_last_std'], 4),\n",
    "                \"F_last_mean\": round(file_stats['F_last_mean'], 4),\n",
    "                \"F_last_std\": round(file_stats['F_last_std'], 4)\n",
    "            }\n",
    "            csv_data.append(csv_row)\n",
    "\n",
    "        except Exception as e:\n",
    "            print(f\"处理文件 {file_path} 出错: {str(e)}\")\n",
    "            continue\n",
    "\n",
    "    # 写入CSV文件\n",
    "    with open(csv_path, 'w', newline='') as f:\n",
    "        writer = csv.DictWriter(f, fieldnames=headers)\n",
    "        writer.writeheader()\n",
    "        writer.writerows(csv_data)\n",
    "    \n",
    "    # 计算全局统计\n",
    "    global_stats = {}\n",
    "    for metric in ['A_auc', 'A_avg', 'A_last', 'F_last']:\n",
    "        means = global_collector[f\"{metric}_means\"]\n",
    "        if means:\n",
    "            global_mean = sum(means)/len(means)\n",
    "            try:\n",
    "                global_std = statistics.stdev(means)\n",
    "            except:\n",
    "                global_std = 0.0\n",
    "        else:\n",
    "            global_mean = global_std = 0.0\n",
    "        \n",
    "        global_stats[f\"{metric}_global_mean\"] = global_mean\n",
    "        global_stats[f\"{metric}_global_std\"] = global_std\n",
    "\n",
    "    # # 打印全局报告\n",
    "    # print(\"\\n=== 全局统计 ===\")\n",
    "    # for metric in ['A_auc', 'A_avg', 'A_last', 'F_last']:\n",
    "    #     print(f\"{metric}:\")\n",
    "    #     print(f\"  平均值（跨文件）: {global_stats[f'{metric}_global_mean']:.4f}\")\n",
    "    #     print(f\"  标准差（跨文件）: {global_stats[f'{metric}_global_std']:.4f}\")\n",
    "\n",
    "    return csv_data, global_stats\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    # 示例：处理所有匹配文件\n",
    "    # root_path = '/home/sunguanglong/DGIL/MISA/results/output/vit_others'\n",
    "    # root_path = '/home/sunguanglong/DGIL/MISA/results/output/vit_in21k'\n",
    "    root_path = '/home/sunguanglong/DGIL/MISA/results/output/vit_fly'\n",
    "    save_path = 'results.csv'\n",
    "    results, global_stats = batch_process_files(f\"{root_path}/*.txt\", save_path)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
