{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Put merged openai batch call output jsonl files in `./mmsci-exps/eval/output/openai_output_w_settings_merged/`.\n",
    "\n",
    "The directory structure would be like:\n",
    "- ./mmsci-exps/eval/output/openai_output_w_settings_merged/\n",
    "  - gpt-4-turbo_wo_abstract_wo_content_image_caption_generation_data.jsonl\n",
    "  - gpt-4-turbo_wo_abstract_w_content_image_caption_generation_data.jsonl\n",
    "  - gpt-4-turbo_w_abstract_wo_content_image_caption_generation_data.jsonl\n",
    "  - gpt-4-turbo_wo-cot_image_caption_matching_data.jsonl\n",
    "  - gpt-4-turbo_w-cot_image_caption_matching_data.jsonl\n",
    "  - gpt-4o_wo_abstract_wo_content_image_caption_generation_data.jsonl\n",
    "  - gpt-4o_wo_abstract_w_content_image_caption_generation_data.jsonl\n",
    "  - gpt-4o_w_abstract_wo_content_image_caption_generation_data.jsonl\n",
    "  - gpt-4o_wo-cot_image_caption_matching_data.jsonl\n",
    "  - gpt-4o_w-cot_image_caption_matching_data.jsonl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# reformat openai captioning results\n",
    "import os\n",
    "import json\n",
    "import glob\n",
    "\n",
    "base_input_dir = './output/openai_output_w_settings_merged'\n",
    "k = 3\n",
    "task = 'generation'\n",
    "output_dir = f'./output/image_caption_{task}'\n",
    "\n",
    "base_data_dir = '../../mmsci-data/benchmark/dev/'\n",
    "input_data = json.load(open(os.path.join(base_data_dir, f'image_caption_{task}_data.json')))\n",
    "input_data_mapping = {i: item for i, item in enumerate(input_data)}\n",
    "\n",
    "def reformat_caption_generation(input_filepath, output_filepath):\n",
    "    output_list = []\n",
    "    with open(input_filepath, 'r') as fin:\n",
    "        for line in fin.readlines():\n",
    "            info = json.loads(line)\n",
    "            key = int(info['custom_id'])\n",
    "            answers = []\n",
    "            for ans_info in info['response']['body'][\"choices\"]:\n",
    "                answers.append(ans_info['message']['content'])\n",
    "            info = input_data_mapping[key]\n",
    "            info['prediction'] = answers\n",
    "            output_list.append(info)\n",
    "    with open(output_filepath, 'w') as fout:\n",
    "        json.dump(output_list, fout, indent=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gpt-4-turbo_wo_abstract_wo_content_image_caption_generation_data.jsonl\n",
      "gpt-4-turbo_wo_abstract_w_content_image_caption_generation_data.jsonl\n",
      "gpt-4o_wo_abstract_wo_content_image_caption_generation_data.jsonl\n",
      "gpt-4o_w_abstract_wo_content_image_caption_generation_data.jsonl\n",
      "gpt-4-turbo_w_abstract_wo_content_image_caption_generation_data.jsonl\n",
      "gpt-4o_wo_abstract_w_content_image_caption_generation_data.jsonl\n"
     ]
    }
   ],
   "source": [
    "file_list = glob.glob(os.path.join(base_input_dir, f'*{task}*'))\n",
    "\n",
    "for filepath in file_list:\n",
    "    filename = filepath.split('/')[-1]\n",
    "    print(filename)\n",
    "    model_name = filename.split('_')[0]\n",
    "    w_abs = filename.find('w_abstract') > -1\n",
    "    w_content = filename.find('w_content') > -1\n",
    "    tag = f'abs{str(w_abs)}_ctx{str(w_content)}'\n",
    "    cur_output_dir = os.path.join(output_dir, tag, f'k_{k}')\n",
    "    os.makedirs(cur_output_dir, exist_ok=True)\n",
    "    reformat_caption_generation(\n",
    "        input_filepath=filepath,\n",
    "        output_filepath=os.path.join(cur_output_dir, f'{model_name}.json')\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dict_keys([1, 2, 3])\n"
     ]
    }
   ],
   "source": [
    "# reformat matching results\n",
    "import os\n",
    "import json\n",
    "import glob\n",
    "from collections import defaultdict\n",
    "\n",
    "base_input_dir = './output/openai_output_w_settings_merged'\n",
    "k = 5\n",
    "task = 'matching'\n",
    "output_dir = f'./output/image_caption_{task}'\n",
    "\n",
    "base_data_dir = '../../mmsci-data/benchmark/dev/'\n",
    "all_input_data = json.load(open(os.path.join(base_data_dir, f'image_caption_{task}_data.json')))\n",
    "input_data_mapping = defaultdict(dict)\n",
    "for setting, input_data in enumerate(all_input_data):\n",
    "    input_data_mapping[setting+1] = {i: item for i, item in enumerate(input_data)}\n",
    "print(input_data_mapping.keys())\n",
    "\n",
    "def reformat_caption_matching(input_filepath, model_name, tag):\n",
    "    all_output_list = defaultdict(list)\n",
    "    with open(input_filepath, 'r') as fin:\n",
    "        for line in fin.readlines():\n",
    "            info = json.loads(line)\n",
    "            setting, key = info['custom_id'].split('_')  # f'{setting+1}_{str(idx)}'\n",
    "            setting, key = int(setting), int(key)\n",
    "            answers = []\n",
    "            for ans_info in info['response']['body'][\"choices\"]:\n",
    "                answers.append({\n",
    "                    'answer': ans_info['message']['content']\n",
    "                })\n",
    "            info = input_data_mapping[setting][key]\n",
    "            info['prediction'] = answers\n",
    "            all_output_list[setting].append(info)\n",
    "    for setting, output_list in all_output_list.items():\n",
    "        cur_output_dir = os.path.join(output_dir, tag, f'setting-{setting}', f'k_{k}')\n",
    "        os.makedirs(cur_output_dir, exist_ok=True)\n",
    "        output_filepath = os.path.join(cur_output_dir, f'{model_name}.json')\n",
    "        with open(output_filepath, 'w') as fout:\n",
    "            json.dump(output_list, fout, indent=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gpt-4-turbo_w-cot_image_caption_matching_data.jsonl\n",
      "w_cot gpt-4-turbo\n",
      "gpt-4-turbo_wo-cot_image_caption_matching_data.jsonl\n",
      "wo_cot gpt-4-turbo\n",
      "gpt-4o_wo-cot_image_caption_matching_data.jsonl\n",
      "wo_cot gpt-4o\n",
      "gpt-4o_w-cot_image_caption_matching_data.jsonl\n",
      "w_cot gpt-4o\n"
     ]
    }
   ],
   "source": [
    "file_list = glob.glob(os.path.join(base_input_dir, f'*{task}*'))\n",
    "\n",
    "for filepath in file_list:\n",
    "    filename = filepath.split('/')[-1]\n",
    "    print(filename)\n",
    "    model_name = filename.split('_')[0]\n",
    "    tag = 'w_cot' if filename.find('w-cot') > -1 else 'wo_cot'\n",
    "    print(tag, model_name)\n",
    "    reformat_caption_matching(\n",
    "        input_filepath=filepath, \n",
    "        model_name=model_name, \n",
    "        tag=tag,\n",
    "    )"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mace_new",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
