{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "def show_acc_by_type(dir_name):\n",
    "    file_list = [file for file in os.listdir(dir_name) if os.path.isdir(os.path.join(dir_name, file))]\n",
    "    step_list = [int(file.split('-')[-1]) for file in file_list]\n",
    "    step_list.sort()\n",
    "    base_name = file_list[0].split('-')[0]\n",
    "    data = {}\n",
    "    for step in step_list:\n",
    "        data[step] = {}\n",
    "        file_name = os.listdir(os.path.join(dir_name, base_name + '-' + str(step)))[0]\n",
    "        with open(os.path.join(os.path.join(dir_name, base_name + '-' + str(step)), file_name), 'r') as f:\n",
    "            lines = f.readlines()\n",
    "            for i in range(0, len(lines)):\n",
    "                if lines[i].startswith('INFO:root:total accuracy'):\n",
    "                    start = i\n",
    "                    break\n",
    "            data[step]['total_acc'] = float(lines[start].split(':')[-1])\n",
    "            data[step]['fully_acc'] = float(lines[start + 1].split(':')[-1])\n",
    "            data[step]['highly_acc'] = float(lines[start + 2].split(':')[-1])\n",
    "            data[step]['weakly_acc'] = float(lines[start + 3].split(':')[-1])\n",
    "            data[step]['unknown_acc'] = float(lines[start + 4].split(':')[-1])\n",
    "    \n",
    "    steps = list(data.keys())\n",
    "    total_acc = [data[step]['total_acc'] for step in steps]\n",
    "    fully_acc = [data[step]['fully_acc'] for step in steps]\n",
    "    highly_acc = [data[step]['highly_acc'] for step in steps]\n",
    "    weakly_acc = [data[step]['weakly_acc'] for step in steps]\n",
    "    unknown_acc = [data[step]['unknown_acc'] for step in steps]\n",
    "\n",
    "    x_labels = [x for x in range(1, len(step_list) + 1)]\n",
    "    bar_width = 0.15  # 设置柱状图的宽度\n",
    "    index = np.arange(len(x_labels))\n",
    "\n",
    "    plt.figure(figsize=(12, 8))\n",
    "    plt.bar([i - 2 * bar_width for i in index], total_acc, width=bar_width, label='Total Accuracy')\n",
    "    plt.bar([i - bar_width for i in index], fully_acc, width=bar_width, label='Fully Accuracy')\n",
    "    plt.bar(index, highly_acc, width=bar_width, label='Highly Accuracy')\n",
    "    plt.bar([i + bar_width for i in index], weakly_acc, width=bar_width, label='Weakly Accuracy')\n",
    "    plt.bar([i + 2 * bar_width for i in index], unknown_acc, width=bar_width, label='Unknown Accuracy')\n",
    "\n",
    "    # 添加连接柱状图顶点的线\n",
    "    plt.plot([i - 2 * bar_width for i in index], total_acc, color='skyblue', marker='o', linestyle='-', linewidth=2, markersize=8)\n",
    "    plt.plot([i - bar_width for i in index], fully_acc, color='lightgreen', marker='o', linestyle='-', linewidth=2, markersize=8)\n",
    "    plt.plot(index, highly_acc, color='salmon', marker='o', linestyle='-', linewidth=2, markersize=8)\n",
    "    plt.plot([i + bar_width for i in index], weakly_acc, color='gold', marker='o', linestyle='-', linewidth=2, markersize=8)\n",
    "    plt.plot([i + 2 * bar_width for i in index], unknown_acc, color='violet', marker='o', linestyle='-', linewidth=2, markersize=8)\n",
    "\n",
    "    plt.xlabel('Steps')\n",
    "    plt.ylabel('Accuracy')\n",
    "    plt.title('Accuracy by Type')\n",
    "    plt.xticks(index, x_labels)\n",
    "    plt.legend()\n",
    "    plt.grid(True)\n",
    "    plt.show()\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "def show_max_acc():\n",
    "    policy_list = [2, 3, 4, 5, 6, 7, 8]\n",
    "    data = {}\n",
    "    for policy in policy_list:\n",
    "        dir_name = \"\" + str(policy) + \"-epoch-8\"\n",
    "        data[policy] = {}\n",
    "        file_list = [file for file in os.listdir(dir_name) if os.path.isdir(os.path.join(dir_name, file))]\n",
    "        step_list = [int(file.split('-')[-1]) for file in file_list]\n",
    "        step_list.sort()\n",
    "        base_name = file_list[0].split('-')[0]\n",
    "        max_acc = 0\n",
    "        max_epoch = 0\n",
    "        for i in range(0, len(step_list)):\n",
    "            step = step_list[i]\n",
    "            file_name = os.listdir(os.path.join(dir_name, base_name + '-' + str(step)))[0]\n",
    "            with open(os.path.join(os.path.join(dir_name, base_name + '-' + str(step)), file_name), 'r') as f:\n",
    "                lines = f.readlines()\n",
    "                for j in range(0, len(lines)):\n",
    "                    if lines[j].startswith('INFO:root:total accuracy'):\n",
    "                        start = j\n",
    "                        break\n",
    "                acc = float(lines[start].split(':')[-1])\n",
    "                if acc > max_acc:\n",
    "                    max_acc = acc\n",
    "                    max_epoch = i + 1\n",
    "        data[policy]['max_acc'] = max_acc\n",
    "        data[policy]['max_epoch'] = max_epoch\n",
    "\n",
    "    # 将数据转换为 DataFrame\n",
    "    df = pd.DataFrame(data).T\n",
    "    df.index.name = 'Policy'\n",
    "    df.columns = ['Max Accuracy', 'Max Epoch']\n",
    "\n",
    "    # 绘制表格\n",
    "    fig, ax = plt.subplots(figsize=(10, 6))\n",
    "    ax.axis('tight')\n",
    "    ax.axis('off')\n",
    "    table = ax.table(cellText=df.values, colLabels=df.columns, rowLabels=df.index, cellLoc='center', loc='center')\n",
    "    table.auto_set_font_size(False)\n",
    "    table.set_fontsize(12)\n",
    "    table.scale(1.2, 1.2)\n",
    "\n",
    "    plt.title('Max Accuracy and Epoch by Policy')\n",
    "    plt.show()\n",
    "\n",
    "show_max_acc()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "show_acc_by_type('')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "qwen-sft",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
