{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "import glob\n",
    "import base64\n",
    "from datetime import datetime\n",
    "\n",
    "import pandas as pd\n",
    "from dotenv import load_dotenv\n",
    "from openai import OpenAI\n",
    "\n",
    "load_dotenv()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "OPENAI_API_KEY = os.getenv(\"OPENAI_API_KEY\")\n",
    "OPENAI_BASE_URL = \"https://api.openai.com/v1\"\n",
    "WORKING_DIRECTORY_BASE_NAME = f\"tmp_mask_to_gpt4o_count_{datetime.now().strftime('%d.%m.%Y-%H:%M:%S')}\"\n",
    "\n",
    "os.mkdir(WORKING_DIRECTORY_BASE_NAME)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "client = OpenAI(api_key=OPENAI_API_KEY, base_url=OPENAI_BASE_URL)\n",
    "df = pd.read_json(DATASET_PATH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def line_template_filler(id, question, image_bin):\n",
    "    return {\n",
    "        \"custom_id\": f\"{id}\",\n",
    "        \"method\": \"POST\",\n",
    "        \"url\": \"/v1/chat/completions\",\n",
    "        \"body\": {\n",
    "            \"model\": \"gpt-4o-2024-08-06\",\n",
    "            \"messages\": [\n",
    "                {\n",
    "                    \"role\": \"user\",\n",
    "                    \"content\": [\n",
    "                        {\n",
    "                            \"type\": \"image_url\",\n",
    "                            \"image_url\": {\n",
    "                                \"url\": f\"data:image/png;base64,{base64.b64encode(image_bin).decode('utf-8')}\"\n",
    "                            }\n",
    "                        },\n",
    "                        {\n",
    "                            \"type\": \"text\",\n",
    "                            \"text\": question\n",
    "                        }\n",
    "                    ]\n",
    "                }\n",
    "            ],\n",
    "            \"temperature\": 1,\n",
    "            \"max_tokens\": 256,\n",
    "            \"top_p\": 1,\n",
    "            \"frequency_penalty\": 0,\n",
    "            \"presence_penalty\": 0,\n",
    "            \"response_format\": {\n",
    "                \"type\": \"json_schema\",\n",
    "                \"json_schema\": {\n",
    "                    \"name\": \"object_counter\",\n",
    "                    \"strict\": True,\n",
    "                    \"schema\": {\n",
    "                        \"type\": \"object\",\n",
    "                        \"properties\": {\n",
    "                            \"count\": {\n",
    "                                \"type\": \"integer\"\n",
    "                            }\n",
    "                        },\n",
    "                        \"additionalProperties\": False,\n",
    "                        \"required\": [\n",
    "                            \"count\"\n",
    "                        ]\n",
    "                    }\n",
    "                }\n",
    "            }\n",
    "        }\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "base_dir = \"path_to_masks_folder\"\n",
    "df = pd.read_csv(\"path_to_labels.csv\")\n",
    "\n",
    "filename = \"emoji_benchmark\"\n",
    "with open(f\"{WORKING_DIRECTORY_BASE_NAME}/{filename}.jsonl\", 'w') as request_file:\n",
    "    for index, row in df.iterrows():\n",
    "        folder_name = f\"{row['id']}.{row['object_of_interest']}\"\n",
    "        for i, mask_img in enumerate(glob.glob(f\"{base_dir}/{folder_name}/**_crops/**/mask.png\", recursive=True)):\n",
    "            with open(mask_img, 'rb') as f:\n",
    "                image = f.read()\n",
    "            question = \"How many masks are visible in the image?\"\n",
    "            jsonl = line_template_filler(id=f\"{row['id']}.{i}\", question=question, image_bin=image)\n",
    "            request_file.write(json.dumps(jsonl))\n",
    "            request_file.write(\"\\n\")\n",
    "\n",
    "batch_input_file = client.files.create(\n",
    "  file=open(f\"{WORKING_DIRECTORY_BASE_NAME}/{filename}.jsonl\", \"rb\"),\n",
    "  purpose=\"batch\"\n",
    ")\n",
    "\n",
    "batch_input_file_id = batch_input_file.id\n",
    "\n",
    "client.batches.create(\n",
    "    input_file_id=batch_input_file_id,\n",
    "    endpoint=\"/v1/chat/completions\",\n",
    "    completion_window=\"24h\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import glob\n",
    "import json\n",
    "import numpy as np\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_csv(\"path_to_labels.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "llm_count = dict()\n",
    "for file in glob.glob(\"path_to_gpt4o_batch_output.jsonl\"):\n",
    "    with open(file) as f:\n",
    "        for line in f:\n",
    "            res = json.loads(line)\n",
    "            llm_count[res['custom_id']] = json.loads(res['response']['body']['choices'][0]['message']['content'])['count']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df['llm_count'] = [0]*len(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for key, value in llm_count.items():\n",
    "    _key = key.split(\".\")[0]\n",
    "    df.loc[df['id'] == int(_key), 'llm_count'] += value"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "df.to_csv(\"output0.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calculate_metrics(df):\n",
    "    ea = (df['llm_count'].to_numpy() == df['answer'].to_numpy()).sum() / len(df)\n",
    "    mae = sum(abs(df['llm_count'].to_numpy() - df['answer'].to_numpy())) / len(df)\n",
    "    rmse = sum(abs(df['llm_count'].to_numpy() - df['answer'].to_numpy())) / len(df)\n",
    "    print(\"EA\", ea)\n",
    "    print(\"MAE\", mae)\n",
    "    print(\"RMSE\", rmse)\n",
    "\n",
    "    print(f\"{ea}\\t{mae}\\t{rmse}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_csv(\"path_to_results.csv\")\n",
    "calculate_metrics(df)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
