{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "71a52b17",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "import math\n",
    "import openai\n",
    "import jsonlines\n",
    "import numpy as np\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "4dd52b7a",
   "metadata": {},
   "outputs": [],
   "source": [
    "system = {\n",
    "    'Hemonc': 'You are a doctor with a professional medical background.',\n",
    "    'PubMedQA': 'You are a doctor with a professional medical background.',\n",
    "    'NQ': 'You are a helpful AI assistant.',\n",
    "    'HotpotQA': 'You are a helpful AI assistant.'\n",
    "}\n",
    "instruction_doc = {\n",
    "    True: 'You are given some documents and a multiple-choice question.\\nBased on the document, select the most appropriate answer from the options provided.',\n",
    "    False: 'Without relying on any external document, select the most appropriate answer from the options provided.'\n",
    "}\n",
    "instruction = 'First, explain your reasoning briefly step-by-step based on the provided information.\\nThen, select the most appropriate option and present your response in the required format.'\n",
    "response = 'Provide your response in the following format:\\n<answer>Option [number]</answer>'\n",
    "\n",
    "def get_message(row, qid, if_doc=False):\n",
    "    prompt = ['### Instruction:\\n' + instruction_doc[if_doc] + instruction]\n",
    "    if if_doc: prompt += ['### Documents:\\n' + row[f'evidence {augmentation}']]\n",
    "    prompt += ['### Question:\\n' + row[f'question {qid}']]\n",
    "    prompt += ['### Choices:\\n' + '\\n'.join([f'Option {i}: {row[f\"option {i}\"]}' for i in range(1, 4)])]\n",
    "    prompt += [response]\n",
    "    message = [{'role':'system', 'content':system[args_dataset]}, {'role':'user', 'content':'\\n\\n'.join(prompt)}]\n",
    "    return message"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fbe2f5ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "cwd = os.getcwd()\n",
    "args_dataset = 'HotpotQA'\n",
    "augmentation = 'naive'\n",
    "dataset = pd.read_csv(f'{cwd}/Data/Augmentation/Input/{args_dataset}.csv')\n",
    "num_resp, num_qn = 100, 20\n",
    "messages_con = []\n",
    "for _ in range(num_resp):\n",
    "    qids = np.random.randint(1, num_qn+1, len(dataset))\n",
    "    messages_con.append([get_message(row, qids[i], if_doc=True) for i, row in dataset.iterrows()])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd732274",
   "metadata": {},
   "outputs": [],
   "source": [
    "max_token = 1024\n",
    "batch_size = 300\n",
    "num_batch = math.ceil(len(dataset) / batch_size)\n",
    "for idx_batch in range(num_batch):\n",
    "    with open(f'{cwd}/Data/Augmentation/GPT-4o/input-{args_dataset}-{augmentation}-{idx_batch}.jsonl', 'w') as file:\n",
    "        for idx_query in range(idx_batch*batch_size, (idx_batch+1)*batch_size):\n",
    "            if idx_query >= len(dataset): break\n",
    "            for idx_resp in range(num_resp):\n",
    "                custom_id = f'{args_dataset}-{idx_query}-{augmentation}-{idx_resp}'\n",
    "                message = messages_con[idx_resp][idx_query]\n",
    "                item = {\"custom_id\": custom_id, \"method\": \"POST\", \"url\": '/v1/chat/completions', \n",
    "                        \"body\": {\"model\": 'gpt-4o', \"messages\": message, \"max_tokens\": max_token}}\n",
    "                file.write(json.dumps(item) + '\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "747b0adb",
   "metadata": {},
   "outputs": [],
   "source": [
    "client = openai.OpenAI(api_key=os.environ[\"OPENAI_API_KEY\"])\n",
    "for idx_batch in range(num_batch):\n",
    "    batch_input_file = client.files.create(\n",
    "        file=open(f\"{cwd}/Data/Augmentation/GPT-4o/input-{args_dataset}-{augmentation}-{idx_batch}.jsonl\", \"rb\"), \n",
    "        purpose=\"batch\"\n",
    "    )\n",
    "    batch_job = client.batches.create(\n",
    "        input_file_id=batch_input_file.id,\n",
    "        endpoint=service2info[service]['url'],\n",
    "        completion_window=\"24h\",\n",
    "        metadata={\n",
    "            \"description\": f\"{args_dataset}-{augmentation}-{idx_batch}\"\n",
    "        }\n",
    "    )\n",
    "    with open(f'{cwd}/Data/Augmentation/Batch2JobID.jsonl', 'a') as file:\n",
    "        file.write(json.dumps({f'{args_dataset}-{augmentation}-{idx_batch}':batch_job.id}) + '\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "db79b00b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Batch Job: HotpotQA-naive-0 | Status: completed | Already Retrieved\n",
      "Batch Job: HotpotQA-naive-1 | Status: completed | Already Retrieved\n",
      "Batch Job: HotpotQA-naive-2 | Status: completed | Already Retrieved\n",
      "Batch Job: HotpotQA-naive-3 | Status: completed | Already Retrieved\n",
      "Batch Job: HotpotQA-naive-4 | Status: completed | Already Retrieved\n",
      "Batch Job: HotpotQA-naive-5 | Status: completed | Already Retrieved\n",
      "Batch Job: HotpotQA-naive-6 | Status: completed | Already Retrieved\n",
      "Batch Job: HotpotQA-naive-7 | Status: completed | Already Retrieved\n",
      "Batch Job: HotpotQA-naive-8 | Status: completed | Already Retrieved\n",
      "Batch Job: HotpotQA-naive-9 | Status: completed | Already Retrieved\n",
      "Batch Job: HotpotQA-naive-10 | Status: completed | Already Retrieved\n",
      "Batch Job: HotpotQA-naive-11 | Status: completed | Retrieving Now\n",
      "Batch Job: HotpotQA-naive-12 | Status: completed | Already Retrieved\n",
      "Batch Job: HotpotQA-naive-13 | Status: completed | Already Retrieved\n",
      "Batch Job: HotpotQA-naive-14 | Status: completed | Already Retrieved\n",
      "Batch Job: HotpotQA-naive-15 | Status: completed | Already Retrieved\n",
      "Batch Job: HotpotQA-naive-16 | Status: completed | Already Retrieved\n",
      "Batch Job: HotpotQA-naive-17 | Status: completed | Already Retrieved\n",
      "Batch Job: HotpotQA-naive-18 | Status: completed | Already Retrieved\n",
      "Batch Job: HotpotQA-naive-19 | Status: completed | Already Retrieved\n",
      "Batch Job: HotpotQA-naive-20 | Status: completed | Already Retrieved\n"
     ]
    }
   ],
   "source": [
    "data = 'HotpotQA'\n",
    "augmentation = 'naive'\n",
    "client = openai.OpenAI(api_key=os.environ[\"OPENAI_API_KEY\"])\n",
    "\n",
    "cwd = os.getcwd()\n",
    "with jsonlines.open(f'{cwd}/Data/Augmentation/Batch2JobID.jsonl') as reader:\n",
    "    for line in reader:\n",
    "        batch, job_id = list(line.items())[0]\n",
    "        if batch.split('-')[0] != data or batch.split('-')[1] != augmentation: continue\n",
    "        batch_job = client.batches.retrieve(job_id)\n",
    "        if batch_job.status != 'completed': \n",
    "            print(f'Batch Job: {batch} | Status: {batch_job.status}')\n",
    "            continue\n",
    "        file_path = f'{cwd}/Data/Augmentation/GPT-4o/output-{batch}.jsonl'\n",
    "        if os.path.exists(file_path): \n",
    "            print(f'Batch Job: {batch} | Status: {batch_job.status} | Already Retrieved')\n",
    "            continue\n",
    "        print(f'Batch Job: {batch} | Status: {batch_job.status} | Retrieving Now')\n",
    "        batch_response = client.files.content(batch_job.output_file_id)\n",
    "        batch_response = batch_response.text.strip().split('\\n')\n",
    "        batch_response = [json.loads(response) for response in batch_response]\n",
    "        with jsonlines.open(file_path, 'w') as writer:\n",
    "            writer.write_all(batch_response)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "f5d0c664",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = 'HotpotQA'\n",
    "augmentation = 'naive'\n",
    "num_batch = 21\n",
    "cwd = os.getcwd()\n",
    "dataset = pd.read_csv(f'{cwd}/Data/Augmentation/Input/{data}.csv')\n",
    "\n",
    "responses = [{'context':[]} for idx_query in range(len(dataset))]\n",
    "for idx_batch in range(num_batch):\n",
    "    with jsonlines.open(f'{cwd}/Data/Augmentation/GPT-4o/output-{data}-{augmentation}-{idx_batch}.jsonl') as reader:\n",
    "        for line in reader:\n",
    "            _, idx_query, _, _ = line['custom_id'].split('-')\n",
    "            response = line['response']['body']['choices'][0]['message']['content']\n",
    "            responses[int(idx_query)]['context'].append(response)\n",
    "            \n",
    "json.dump(responses, open(f'{cwd}/Data/Augmentation/GPT-4o/{data}-output_augmentation={augmentation}.json', 'w'), indent=4)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
