{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e5cd5a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.chdir(\"../\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ce674e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import collections\n",
    "import json\n",
    "from pathlib import Path\n",
    "from types import SimpleNamespace\n",
    "\n",
    "import click\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0671798",
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_results_jsonl(results_jsonl):\n",
    "    results_jsonl = Path(results_jsonl)\n",
    "\n",
    "    ds_metrics_records = []\n",
    "    with open(results_jsonl, \"r\") as f:\n",
    "        for line in (f):\n",
    "            data = json.loads(line)\n",
    "\n",
    "            dataset_name = data['dataset_name']\n",
    "            average_num_correct = data['average_num_correct']\n",
    "            pass_at_num_rollouts = data['pass_at_num_rollouts']\n",
    "            majority_at_num_rollouts = data['majority_at_num_rollouts']\n",
    "\n",
    "            ds_metrics_records.append({\n",
    "                \"dataset_name\": dataset_name,\n",
    "                \"average_num_correct\": average_num_correct,\n",
    "                \"majority_at_num_rollouts\": majority_at_num_rollouts,\n",
    "                \"pass_at_num_rollouts\": pass_at_num_rollouts,\n",
    "            })\n",
    "\n",
    "    df = pd.DataFrame(ds_metrics_records)\n",
    "    df_mean = df.groupby(\"dataset_name\").mean().reset_index()\n",
    "    # convert it to percentage\n",
    "    df_mean['average_num_correct'] = df_mean['average_num_correct'] * 100\n",
    "    df_mean['pass_at_num_rollouts'] = df_mean['pass_at_num_rollouts'] * 100\n",
    "    df_mean['majority_at_num_rollouts'] = df_mean['majority_at_num_rollouts'] * 100\n",
    "    model_name = results_jsonl.parent.name\n",
    "    df_mean['model_name'] = model_name\n",
    "    return df_mean\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9f4e5e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_dir = \"outputs/greedy\"\n",
    "base_dir = Path(base_dir)\n",
    "# get all regraded_eval_results.jsonl\n",
    "def get_all_results_jsonl(base_dir):\n",
    "    results_jsonl_files = list(base_dir.glob(\"**/regraded_eval_results.jsonl\"))\n",
    "    if not results_jsonl_files:\n",
    "        raise ValueError(f\"No regraded_eval_results.jsonl files found in {base_dir}\")\n",
    "    return results_jsonl_files\n",
    "\n",
    "results_jsonl_files = get_all_results_jsonl(base_dir)\n",
    "results_jsonl_files[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28323d9c",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_list = []\n",
    "for results_jsonl in tqdm.tqdm(results_jsonl_files):\n",
    "    df = read_results_jsonl(results_jsonl)\n",
    "\n",
    "    # outputs/greedy/v0/7b-m23k-checkpoint-4401/regraded_eval_results.jsonl\n",
    "    version = results_jsonl.parent.parent.name\n",
    "    df[\"version\"] = version\n",
    "    df_list.append(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c458bfc6",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_all = pd.concat(df_list, ignore_index=True)\n",
    "\n",
    "mean_std_records = []\n",
    "# get mean and std, group by model_name, dataset_name\n",
    "for group, sub_df in df_all.groupby([\"model_name\", \"dataset_name\"]):\n",
    "    mean = sub_df[\"average_num_correct\"].mean()\n",
    "    std = sub_df[\"average_num_correct\"].std()\n",
    "    mean_std_records.append({\n",
    "        \"model_name\": group[0],\n",
    "        \"dataset_name\": group[1],\n",
    "        \"mean_average_num_correct\": mean,\n",
    "        \"std_average_num_correct\": std,\n",
    "    })"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "051f9373",
   "metadata": {},
   "outputs": [],
   "source": [
    "mean_std_df = pd.DataFrame(mean_std_records)\n",
    "mean_std_df\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa8f7877",
   "metadata": {},
   "outputs": [],
   "source": [
    "def reorder_rows(df, rows, append_unseen_rows=False):\n",
    "    \"\"\"\n",
    "    Reorder the rows of a DataFrame based on a given list of row names.\n",
    "    \"\"\"\n",
    "    # append unseen rows to the end\n",
    "    if append_unseen_rows:\n",
    "        unseen_rows =  df.index.difference(rows)\n",
    "        rows.extend(unseen_rows)\n",
    "    df = df.reindex(rows)\n",
    "    return df\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fbcf8a98",
   "metadata": {},
   "outputs": [],
   "source": [
    "# convert mean_std_df (model_name, dataset_name, mean_average_num_correct) to a pivot table\n",
    "mean_df = mean_std_df.pivot_table(\n",
    "    index=\"model_name\",\n",
    "    columns=\"dataset_name\",\n",
    "    values=\"mean_average_num_correct\",\n",
    ")\n",
    "rows = [\n",
    "    # 3b\n",
    "    \"Qwen2.5-VL-3B-Instruct\",\n",
    "    \"3b-m23k-checkpoint-4401\",\n",
    "    \"3b-pmc_vqa-checkpoint-12594\",\n",
    "    \"train-qwen2_5_vl_3b-pmc_vqa-m23k_sft_epoch_3-step_1150\",\n",
    "    \"train-qwen2_5_vl_3b-m23k-step_320\",\n",
    "    \"train-qwen2_5_vl_3b-pmc_vqa-step_451\",\n",
    "    \"train-qwen2_5_vl_3b-pmc_vqa-m23k_rl-step_1805\",\n",
    "    # 7b\n",
    "    \"Qwen2.5-VL-7B-Instruct\",\n",
    "    \"7b-m23k-checkpoint-4401\",\n",
    "    \"7b-pmc_vqa-checkpoint-12594\",\n",
    "    \"train-qwen2_5_vl_7b-pmc_vqa-m23k_sft_epoch_3-step_1805\",\n",
    "    \"train-qwen2_5_vl_7b-m23k-step_320\",\n",
    "    \"train-qwen2_5_vl_7b-pmc_vqa-step_451\",\n",
    "    \"train-qwen2_5_vl_7b-pmc_vqa-m23k_rl-step_1805\",\n",
    "    # \n",
    "    \"HuatuoGPT-Vision-7B-Qwen2.5VL\",\n",
    "    \"Qwen2.5-VL-32B-Instruct\",\n",
    "    \"train-qwen2_5_vl_32b-m23k-step_645\",\n",
    "    # others\n",
    "    \"llava-med-v1.5-mistral-7b-hf\",\n",
    "    \"HuatuoGPT-Vision-7B-hf\",\n",
    "    \"HuatuoGPT-Vision-34B-hf\",\n",
    "    \"medgemma-4b-it\",\n",
    "    \"medgemma-27b-it\",\n",
    "    \"gemma-3-4b-it\",\n",
    "    \"gemma-3-27b-it\",\n",
    "]\n",
    "mean_df = reorder_rows(mean_df, rows, False)\n",
    "mean_df_path = \"outputs/mean_average_num_correct.tsv\"\n",
    "mean_df.to_csv(mean_df_path, sep=\"\\t\")\n",
    "print(f\"Mean average number correct saved to {mean_df_path}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "jupyter",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
