{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import re\n",
    "import os\n",
    "from collections import Counter\n",
    "from tqdm import tqdm\n",
    "\n",
    "rst_dir = '/home/ubuntu/MMSci/mmsci-exps/eval/output/image_caption_matching'\n",
    "\n",
    "cot = False\n",
    "setting = 3\n",
    "k = 5\n",
    "rst_dir = os.path.join(rst_dir, \"w_cot\" if cot else \"wo_cot\", f\"setting-{setting}\", f\"k_{k}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def extract_ABCD_colon(input_string):\n",
    "    input_string = input_string.strip().replace(\"\\n\", \" \")\n",
    "    pattern = r'(^| )([A-D]):'\n",
    "    match = re.search(pattern, input_string)\n",
    "    if match:\n",
    "        return match.group(2)\n",
    "    return None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def parse_prediction(str):\n",
    "    \n",
    "    if \"[/INST]\" in str:\n",
    "        str = str.split(\"[/INST]\")[1].strip()\n",
    "\n",
    "    if extract_ABCD_colon(str) in [\"A\", \"B\", \"C\", \"D\"]:\n",
    "        return extract_ABCD_colon(str).lower()\n",
    "    \n",
    "    choice_list = ['a', 'b', 'c', 'd']\n",
    "    # remove all punctuations\n",
    "    str = re.sub(r'[^\\w\\s]', '', str).lower()\n",
    "    # replace '\\n' with ' '\n",
    "    str = str.replace('\\n', ' ')\n",
    "    # remove duplicated spaces\n",
    "    str = re.sub(' +', ' ', str)\n",
    "    if len(str) == 0:\n",
    "        return 'wrong answer'\n",
    "    # only keep the first 10 words\n",
    "    str = ' '.join(str.split()[:10])\n",
    "\n",
    "    if str.split()[0] in choice_list:  # first token is a/b/c/d\n",
    "        return str.split()[0]\n",
    "    for choice in choice_list:\n",
    "        if str.endswith(f'is {choice}'):\n",
    "            return choice\n",
    "        if f'is {choice} ' in str:\n",
    "            return choice\n",
    "    return 'wrong answer'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "A \n",
      " a \n",
      " ---\n",
      "(a) \n",
      " a \n",
      " ---\n",
      "(e) \n",
      " wrong answer \n",
      " ---\n",
      "is describing \n",
      " wrong answer \n",
      " ---\n",
      "bla bla bla is b \n",
      " b \n",
      " ---\n",
      "The caption describing Figure (c) is:\n",
      "\n",
      "A \n",
      " a \n",
      " ---\n",
      "The content in Figure (b) describes the temperature derivative of the resistivity, indicating aT-linear behaviour at high temperatures at ε=−0.3%. The figure also shows the fit-free analysis of theT4/3power-law behaviour for ε=−0.3%, theT5/3power-law behaviour for ε=−2.9%, and the upper and lower temperature limits where the resistivity crosses over to a linearTbehaviour. The figure also shows the thermal hysteresis coercivity versus strain and q-scans across the E′-type AFM reflection intensity at k=(1/4,1/4,1/4) observed by resonant X-ray diffraction at the Ni L3-edge at 12 K and 140 K. Therefore, the answer is C: The thermal hysteresis coercivity versus strain. \n",
      " c \n",
      " ---\n"
     ]
    }
   ],
   "source": [
    "test_list = [\n",
    "    'A',\n",
    "    '(a)',\n",
    "    '(e)',\n",
    "    'is describing',\n",
    "    'bla bla bla is b',\n",
    "    \"The caption describing Figure (c) is:\\n\\nA\",\n",
    "    \"The content in Figure (b) describes the temperature derivative of the resistivity, indicating aT-linear behaviour at high temperatures at \\u03b5=\\u22120.3%. The figure also shows the fit-free analysis of theT4/3power-law behaviour for \\u03b5=\\u22120.3%, theT5/3power-law behaviour for \\u03b5=\\u22122.9%, and the upper and lower temperature limits where the resistivity crosses over to a linearTbehaviour. The figure also shows the thermal hysteresis coercivity versus strain and q-scans across the E\\u2032-type AFM reflection intensity at k=(1/4,1/4,1/4) observed by resonant X-ray diffraction at the Ni L3-edge at 12\\u2009K and 140\\u2009K. Therefore, the answer is C: The thermal hysteresis coercivity versus strain.\"\n",
    "]\n",
    "for test_str in test_list:\n",
    "    print(test_str, '\\n', parse_prediction(test_str), '\\n', '-'*3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "A\n",
      "A\n"
     ]
    }
   ],
   "source": [
    "def most_frequent_item(lst):\n",
    "    # Create a dictionary to store the count of each item\n",
    "    counts = {}\n",
    "    # Iterate over the list and count occurrences\n",
    "    for item in lst:\n",
    "        if item in counts:\n",
    "            counts[item] += 1\n",
    "        else:\n",
    "            counts[item] = 1\n",
    "    \n",
    "    # Find the item with the maximum count\n",
    "    max_count = 0\n",
    "    most_frequent = None\n",
    "    for item in lst:\n",
    "        if counts[item] > max_count:\n",
    "            max_count = counts[item]\n",
    "            most_frequent = item\n",
    "    \n",
    "    return most_frequent\n",
    "\n",
    "# Example usage\n",
    "input1 = [\"A\", \"D\", \"A\", \"A\", \"B\"]\n",
    "input2 = [\"A\", \"D\", \"A\", \"D\", \"B\"]\n",
    "print(most_frequent_item(input1))  # Output: \"A\"\n",
    "print(most_frequent_item(input2))  # Output: \"A\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/1119 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1119/1119 [00:00<00:00, 13935.23it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "llava-next-mistral.json: 1: 265/1119 = 23.68%\n",
      "llava-next-mistral.json: 1: 265/1119 = 23.68%\n",
      "\n",
      "llava-next-mistral.json: 3: 259/1119 = 23.15%\n",
      "llava-next-mistral.json: 3: 259/1119 = 23.15%\n",
      "\n",
      "llava-next-mistral.json: 5: 243/1119 = 21.72%\n",
      "llava-next-mistral.json: 5: 243/1119 = 21.72%\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1119/1119 [00:00<00:00, 17334.40it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "llava-next.json: 1: 244/1119 = 21.81%\n",
      "llava-next.json: 1: 244/1119 = 21.81%\n",
      "\n",
      "llava-next.json: 3: 249/1119 = 22.25%\n",
      "llava-next.json: 3: 249/1119 = 22.25%\n",
      "\n",
      "llava-next.json: 5: 255/1119 = 22.79%\n",
      "llava-next.json: 5: 255/1119 = 22.79%\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1117/1117 [00:00<00:00, 211066.25it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gpt-4o-w-ans.json: 1: 964/1117 = 86.30%\n",
      "gpt-4o-w-ans.json: 1: 964/1117 = 86.30%\n",
      "\n",
      "gpt-4o-w-ans.json: 3: 972/1117 = 87.02%\n",
      "gpt-4o-w-ans.json: 3: 972/1117 = 87.02%\n",
      "\n",
      "gpt-4o-w-ans.json: 5: 977/1117 = 87.47%\n",
      "gpt-4o-w-ans.json: 5: 977/1117 = 87.47%\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1119/1119 [00:00<00:00, 9607.26it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "kosmos2.json: 1: 268/1119 = 23.95%\n",
      "kosmos2.json: 1: 268/1119 = 23.95%\n",
      "\n",
      "kosmos2.json: 3: 268/1119 = 23.95%\n",
      "kosmos2.json: 3: 268/1119 = 23.95%\n",
      "\n",
      "kosmos2.json: 5: 268/1119 = 23.95%\n",
      "kosmos2.json: 5: 268/1119 = 23.95%\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1117/1117 [00:00<00:00, 208085.17it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gpt-4-turbo-w-ans.json: 1: 750/1117 = 67.14%\n",
      "gpt-4-turbo-w-ans.json: 1: 750/1117 = 67.14%\n",
      "\n",
      "gpt-4-turbo-w-ans.json: 3: 772/1117 = 69.11%\n",
      "gpt-4-turbo-w-ans.json: 3: 772/1117 = 69.11%\n",
      "\n",
      "gpt-4-turbo-w-ans.json: 5: 791/1117 = 70.81%\n",
      "gpt-4-turbo-w-ans.json: 5: 791/1117 = 70.81%\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1117/1117 [00:00<00:00, 11618.92it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gpt-4-turbo.json: 1: 664/1117 = 59.44%\n",
      "gpt-4-turbo.json: 1: 664/1117 = 59.44%\n",
      "\n",
      "gpt-4-turbo.json: 3: 718/1117 = 64.28%\n",
      "gpt-4-turbo.json: 3: 718/1117 = 64.28%\n",
      "\n",
      "gpt-4-turbo.json: 5: 743/1117 = 66.52%\n",
      "gpt-4-turbo.json: 5: 743/1117 = 66.52%\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1119/1119 [00:00<00:00, 69812.52it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "llava-next-mmsci.json: 1: 850/1119 = 75.96%\n",
      "llava-next-mmsci.json: 1: 850/1119 = 75.96%\n",
      "\n",
      "llava-next-mmsci.json: 3: 861/1119 = 76.94%\n",
      "llava-next-mmsci.json: 3: 861/1119 = 76.94%\n",
      "\n",
      "llava-next-mmsci.json: 5: 864/1119 = 77.21%\n",
      "llava-next-mmsci.json: 5: 864/1119 = 77.21%\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1117/1117 [00:00<00:00, 26128.14it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gpt-4o.json: 1: 967/1117 = 86.57%\n",
      "gpt-4o.json: 1: 967/1117 = 86.57%\n",
      "\n",
      "gpt-4o.json: 3: 977/1117 = 87.47%\n",
      "gpt-4o.json: 3: 977/1117 = 87.47%\n",
      "\n",
      "gpt-4o.json: 5: 983/1117 = 88.00%\n",
      "gpt-4o.json: 5: 983/1117 = 88.00%\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "\n",
    "for file in os.listdir(rst_dir):\n",
    "    rst_list = json.load(open(os.path.join(rst_dir, file)))\n",
    "    num_total = len(rst_list)\n",
    "    num_correct = [0, 0, 0] # 1, 3, 5\n",
    "    valid = [0, 0, 0] # 1, 3, 5\n",
    "    for item in tqdm(rst_list):\n",
    "        # counter = Counter\n",
    "        gt = item['answer'].lower()\n",
    "        preds = [] if 'prediction' not in item else item['prediction'] # only select the top_k pred for evaluation\n",
    "        preds = [pred[\"extracted_answer\"].lower() if \"w-ans\" in file else parse_prediction(pred[\"answer\"]) for pred in preds]\n",
    "        for idx, top_k in enumerate([1, 3, 5]):\n",
    "            top_preds = preds[:top_k]\n",
    "            if most_frequent_item(top_preds) == gt:\n",
    "                num_correct[idx] += 1\n",
    "            if top_preds:\n",
    "                valid[idx] += 1\n",
    "    for idx, top_k in enumerate([1, 3, 5]):\n",
    "        print(f'{file}: {top_k}: {num_correct[idx]}/{num_total} = {num_correct[idx]*100/num_total:.2f}%')\n",
    "        print(f'{file}: {top_k}: {num_correct[idx]}/{valid[idx]} = {num_correct[idx]*100/valid[idx]:.2f}%')\n",
    "        print()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mace",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
