{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import os\n",
    "import json\n",
    "import csv\n",
    "import pandas as pd\n",
    "\n",
    "base_score_dir = '/home/ubuntu/MMSci/mmsci-exps/eval/eval_scores/image_caption_generation/textgen'\n",
    "target_score_dir = '/home/ubuntu/MMSci/mmsci-exps/eval/eval_scores/image_caption_generation'\n",
    "csv_dir = os.path.join(target_score_dir, 'csv')\n",
    "\n",
    "if not os.path.exists(csv_dir):\n",
    "    os.makedirs(csv_dir, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_list = ['blip2', 'kosmos2', 'qwen', 'llava', 'llava-next', 'llava-next-mistral', 'gpt-4-turbo', 'gpt-4o', 'llava-next-mmsci']\n",
    "metric_list = ['BLEU-1', 'BLEU-2', 'BLEU-3', 'BLEU-4', 'METEOR', 'rougeL', 'BERTSCORE', 'CLIPScore', 'RefCLIPScore']\n",
    "fields = ['model']\n",
    "for metric in metric_list:\n",
    "    fields.extend([f'{metric} [mean]', f'{metric} [std]'])\n",
    "\n",
    "for w_abs in [False, True]:\n",
    "    for w_ctx in [False, True]:\n",
    "        if w_abs and w_ctx:\n",
    "            continue\n",
    "        tag = f'abs{w_abs}_ctx{w_ctx}'\n",
    "\n",
    "        with open(os.path.join(csv_dir, f'{tag}.csv'), 'w') as csvfile:\n",
    "            writer = csv.DictWriter(csvfile, fieldnames=fields)\n",
    "            writer.writeheader()\n",
    "            score_dict = []\n",
    "\n",
    "            for model in model_list:\n",
    "                filename = os.path.join(base_score_dir, tag, f'{model}.json')\n",
    "                if not os.path.exists(filename):\n",
    "                    continue\n",
    "                all_scores = json.load(open(filename))\n",
    "                info = {'model': model}\n",
    "                for metric in metric_list:\n",
    "                    if metric in all_scores:\n",
    "                        scores = all_scores[metric]\n",
    "                        mean = np.mean(scores)\n",
    "                        std = np.std(scores)\n",
    "                        info[f'{metric} [mean]'] = mean\n",
    "                        info[f'{metric} [std]'] = std\n",
    "                    else:\n",
    "                        info[f'{metric} [mean]'] = 0.0\n",
    "                        info[f'{metric} [std]'] = 0.0\n",
    "                score_dict.append(info)\n",
    "\n",
    "            writer.writerows(score_dict)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Index(['model', 'BLEU-1 [mean]', 'BLEU-1 [std]', 'BLEU-2 [mean]',\n",
      "       'BLEU-2 [std]', 'BLEU-3 [mean]', 'BLEU-3 [std]', 'BLEU-4 [mean]',\n",
      "       'BLEU-4 [std]', 'METEOR [mean]', 'METEOR [std]', 'rougeL [mean]',\n",
      "       'rougeL [std]', 'BERTSCORE [mean]', 'BERTSCORE [std]',\n",
      "       'CLIPScore [mean]', 'CLIPScore [std]', 'RefCLIPScore [mean]',\n",
      "       'RefCLIPScore [std]'],\n",
      "      dtype='object')\n"
     ]
    }
   ],
   "source": [
    "w_abs = True\n",
    "w_ctx = False\n",
    "\n",
    "tag = f'abs{w_abs}_ctx{w_ctx}'\n",
    "\n",
    "# Load the CSV file\n",
    "cap_score_path = os.path.join(target_score_dir, \"csv\", f'{tag}.csv')\n",
    "df = pd.read_csv(cap_score_path)\n",
    "print(df.columns)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                model  BLEU-1 [mean]  BLEU-2 [mean]  BLEU-3 [mean]  \\\n",
      "0               blip2          32.88           4.18           0.45   \n",
      "1             kosmos2          22.28           2.91           0.61   \n",
      "2                qwen          38.27           8.75           2.22   \n",
      "3               llava          30.78           4.50           0.66   \n",
      "4          llava-next          19.79           3.70           0.68   \n",
      "5  llava-next-mistral          19.50           3.95           0.76   \n",
      "6         gpt-4-turbo          22.95           5.63           1.56   \n",
      "7              gpt-4o          21.06           5.58           1.76   \n",
      "8    llava-next-mmsci          45.89          16.96           8.12   \n",
      "\n",
      "   BLEU-4 [mean]  METEOR [mean]  rougeL [mean]  BERTSCORE [mean]  \\\n",
      "0           0.09           7.32           9.14             79.72   \n",
      "1           0.20          19.50          11.81             79.09   \n",
      "2           0.70          16.02          15.38             81.87   \n",
      "3           0.18          14.54          14.00             81.20   \n",
      "4           0.18          20.86          12.88             80.86   \n",
      "5           0.20          21.49          12.75             80.84   \n",
      "6           0.50          27.59          15.66             82.37   \n",
      "7           0.58          28.41          16.32             81.82   \n",
      "8           4.08          24.77          20.69             84.46   \n",
      "\n",
      "   CLIPScore [mean]  RefCLIPScore [mean]  \n",
      "0               0.0                  0.0  \n",
      "1               0.0                  0.0  \n",
      "2               0.0                  0.0  \n",
      "3               0.0                  0.0  \n",
      "4               0.0                  0.0  \n",
      "5               0.0                  0.0  \n",
      "6               0.0                  0.0  \n",
      "7               0.0                  0.0  \n",
      "8               0.0                  0.0  \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_1336798/2271840411.py:9: FutureWarning: DataFrame.applymap has been deprecated. Use DataFrame.map instead.\n",
      "  filtered_df[numeric_cols] = filtered_df[numeric_cols].applymap(lambda x: x * 100)\n",
      "/tmp/ipykernel_1336798/2271840411.py:9: SettingWithCopyWarning: \n",
      "A value is trying to be set on a copy of a slice from a DataFrame.\n",
      "Try using .loc[row_indexer,col_indexer] = value instead\n",
      "\n",
      "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
      "  filtered_df[numeric_cols] = filtered_df[numeric_cols].applymap(lambda x: x * 100)\n",
      "/tmp/ipykernel_1336798/2271840411.py:10: SettingWithCopyWarning: \n",
      "A value is trying to be set on a copy of a slice from a DataFrame.\n",
      "Try using .loc[row_indexer,col_indexer] = value instead\n",
      "\n",
      "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
      "  filtered_df[numeric_cols] = filtered_df[numeric_cols].round(2)\n"
     ]
    }
   ],
   "source": [
    "# Filter columns that contain \"[mean]\"\n",
    "# filtered_columns = df.columns\n",
    "filtered_columns = [col for col in df.columns if \"[mean]\" in col or \"model\" in col]\n",
    "\n",
    "# Check if the string is in the 'Name' column (case-sensitive)\n",
    "filtered_df = df[filtered_columns]\n",
    "\n",
    "numeric_cols = filtered_df.select_dtypes(include='number').columns\n",
    "filtered_df[numeric_cols] = filtered_df[numeric_cols].applymap(lambda x: x * 100)\n",
    "filtered_df[numeric_cols] = filtered_df[numeric_cols].round(2)\n",
    "\n",
    "# Print the filtered DataFrame\n",
    "print(filtered_df)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mace_new",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
