{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "import base64\n",
    "from datetime import datetime\n",
    "\n",
    "import pandas as pd\n",
    "from dotenv import load_dotenv\n",
    "from openai import OpenAI\n",
    "\n",
    "load_dotenv()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "OPENAI_API_KEY = os.getenv(\"OPENAI_API_KEY\")\n",
    "OPENAI_BASE_URL = \"https://api.openai.com/v1\"\n",
    "WORKING_DIRECTORY_BASE_NAME = f\"tmp_{datetime.now().strftime('%d.%m.%Y-%H:%M:%S')}\"\n",
    "IMAGE_BASE_PATH = \"Image_path\"\n",
    "DATASET_PATH = \"Dataset_path\"\n",
    "\n",
    "os.mkdir(WORKING_DIRECTORY_BASE_NAME)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "client = OpenAI(api_key=OPENAI_API_KEY, base_url=OPENAI_BASE_URL)\n",
    "df = pd.read_csv(DATASET_PATH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def line_template_filler(id, question, image_bin):\n",
    "    return {\n",
    "        \"custom_id\": f\"{id}\",\n",
    "        \"method\": \"POST\",\n",
    "        \"url\": \"/v1/chat/completions\",\n",
    "        \"body\": {\n",
    "            \"model\": \"gpt-4o-2024-08-06\",\n",
    "            \"messages\": [\n",
    "                {\n",
    "                    \"role\": \"user\",\n",
    "                    \"content\": [\n",
    "                        {\n",
    "                            \"type\": \"image_url\",\n",
    "                            \"image_url\": {\n",
    "                                \"url\": f\"data:image/png;base64,{base64.b64encode(image_bin).decode('utf-8')}\"\n",
    "                            }\n",
    "                        },\n",
    "                        {\n",
    "                            \"type\": \"text\",\n",
    "                            \"text\": question\n",
    "                        }\n",
    "                    ]\n",
    "                }\n",
    "            ],\n",
    "            \"temperature\": 1,\n",
    "            \"max_tokens\": 256,\n",
    "            \"top_p\": 1,\n",
    "            \"frequency_penalty\": 0,\n",
    "            \"presence_penalty\": 0,\n",
    "            \"response_format\": {\n",
    "                \"type\": \"json_schema\",\n",
    "                \"json_schema\": {\n",
    "                    \"name\": \"object_counting_response_with_reasoning\",\n",
    "                    \"strict\": True,\n",
    "                    \"schema\": {\n",
    "                        \"type\": \"object\",\n",
    "                        \"properties\": {\n",
    "                            \"reasoning\": {\"type\": \"string\"},\n",
    "                            \"count\": {\"type\": \"integer\"}\n",
    "                        },\n",
    "                        \"additionalProperties\": False,\n",
    "                        \"required\": [\"count\", \"reasoning\"]\n",
    "                    },\n",
    "                },\n",
    "            }\n",
    "        }\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "filename = \"File_name\"\n",
    "query = \"0 <= index < 200\"\n",
    "\n",
    "with open(f\"{WORKING_DIRECTORY_BASE_NAME}/{filename}.jsonl\", 'w') as request_file:\n",
    "    for index, row in df.query(query).iterrows():\n",
    "        with open(f\"{IMAGE_BASE_PATH}/{row['image_name']}\", 'rb') as f:\n",
    "            image = f.read()\n",
    "        question = f\"How many {row['category_name']} are visible?\"\n",
    "        jsonl = line_template_filler(id=index, question=question, image_bin=image)\n",
    "        request_file.write(json.dumps(jsonl))\n",
    "        request_file.write(\"\\n\")\n",
    "\n",
    "batch_input_file = client.files.create(\n",
    "  file=open(f\"{WORKING_DIRECTORY_BASE_NAME}/{filename}.jsonl\", \"rb\"),\n",
    "  purpose=\"batch\"\n",
    ")\n",
    "\n",
    "batch_input_file_id = batch_input_file.id\n",
    "\n",
    "client.batches.create(\n",
    "    input_file_id=batch_input_file_id,\n",
    "    endpoint=\"/v1/chat/completions\",\n",
    "    completion_window=\"24h\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "file_response = client.files.content(\"file-4cM5zyPKBhXp4nx8xkAJwB\")\n",
    "with open(\"output.jsonl\", \"w\") as f:\n",
    "    f.write(file_response.text)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "import glob\n",
    "import json\n",
    "import numpy as np\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "llm_count = dict()\n",
    "for file in glob.glob(\"JSONL_OUTPUT_FILE_PATH\"):\n",
    "    with open(file) as f:\n",
    "        for line in f:\n",
    "            res = json.loads(line)\n",
    "            llm_count[res['custom_id']] = json.loads(res['response']['body']['choices'][0]['message']['content'])['count']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "df['llm_count'] = [None]*len(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "for key, value in llm_count.items():\n",
    "    df.loc[df.index == int(key), 'llm_count'] = value"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calculate_metrics(df):\n",
    "    print(\"EA\", (df['llm_count'].to_numpy() == df['answer'].to_numpy()).sum() / len(df))\n",
    "    print(\"MAE\", sum(abs(df['llm_count'].to_numpy() - df['answer'].to_numpy())) / len(df))\n",
    "    print(\"RMSE\", (np.sqrt(sum((df['llm_count'].to_numpy() - df['answer'].to_numpy())**2)/len(df))).item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df.to_csv(\"output.csv\", index=False)\n",
    "calculate_metrics(df)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
