{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from mistralai import Mistral\n",
    "from dotenv import load_dotenv\n",
    "from common import BongardDataset\n",
    "import base64\n",
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    "import time\n",
    "import re\n",
    "\n",
    "load_dotenv()\n",
    "\n",
    "\n",
    "def encode_image(image_path):\n",
    "    try:\n",
    "        with open(image_path, \"rb\") as image_file:\n",
    "            return base64.b64encode(image_file.read()).decode('utf-8')\n",
    "    except FileNotFoundError:\n",
    "        print(f\"Error: The file {image_path} was not found.\")\n",
    "        return None\n",
    "    except Exception as e:\n",
    "        print(f\"Error: {e}\")\n",
    "        return None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "def parse_prompts(text):\n",
    "    regex = r\"Positive prompt:\\s*(.+?)\\s*Negative prompt:\\s*(.+)\"\n",
    "\n",
    "    match = re.search(regex, text, re.DOTALL)\n",
    "\n",
    "    if match:\n",
    "        positive_prompt = match.group(1).strip().replace(\"\\\"\", \"\")\n",
    "        negative_prompt = match.group(2).strip().replace(\"\\\"\", \"\")\n",
    "        return positive_prompt, negative_prompt\n",
    "    else:\n",
    "        print(f\"Error: Could not parse prompts from text: {text}\")\n",
    "        return \"\", \"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "QUESTION = \"\"\"\n",
    "You are given an image and a general concept. Your task is to refine the concept into two concise, visually descriptive prompts: one that aligns with the image (positive prompt) and one that contrasts with it (negative prompt). Focus on making each prompt specific, clearly grounded in the image, and reflective of the core idea. You don’t need to match every detail—just convey the main visual concept.\n",
    "\n",
    "### Example:\n",
    "Image: Human legs wearing socks with vertical lines\n",
    "Concept: Vertical lines.\n",
    "Positive prompt: Socks with vertical lines\n",
    "Negative prompt: Socks with horizontal lines \n",
    "\n",
    "Now, it's your turn: \n",
    "Concept: {concept}\n",
    "\n",
    "### Instructions:\n",
    "1. Generate a positive and negative prompt based on the provided image and concept.\n",
    "2. Answer using the following format:\n",
    "Positive prompt:\n",
    "Negative prompt:\n",
    "\"\"\".strip()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "840it [00:08, 98.29it/s] \n"
     ]
    }
   ],
   "source": [
    "FILE_NAME = \"bonagrd_rwr_prompts.csv\"\n",
    "already_answered = pd.read_csv(FILE_NAME) if os.path.exists(FILE_NAME) else pd.DataFrame(columns=[\"problem_id\", \"file\", \"side\", \"positive\", \"negative\"])\n",
    "api_key = os.environ[\"MISTRAL_API_KEY\"]\n",
    "model = \"pixtral-12b-2409\"\n",
    "client = Mistral(api_key=api_key)\n",
    "dataset = BongardDataset(\"../data/bongard-rwr\")\n",
    "answers = []\n",
    "\n",
    "for problem_id, file_name, side, file_path in tqdm(dataset.all_fragments()):\n",
    "    if already_answered.query(f\"problem_id == {problem_id} and file == '{file_name}' and side == '{side}'\").shape[0] > 0:\n",
    "        continue\n",
    "\n",
    "    try: \n",
    "        left_label, right_label = dataset.get_labels(problem_id)\n",
    "        encoded_image = encode_image(file_path)\n",
    "\n",
    "        chat_response = client.chat.complete(\n",
    "            model= model,\n",
    "            messages = [\n",
    "                {\n",
    "                    \"role\": \"user\",\n",
    "                    \"content\": [\n",
    "                        {\n",
    "                            \"type\": \"text\",\n",
    "                            \"text\": QUESTION.format(concept=left_label if side == \"left\" else right_label)\n",
    "                        },\n",
    "                        {\n",
    "                            \"type\": \"image_url\",\n",
    "                            \"image_url\": f\"data:image/jpeg;base64,{encoded_image}\" \n",
    "                        }\n",
    "                    ]\n",
    "                },\n",
    "            ]\n",
    "        )\n",
    "\n",
    "        positive_prompt, negative_prompt = parse_prompts(chat_response.choices[0].message.content)\n",
    "\n",
    "        answers.append({\n",
    "            \"problem_id\": problem_id,\n",
    "            \"file\": file_name,\n",
    "            \"side\": side,\n",
    "            \"positive\": positive_prompt,\n",
    "            \"negative\": negative_prompt,\n",
    "        })\n",
    "\n",
    "        df = pd.concat([pd.DataFrame(answers), already_answered])\n",
    "        df.to_csv(FILE_NAME, index=False)\n",
    "\n",
    "        time.sleep(5)\n",
    "        \n",
    "    except Exception as e:\n",
    "        print(f\"Error: {e}\")\n",
    "        print(f\"Problem ID: {problem_id}, File: {file_name}, Side: {side}\")\n",
    "        continue"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(840, 6)"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.read_csv(\"bonagrd_rwr_prompts.csv\")\n",
    "df.shape"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
