{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from mistralai import Mistral\n",
    "from dotenv import load_dotenv\n",
    "from common import BongardDataset\n",
    "import base64\n",
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    "import time\n",
    "import re\n",
    "from typing import List\n",
    "import json\n",
    "\n",
    "load_dotenv()\n",
    "\n",
    "\n",
    "def encode_image(image_path):\n",
    "    try:\n",
    "        with open(image_path, \"rb\") as image_file:\n",
    "            return base64.b64encode(image_file.read()).decode('utf-8')\n",
    "    except FileNotFoundError:\n",
    "        print(f\"Error: The file {image_path} was not found.\")\n",
    "        return None\n",
    "    except Exception as e:\n",
    "        print(f\"Error: {e}\")\n",
    "        return None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def parse_prompts(text) -> List[str]:\n",
    "    regex = r\"\\[(.*?)\\]\"\n",
    "\n",
    "    match = re.search(regex, text, re.DOTALL)\n",
    "\n",
    "    if match:\n",
    "        return json.loads(match.group(0))\n",
    "    else:\n",
    "        print(f\"Error: Could not parse prompts from text: {text}\")\n",
    "        return []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "PROMPT = \"\"\"\n",
    "You are tasked with modifying a given prompt for a diffusion model. \n",
    "Your primary objective is to preserve the specified **concept** while altering unrelated details or environments. \n",
    "You are encouraged to create diverse, creative, and unique augmentations that stay true to the concept but introduce variety in interpretation.\n",
    "\n",
    "### Example:\n",
    "Prompt: \"An empty white bowl with a thin black rim placed on a solid blue background.\"\n",
    "Concept: \"Empty picture\"\n",
    "Output Augmentations: [\n",
    "    \"An empty red plate on a wooden table.\",\n",
    "    \"A clear glass cup sitting on a marble countertop.\",\n",
    "    \"A white ceramic vase on a patterned fabric.\"\n",
    "]\n",
    "\n",
    "Now, it's your turn: \n",
    "Prompt: {prompt}\n",
    "Concept: {concept}\n",
    "\n",
    "### Instructions:\n",
    "1. Generate **{n_augmentations} unique augmentations** for the provided prompt.\n",
    "2. Output the results in the following JSON string array format:\n",
    "[\n",
    "    \"<augmented_prompt_1>\",\n",
    "    \"<augmented_prompt_2>\",\n",
    "    ...\n",
    "]\n",
    "\n",
    "Ensure that each augmentation aligns with the concept and introduces creative variations in other details.\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "840it [2:14:49,  9.63s/it]\n"
     ]
    }
   ],
   "source": [
    "OUTPUT_FILE = 'bonagrd_rwr_prompts_augmented_plus_15.csv'\n",
    "prompts = pd.read_csv(\"bonagrd_rwr_prompts.csv\")\n",
    "augmented = pd.read_csv(OUTPUT_FILE) if os.path.exists(OUTPUT_FILE) else pd.DataFrame(columns=[\"problem_id\", \"file\", \"side\", \"positive\", \"negative\"]) \n",
    "api_key = os.environ[\"MISTRAL_API_KEY\"]\n",
    "model = \"pixtral-12b-2409\"\n",
    "client = Mistral(api_key=api_key)\n",
    "dataset = BongardDataset(\"../data/bongard-rwr\")\n",
    "answers = []\n",
    "\n",
    "for problem_id, file_name, side, file_path in tqdm(dataset.all_fragments()):\n",
    "    if augmented.query(f\"problem_id == {problem_id} and file == '{file_name}' and side == '{side}'\").shape[0] > 0:\n",
    "        continue\n",
    "\n",
    "    try: \n",
    "        prompt_data = prompts.query(f\"problem_id == {problem_id} and file == '{file_name}' and side == '{side}'\")\n",
    "        prompt = prompt_data['positive'].iloc[0]\n",
    "        negative = prompt_data['negative'].iloc[0]\n",
    "\n",
    "        concept = dataset.get_label(problem_id, side)\n",
    "        question = PROMPT.format(prompt=prompt, concept=concept, n_augmentations=15)\n",
    "\n",
    "        chat_response = client.chat.complete(\n",
    "            model= model,\n",
    "            messages = [\n",
    "                {\n",
    "                    \"role\": \"user\",\n",
    "                    \"content\": [\n",
    "                        {\n",
    "                            \"type\": \"text\",\n",
    "                            \"text\": question\n",
    "                        },\n",
    "                    ]\n",
    "                },\n",
    "            ]\n",
    "        )\n",
    "\n",
    "        new_prompts = parse_prompts(chat_response.choices[0].message.content.strip())\n",
    "\n",
    "        answers.extend({\n",
    "            \"problem_id\": problem_id,\n",
    "            \"file\": file_name,\n",
    "            \"side\": side,\n",
    "            \"positive\":positive, \n",
    "            \"negative\": negative\n",
    "        } for positive in new_prompts)\n",
    "\n",
    "        df = pd.concat([pd.DataFrame(answers), augmented])\n",
    "        df.to_csv(OUTPUT_FILE, index=False)\n",
    "\n",
    "        time.sleep(5)\n",
    "        \n",
    "    except Exception as e:\n",
    "        print(f\"Error: {e}\")\n",
    "        print(f\"Problem ID: {problem_id}, File: {file_name}, Side: {side}\")\n",
    "        continue"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
