{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "import base64\n",
    "\n",
    "base_output_dir = '../../mmsci-data/openai-input-w-setting'\n",
    "for task in ['matching', 'generation']:\n",
    "    os.makedirs(os.path.join(base_output_dir, task), exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Function to encode the image\n",
    "def encode_image(image_path):\n",
    "  with open(image_path, \"rb\") as image_file:\n",
    "    return base64.b64encode(image_file.read()).decode('utf-8')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Prepare input jsonl for openai batch calls - matching task\n",
    "\n",
    "base_dir = '../../mmsci-data/'\n",
    "image_dir = os.path.join(base_dir, 'benchmark/dev/images/')\n",
    "input_filename = os.path.join(base_dir, 'benchmark/dev/image_caption_matching_data.json')\n",
    "\n",
    "def get_prompt_for_matching(question, w_cot):\n",
    "    final_instr = 'Answer only with A, B, C, or D.'\n",
    "    cot_prompt = 'Please first thoroughly analyze and think about this problem, and then come to your final answer.'\n",
    "    if not w_cot:\n",
    "        prompt = f'{question}\\n{final_instr}'\n",
    "    else:\n",
    "        prompt = f'{question}\\n{cot_prompt}'\n",
    "    return prompt\n",
    "\n",
    "def form_openai_input_for_matching(model, w_cot):\n",
    "    all_data = json.load(open(input_filename, 'r'))\n",
    "    output_filename = os.path.join(base_output_dir, 'matching', f\"{model}_{'w' if w_cot else 'wo'}-cot_{input_filename.split('/')[-1]}l\")\n",
    "    print(output_filename)\n",
    "    \n",
    "    with open(output_filename, 'w') as fout:\n",
    "        for setting, data in enumerate(all_data):\n",
    "            for idx, item in enumerate(data):\n",
    "                base64_image = encode_image(os.path.join(image_dir, item['image']))\n",
    "                info = {\n",
    "                    \"custom_id\": f'{setting+1}_{str(idx)}',\n",
    "                    \"method\": \"POST\", \n",
    "                    \"url\": \"/v1/chat/completions\", \n",
    "                    \"body\": {\n",
    "                        \"model\": model,\n",
    "                        \"messages\": [\n",
    "                            {\n",
    "                            \"role\": \"user\",\n",
    "                            \"content\": [\n",
    "                                {\n",
    "                                \"type\": \"text\",\n",
    "                                \"text\": get_prompt_for_matching(item['question'], w_cot),\n",
    "                                },\n",
    "                                {\n",
    "                                \"type\": \"image_url\",\n",
    "                                \"image_url\": {\n",
    "                                    \"url\": f\"data:image/jpeg;base64,{base64_image}\"\n",
    "                                }\n",
    "                                }\n",
    "                            ]\n",
    "                            }\n",
    "                        ],\n",
    "                        \"max_tokens\": 1024,\n",
    "                        \"n\": 5,\n",
    "                        \"temperature\": 0.7,\n",
    "                    }\n",
    "                }\n",
    "                fout.write(json.dumps(info) + '\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "./openai-input-w-setting/matching/gpt-4o_w-cot_image_caption_matching_data.jsonl\n",
      "./openai-input-w-setting/matching/gpt-4o_wo-cot_image_caption_matching_data.jsonl\n",
      "./openai-input-w-setting/matching/gpt-4-turbo_w-cot_image_caption_matching_data.jsonl\n",
      "./openai-input-w-setting/matching/gpt-4-turbo_wo-cot_image_caption_matching_data.jsonl\n"
     ]
    }
   ],
   "source": [
    "for model in ['gpt-4o', 'gpt-4-turbo']:\n",
    "    for w_cot in [True, False]:\n",
    "        form_openai_input_for_matching(model, w_cot)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Prepare input jsonl for openai batch calls - captioning task\n",
    "base_dir = '../../mmsci-data/'\n",
    "image_dir = os.path.join(base_dir, 'benchmark/dev/images/')\n",
    "input_filename = os.path.join(base_dir, 'benchmark/dev/image_caption_generation_data.json')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Prepare input jsonl for openai batch calls - captioning task\n",
    "\n",
    "import tiktoken\n",
    "from tqdm import tqdm\n",
    "\n",
    "embedding_encoding = \"cl100k_base\"\n",
    "encoding = tiktoken.get_encoding(embedding_encoding)\n",
    "max_tokens = 125000 \n",
    "\n",
    "\n",
    "def truncate(x, max_len=max_tokens):\n",
    "    token_list = encoding.encode(x)[:max_len]\n",
    "    return encoding.decode(token_list)\n",
    "\n",
    "\n",
    "def format_prompt(abstract, content, with_abstract, with_content):\n",
    "    input_content = ''\n",
    "    if with_abstract:\n",
    "        input_content += f'Article:\\n{abstract}\\n'\n",
    "    if with_content:\n",
    "        input_content += f'Article:\\n{content}\\n'\n",
    "    prompt = f'Please write a detailed description of the given scientific figure based on the following content:\\n{input_content}'\n",
    "    prompt = truncate(prompt, max_tokens)\n",
    "    return prompt\n",
    "\n",
    "\n",
    "def form_openai_input_for_captioning(model, with_abstract, with_content):\n",
    "    tag = f'{\"wo\" if not with_abstract else \"w\"}_abstract_{\"wo\" if not with_content else \"w\"}_content'\n",
    "\n",
    "    data = json.load(open(input_filename, 'r'))\n",
    "    output_filename = os.path.join(base_output_dir, 'generation', f\"{model}_{tag}_{input_filename.split('/')[-1]}l\")\n",
    "\n",
    "    with open(output_filename, 'w') as fout:\n",
    "        for idx, item in tqdm(list(enumerate(data)), total=len(data), desc='prepare data'):\n",
    "            base64_image = encode_image(os.path.join(image_dir, item['image']))\n",
    "            info = {\n",
    "                \"custom_id\": str(idx), \n",
    "                \"method\": \"POST\", \n",
    "                \"url\": \"/v1/chat/completions\", \n",
    "                \"body\": {\n",
    "                    \"model\": model,\n",
    "                    \"messages\": [\n",
    "                        {\n",
    "                        \"role\": \"user\",\n",
    "                        \"content\": [\n",
    "                            {\n",
    "                            \"type\": \"text\",\n",
    "                            \"text\": format_prompt(item['abstract'], item['content'], with_abstract, with_content),\n",
    "                            },\n",
    "                            {\n",
    "                            \"type\": \"image_url\",\n",
    "                            \"image_url\": {\n",
    "                                \"url\": f\"data:image/jpeg;base64,{base64_image}\"\n",
    "                            }\n",
    "                            }\n",
    "                        ]\n",
    "                        }\n",
    "                    ],\n",
    "                    \"max_tokens\": 1024,\n",
    "                    \"n\": 3,\n",
    "                    \"temperature\": 0.7,\n",
    "                }\n",
    "            }\n",
    "            fout.write(json.dumps(info) + '\\n')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "prepare data: 100%|██████████| 1281/1281 [00:01<00:00, 885.92it/s]\n",
      "prepare data: 100%|██████████| 1281/1281 [00:13<00:00, 93.08it/s]\n",
      "prepare data: 100%|██████████| 1281/1281 [00:01<00:00, 1124.11it/s]\n",
      "prepare data: 100%|██████████| 1281/1281 [00:01<00:00, 884.44it/s]\n",
      "prepare data: 100%|██████████| 1281/1281 [00:13<00:00, 95.19it/s]\n",
      "prepare data: 100%|██████████| 1281/1281 [00:01<00:00, 1186.08it/s]\n"
     ]
    }
   ],
   "source": [
    "for model in ['gpt-4o', 'gpt-4-turbo']:\n",
    "    for with_abstract in [True, False]:\n",
    "        for with_content in [True, False]:\n",
    "            if with_abstract and with_content:\n",
    "                continue\n",
    "            form_openai_input_for_captioning(model, with_abstract, with_content)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "# split a file into several subfiles to fit for openai batch call max input file size\n",
    "\n",
    "def chunks(lst, n):\n",
    "    \"\"\"Yield successive n-sized chunks from lst.\"\"\"\n",
    "    for i in range(0, len(lst), n):\n",
    "        yield lst[i:i + n]\n",
    "\n",
    "def split_file(input_dir, output_dir, filename, n=4):\n",
    "    with open(os.path.join(input_dir, filename), 'r') as fin:\n",
    "        lines = fin.readlines()\n",
    "    line_chunks = list(chunks(lines, len(lines)//n))\n",
    "    for idx, chunk in enumerate(line_chunks):\n",
    "        with open(os.path.join(output_dir, filename.replace('.jsonl', f'_{idx}.jsonl')), 'w') as fout:\n",
    "            fout.writelines(chunk)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "\n",
    "for task in ['matching', 'generation']:\n",
    "    input_dir = os.path.join(base_output_dir, task)\n",
    "    output_dir = os.path.join(base_output_dir, f'{task}_chunked')\n",
    "    os.makedirs(output_dir, exist_ok=True)\n",
    "\n",
    "    for file in os.listdir(input_dir):\n",
    "        file_size = os.path.getsize(os.path.join(input_dir, file)) >> 20\n",
    "        num_chunk = math.ceil(file_size/100)\n",
    "        split_file(input_dir, output_dir, file, n=num_chunk)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create openai client\n",
    "\n",
    "from openai import OpenAI\n",
    "client = OpenAI(api_key='xxx')\n",
    "\n",
    "\n",
    "def create_batch_job(filename):\n",
    "    print(f\"Creating batch job for {filename}\")\n",
    "\n",
    "    # upload input file to openai server\n",
    "    batch_input_file = client.files.create(\n",
    "    file=open(filename, \"rb\"),\n",
    "    purpose=\"batch\"\n",
    "    )\n",
    "\n",
    "    batch_input_file_id = batch_input_file.id\n",
    "    print(batch_input_file_id)\n",
    "\n",
    "    # create batch job\n",
    "    job = client.batches.create(\n",
    "        input_file_id=batch_input_file_id,\n",
    "        endpoint=\"/v1/chat/completions\",\n",
    "        completion_window=\"24h\",\n",
    "        metadata={\n",
    "        \"description\": filename.split('/')[-1]\n",
    "        }\n",
    "    )\n",
    "    print(f\"Batch job ID:\\t{job.id}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create batch jobs\n",
    "\n",
    "for task in ['matching', 'generation']:\n",
    "    file_dir = os.path.join(base_output_dir, f'{task}_chunked')\n",
    "    for file in os.listdir(file_dir):\n",
    "        create_batch_job(os.path.join(file_dir, file))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mace",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
