{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "os.chdir(\"../\")\n",
    "print(os.getcwd())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "from pathlib import Path\n",
    "import json\n",
    "from typing import List, Union, Iterable\n",
    "from pprint import pprint\n",
    "import seaborn as sns\n",
    "\n",
    "\n",
    "sns.set_theme(style=\"whitegrid\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def search_json(search_json_dir):\n",
    "    if isinstance(search_json_dir, str):\n",
    "        search_json_dir = [search_json_dir]\n",
    "    search_json_dir = [Path(_dir) for _dir in search_json_dir]\n",
    "\n",
    "    json_path_list = []\n",
    "    # recursively find\n",
    "    for _dir in search_json_dir:\n",
    "        json_path_list.extend(list(_dir.rglob(\"metrics.json\")))\n",
    "\n",
    "    if len(json_path_list) == 0:\n",
    "        raise FileNotFoundError(f\"No json file found in the search_json_dir: {search_json_dir}\")\n",
    "    return json_path_list\n",
    "\n",
    "\n",
    "def get_clean_exp_name(exp_name):\n",
    "    clean_exp_name = exp_name\n",
    "    clean_exp_name = clean_exp_name.split(\"-\")[0]\n",
    "\n",
    "    clean_exp_name = clean_exp_name.split(\"_\")[-1]\n",
    "    try:\n",
    "        clean_exp_name = int(clean_exp_name)\n",
    "    except ValueError:\n",
    "        clean_exp_name = exp_name\n",
    "\n",
    "    return clean_exp_name\n",
    "\n",
    "\n",
    "def load_metrics_to_df(json_path_list, use_clean_exp_name=False):\n",
    "    metrics_dict_list = []\n",
    "\n",
    "    for json_path in json_path_list:\n",
    "        # ignore \"version_/metrics.json\"\n",
    "        exp_name = json_path.parents[1].name\n",
    "        if use_clean_exp_name is True:\n",
    "            clean_exp_name = get_clean_exp_name(exp_name)\n",
    "        else:\n",
    "            clean_exp_name = exp_name\n",
    "\n",
    "        metrics_dict = {\n",
    "            \"path\": json_path,\n",
    "            \"exp_name\": exp_name,\n",
    "            \"clean_exp_name\": clean_exp_name,\n",
    "        }\n",
    "\n",
    "        with open(json_path, \"r\") as f:\n",
    "            metrics_dict_ = json.load(f)\n",
    "        metrics_dict.update({k: v[\"accuracy\"] for k, v in metrics_dict_.items()})\n",
    "        metrics_dict_list.append(metrics_dict)\n",
    "\n",
    "    df = pd.DataFrame(metrics_dict_list)\n",
    "    df.sort_values(\"clean_exp_name\", inplace=True)\n",
    "    return df\n",
    "\n",
    "\n",
    "def gather_and_plot_metrics(\n",
    "    search_json_dir,\n",
    "    use_clean_exp_name=False,\n",
    "    *args,\n",
    "    **kwargs,\n",
    "):\n",
    "    json_path_list = search_json(search_json_dir)\n",
    "    df = load_metrics_to_df(json_path_list, use_clean_exp_name)\n",
    "\n",
    "    return df\n",
    "\n",
    "\n",
    "def plot_df(\n",
    "    df,\n",
    "    ax,\n",
    "    plot_title=None,\n",
    "    plot_xlabel=None,\n",
    "    plot_ylabel=None,\n",
    "    plot_xlim=None,\n",
    "    plot_ylim=None,\n",
    "    legend=True,\n",
    "    *args,\n",
    "    **kwargs,\n",
    "):\n",
    "    plot_df = df.drop([\"path\", \"exp_name\"], axis=1)\n",
    "    plot_df.set_index(\"clean_exp_name\", inplace=True)\n",
    "\n",
    "    sns.lineplot(data=plot_df, linewidth=2.5, markers=True, ax=ax, legend=legend)\n",
    "\n",
    "    if legend:\n",
    "        sns.move_legend(ax, \"upper left\", bbox_to_anchor=(-0.6, 1))\n",
    "    if plot_title is not None:\n",
    "        ax.set_title(plot_title)\n",
    "    if plot_xlabel is not None:\n",
    "        ax.set_xlabel(plot_xlabel)\n",
    "    if plot_ylabel is not None:\n",
    "        ax.set_ylabel(plot_ylabel)\n",
    "\n",
    "    if plot_xlim is not None:\n",
    "        ax.set_xlim(*plot_xlim)\n",
    "\n",
    "    if plot_ylim is not None:\n",
    "        y_bottom, y_top = plot_ylim\n",
    "        y_bottom = min(y_bottom, plot_df.min().min())\n",
    "        y_top = max(y_top, plot_df.max().max())\n",
    "        ax.set_ylim(y_bottom, y_top)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig_title = \"Thinking Budget, 1k SFT Data, 7B Model\"\n",
    "exp_arg_list = [\n",
    "    {\n",
    "        \"search_json_dir\": \"outputs/250318-eval-medical_llm\",\n",
    "        \"plot_title\": \"Epoch 5\",\n",
    "    },\n",
    "]\n",
    "exp_shared_arg = {\n",
    "\"plot_xlabel\": \"Thinking Budget (in # Tokens)\",\n",
    "\"plot_ylabel\": \"Accuracy\",\n",
    "\"plot_ylim\": (0.45, 0.78),\n",
    "# \"plot_xlim\": (0, 7000),\n",
    "\"use_clean_exp_name\": False,\n",
    "}\n",
    "\n",
    "\n",
    "df_list = []\n",
    "for exp_arg in exp_arg_list:\n",
    "    exp_arg.update(exp_shared_arg)\n",
    "    df = gather_and_plot_metrics(**exp_arg)\n",
    "    df_list.append(df)\n",
    "all_df = pd.concat(df_list)\n",
    "all_df\n",
    "output_path = f\"outputs/{fig_title.replace('/', '_')}.tsv\"\n",
    "all_df.to_csv(output_path, sep=\"\\t\", index=False)\n",
    "print(f\"Saved to {output_path}\")\n",
    "display(all_df)\n",
    "\n",
    "\n",
    "num_plots = len(df_list)\n",
    "fig, axes = plt.subplots(1, num_plots, figsize=(7 * num_plots, 6))\n",
    "if not isinstance(axes, Iterable):\n",
    "    axes = [axes]\n",
    "for idx, (df, ax, exp_arg) in enumerate(zip(df_list, axes, exp_arg_list)):\n",
    "    if idx != 0:\n",
    "        exp_arg[\"legend\"] = False\n",
    "    else:\n",
    "        exp_arg[\"legend\"] = True\n",
    "    plot_df(df=df, ax=ax, **exp_arg)\n",
    "fig.suptitle(fig_title)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "anonymous-med_sipf-1",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
