{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import os\n",
    "\n",
    "from dotenv import load_dotenv\n",
    "\n",
    "from src.evaluation.model import evalute_answers_using_model, count_correct_answers, get_resolved_problems, get_unresolved_problems, get_resolved_problems_with_voting, get_voting_mean_agreement\n",
    "from src.llm_messenger.messengers.claude_messenger import ClaudeMessenger\n",
    "from src.llm_messenger.messengers.google_messenger import GoogleMessenger\n",
    "from src.llm_messenger.messengers.gpt_messenger import GPTMessenger\n",
    "from src.prompting_techniques.descriptive import get_descriptions, resolve_bongard_with_descriptive_prompting\n",
    "from typing import List, Tuple\n",
    "from src.llm_messenger.classes.llm_messenger import LLMMessenger\n",
    "\n",
    "load_dotenv()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "DESCRIBE_IMAGE_PROMPT = \"\"\"\n",
    "The provided image is a part of an abstract visual reasoning problem. Describe all crucial properties of the image. Your description should be as concise as possible. Focus on the most important details. The image is provided correctly. Respond only with descriptions.\n",
    "\"\"\".strip()\n",
    "\n",
    "PROBLEM_DESCRIPTION = \"\"\"\n",
    "A Bongard Problem is composed of left and right sides separated by a line. Each side contains six images. All images belonging to one side present a common concept, which is lacking in all images from the other side. The goal is to describe the rule that fits all images on the left side, but none on the right, and, conversely, the rule that fits all images on the right side, but none on the left. The description of the rule should be simple and concise.\n",
    "Example 1: All shapes on left are small. All shapes on right are big.\n",
    "Example 2: The left side contains circles. The right side contains triangles.\n",
    "\"\"\".strip()\n",
    "\n",
    "QUESTION = \"\"\"\n",
    "You are a vision understanding module designed to provide short, clear and accurate answers. Your goal is to solve the provided Bongard Problem using descriptions of its images.\n",
    "\n",
    "LEFT IMAGES:\n",
    "{left_descriptions}\n",
    "\n",
    "RIGHT IMAGES:\n",
    "{right_descriptions}\n",
    "\n",
    "What is the difference between the two sides of the problem?\n",
    "\"\"\".strip()\n",
    "\n",
    "QUESTION_DESCRIPTIVE_DIRECT = \"\"\"\n",
    "You are a vision understanding module designed to provide short, clear and accurate answers. Your goal is to solve the provided Bongard Problem with the help of its image descriptions.\n",
    "\n",
    "LEFT IMAGES:\n",
    "{left_descriptions}\n",
    "\n",
    "RIGHT IMAGES:\n",
    "{right_descriptions}\n",
    "\n",
    "What is the difference between the two sides of the problem?\n",
    "\"\"\".strip()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def prepare_models(experiment_dir: str) -> Tuple[LLMMessenger, LLMMessenger, LLMMessenger]:\n",
    "    gpt4o = GPTMessenger(\n",
    "        model_name=\"gpt-4o\",\n",
    "        api_key=os.environ[\"OPENAI_API_KEY\"],\n",
    "        log_directory=experiment_dir,\n",
    "    )\n",
    "\n",
    "    gpt4turbo = GPTMessenger(\n",
    "        model_name=\"gpt-4-turbo\",\n",
    "        api_key=os.environ[\"OPENAI_API_KEY\"],\n",
    "        log_directory=experiment_dir,\n",
    "    )\n",
    "\n",
    "    gemini = GoogleMessenger(\n",
    "        api_key=os.environ[\"GOOGLE_API_KEY\"],\n",
    "        model_name=\"gemini-1.5-pro\",\n",
    "        log_directory=experiment_dir,\n",
    "    )\n",
    "\n",
    "    claude = ClaudeMessenger(\n",
    "        api_key=os.environ[\"ANTHROPIC_API_KEY\"],\n",
    "        model_name=\"claude-3-5-sonnet-20240620\",\n",
    "        log_directory=experiment_dir,\n",
    "    )\n",
    "\n",
    "    return gpt4o, gpt4turbo, gemini, claude"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "experiment_dir = \"../data/processed/bongard/optimize_evaluation_prompt/descriptive\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prepare results for the optimization of the evalutation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "gpt4o, gpt4turbo, gemini, claude = prepare_models(experiment_dir)\n",
    "all_models : List[LLMMessenger] = [gpt4o, gpt4turbo, gemini, claude]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "('../data/processed/bongard/optimize_evaluation_prompt/descriptive/gpt-4o_descriptions.json',\n",
       " '../data/processed/bongard/optimize_evaluation_prompt/descriptive/gpt-4o_answers.json')"
      ]
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "descriptions_file = f\"{experiment_dir}/{gpt4o.get_name()}_descriptions.json\"\n",
    "answers_file = f\"{experiment_dir}/{gpt4o.get_name()}_answers.json\"\n",
    "descriptions_file, answers_file"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [03:14<00:00,  9.75s/it]\n"
     ]
    }
   ],
   "source": [
    "get_descriptions(\n",
    "    problem_ids=list(range(101, 121)),\n",
    "    splitted_data_path=\"../data/raw/bongard_splitted\",\n",
    "    model=gpt4o,\n",
    "    output_file=descriptions_file, \n",
    "    question=DESCRIBE_IMAGE_PROMPT,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/20 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:15<00:00,  1.32it/s]\n"
     ]
    }
   ],
   "source": [
    "resolve_bongard_with_descriptive_prompting(\n",
    "    problem_ids=list(range(101, 121)),\n",
    "    splitted_data_path=\"../data/raw/bongard_splitted\",\n",
    "    output_file=answers_file,\n",
    "    model=gpt4o,\n",
    "    descriptions_file=descriptions_file,\n",
    "    problem_description=PROBLEM_DESCRIPTION,\n",
    "    question=QUESTION,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:00<00:00, 19553.86it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gpt-4o {118}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:20<00:00,  1.02s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gpt-4-turbo {106, 108, 111, 115, 118}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:00<?, ?it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gemini-1.5-pro {118}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "claude-3-5-sonnet-20240620 {106, 111, 115, 118, 119}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "for model in all_models:\n",
    "    evalute_answers_using_model(\n",
    "        answers_file=answers_file,\n",
    "        labels_file=\"../data/raw/labels.csv\",\n",
    "        model=model,\n",
    "    )\n",
    "    print(model.get_name(), get_resolved_problems(answers_file, model.get_name()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluate with images attached"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:00<?, ?it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gpt-4o+image {118, 111}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:00<?, ?it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gpt-4-turbo+image {106, 115, 118, 111}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:00<?, ?it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gemini-1.5-pro+image {118, 111}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "claude-3-5-sonnet-20240620+image {106, 108, 115, 118, 119}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "for model in all_models:\n",
    "    evaluation_author = f\"{model.get_name()}+image\"\n",
    "    evalute_answers_using_model(\n",
    "        answers_file=answers_file,\n",
    "        labels_file=\"../data/raw/labels.csv\",\n",
    "        model=model,\n",
    "        with_image=True,\n",
    "        splitted_data_path=\"../data/raw/bongard_splitted\",\n",
    "        author=evaluation_author,\n",
    "    )\n",
    "    print(evaluation_author, get_resolved_problems(answers_file, evaluation_author))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prompt optimization "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "metadata": {},
   "outputs": [],
   "source": [
    "STRICT_PROMPT = \"\"\"\n",
    "You are a logic module designed to provide accurate answers. \n",
    "In a Bongard Problem the objective is to spot the difference between the contents of images located on the two opposite sides of the problem. \n",
    "You are given correct labels of these sides and must decide whether the answer provided by the user is correct and matches with those labels. Answer with 'OK' or 'WRONG'.\n",
    "The user's answer has to strictly match the labels, as shown in examples.\n",
    "\n",
    "FIRST EXAMPLE: \n",
    "LEFT SIDE LABEL: All shapes are small.\n",
    "RIGHT SIDE LABEL: All shapes are big.\n",
    "USER ANSWER: On the left side, one of the shapes is small. On the right side, all of the shapes are big.\n",
    "EVALUATION: WRONG\n",
    "END OF FIRST EXAMPLE.\n",
    "\n",
    "SECOND EXAMPLE:\n",
    "LEFT SIDE LABEL: All shapes are small.\n",
    "RIGHT SIDE LABEL: All shapes are big.\n",
    "USER ANSWER: On the left side, all of the shapes are small. On the right side, all of the shapes are big.\n",
    "EVALUATION: OK\n",
    "END OF SECOND EXAMPLE.\n",
    "\n",
    "\n",
    "LEFT SIDE LABEL:\n",
    "{first_label}\n",
    "\n",
    "RIGHT SIDE LABEL:\n",
    "{second_label}\n",
    "\n",
    "USER ANSWER:\n",
    "{model_answer}\n",
    "\"\"\".strip()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:00<?, ?it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gpt-4o+strict-prompt set()\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:16<00:00,  1.20it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gpt-4-turbo+strict-prompt {118}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:00<?, ?it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gemini-1.5-pro+strict-prompt {118}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "claude-3-5-sonnet-20240620+strict-prompt set()\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "for model in all_models:\n",
    "    evaluation_author = f\"{model.get_name()}+strict-prompt\"\n",
    "    evalute_answers_using_model(\n",
    "        answers_file=answers_file,\n",
    "        labels_file=\"../data/raw/labels.csv\",\n",
    "        model=model,\n",
    "        author=evaluation_author,\n",
    "        prompt=STRICT_PROMPT\n",
    "    )\n",
    "    print(evaluation_author, get_resolved_problems(answers_file, evaluation_author))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/20 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:14<00:00,  1.40it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gpt-4o+logic-prompt {118}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:25<00:00,  1.29s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gpt-4-turbo+logic-prompt {106, 115, 118, 111}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:17<00:00,  1.11it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gemini-1.5-pro+logic-prompt {118, 111}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:25<00:00,  1.26s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "claude-3-5-sonnet-20240620+logic-prompt {106, 118}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "for model in all_models:\n",
    "    evaluation_author = f\"{model.get_name()}+strict-prompt+second-try\"\n",
    "    evalute_answers_using_model(\n",
    "        answers_file=answers_file,\n",
    "        labels_file=\"../data/raw/labels.csv\",\n",
    "        model=model,\n",
    "        author=evaluation_author,\n",
    "        prompt=STRICT_PROMPT\n",
    "    )\n",
    "    print(evaluation_author, get_resolved_problems(answers_file, evaluation_author))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:00<?, ?it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gpt-4o+image+strict-prompt {118}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/20 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:31<00:00,  1.59s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gpt-4-turbo+image+strict-prompt {106, 118, 111}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:00<?, ?it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gemini-1.5-pro+image+strict-prompt {118}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "claude-3-5-sonnet-20240620+image+strict-prompt set()\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "for model in all_models:\n",
    "    evaluation_author = f\"{model.get_name()}+image+strict-prompt\"\n",
    "    evalute_answers_using_model(\n",
    "        answers_file=answers_file,\n",
    "        labels_file=\"../data/raw/labels.csv\",\n",
    "        model=model,\n",
    "        with_image=True,\n",
    "        splitted_data_path=\"../data/raw/bongard_splitted\",\n",
    "        author=evaluation_author,\n",
    "        prompt=STRICT_PROMPT\n",
    "    )\n",
    "    print(evaluation_author, get_resolved_problems(answers_file, evaluation_author))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "LOGIC_PROMPT = \"\"\"\n",
    "You are a logic module designed to provide accurate answers. \n",
    "In a Bongard Problem the objective is to spot the difference between the contents of images located on the two opposite sides of the problem. \n",
    "You are given correct labels of these sides and must decide whether the answer provided by the user is correct and matches with those labels. Answer with 'OK' or 'WRONG'.\n",
    "The user's answer must logically match the labels, as shown in the examples.\n",
    "\n",
    "FIRST EXAMPLE: \n",
    "LEFT SIDE LABEL: All shapes are small.\n",
    "RIGHT SIDE LABEL: All shapes are big.\n",
    "USER ANSWER: On the left side, one of the shapes is small. On the right side, all of the shapes are big.\n",
    "EVALUATION: WRONG\n",
    "END OF FIRST EXAMPLE.\n",
    "\n",
    "SECOND EXAMPLE:\n",
    "LEFT SIDE LABEL: All shapes are small.\n",
    "RIGHT SIDE LABEL: All shapes are big.\n",
    "USER ANSWER: On the left side, all of the shapes are small. On the right side, all of the shapes are big.\n",
    "EVALUATION: OK\n",
    "END OF SECOND EXAMPLE.\n",
    "\n",
    "\n",
    "LEFT SIDE LABEL:\n",
    "{first_label}\n",
    "\n",
    "RIGHT SIDE LABEL:\n",
    "{second_label}\n",
    "\n",
    "USER ANSWER:\n",
    "{model_answer}\n",
    "\"\"\".strip()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:13<00:00,  1.50it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gpt-4o+logic-prompt {118}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:20<00:00,  1.03s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gpt-4-turbo+logic-prompt {106, 115, 118}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:18<00:00,  1.08it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gemini-1.5-pro+logic-prompt {118, 111}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:26<00:00,  1.33s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "claude-3-5-sonnet-20240620+logic-prompt {106, 118}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "for model in all_models:\n",
    "    evaluation_author = f\"{model.get_name()}+logic-prompt\"\n",
    "    evalute_answers_using_model(\n",
    "        answers_file=answers_file,\n",
    "        labels_file=\"../data/raw/labels.csv\",\n",
    "        model=model,\n",
    "        author=evaluation_author,\n",
    "        prompt=LOGIC_PROMPT\n",
    "    )\n",
    "    print(evaluation_author, get_resolved_problems(answers_file, evaluation_author))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:10<00:00,  1.82it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gpt-4o+logic-prompt+second-try {118}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:19<00:00,  1.00it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gpt-4-turbo+logic-prompt+second-try {106, 115, 118, 111}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:17<00:00,  1.17it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gemini-1.5-pro+logic-prompt+second-try {118, 111}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:25<00:00,  1.25s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "claude-3-5-sonnet-20240620+logic-prompt+second-try {106, 118}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "for model in all_models:\n",
    "    evaluation_author = f\"{model.get_name()}+logic-prompt+second-try\"\n",
    "    evalute_answers_using_model(\n",
    "        answers_file=answers_file,\n",
    "        labels_file=\"../data/raw/labels.csv\",\n",
    "        model=model,\n",
    "        author=evaluation_author,\n",
    "        prompt=LOGIC_PROMPT\n",
    "    )\n",
    "    print(evaluation_author, get_resolved_problems(answers_file, evaluation_author))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:40<00:00,  2.05s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gpt-4o+image+logic-prompt {118}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:37<00:00,  1.88s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gpt-4-turbo+image+logic-prompt {106, 118, 111}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:22<00:00,  1.13s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gemini-1.5-pro+image+logic-prompt {106, 118, 119}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:43<00:00,  2.16s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "claude-3-5-sonnet-20240620+image+logic-prompt {106, 118}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "for model in all_models:\n",
    "    evaluation_author = f\"{model.get_name()}+image+logic-prompt\"\n",
    "    evalute_answers_using_model(\n",
    "        answers_file=answers_file,\n",
    "        labels_file=\"../data/raw/labels.csv\",\n",
    "        model=model,\n",
    "        with_image=True,\n",
    "        splitted_data_path=\"../data/raw/bongard_splitted\",\n",
    "        author=evaluation_author,\n",
    "        prompt=LOGIC_PROMPT\n",
    "    )\n",
    "    print(evaluation_author, get_resolved_problems(answers_file, evaluation_author))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:37<00:00,  1.89s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gpt-4o+image+logic-prompt+second-try {118}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:46<00:00,  2.32s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gpt-4-turbo+image+logic-prompt+second-try {106, 118, 111}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:23<00:00,  1.18s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gemini-1.5-pro+image+logic-prompt+second-try {106, 118, 119}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:38<00:00,  1.91s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "claude-3-5-sonnet-20240620+image+logic-prompt+second-try {106, 118}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "for model in all_models:\n",
    "    evaluation_author = f\"{model.get_name()}+image+logic-prompt+second-try\"\n",
    "    evalute_answers_using_model(\n",
    "        answers_file=answers_file,\n",
    "        labels_file=\"../data/raw/labels.csv\",\n",
    "        model=model,\n",
    "        with_image=True,\n",
    "        splitted_data_path=\"../data/raw/bongard_splitted\",\n",
    "        author=evaluation_author,\n",
    "        prompt=LOGIC_PROMPT\n",
    "    )\n",
    "    print(evaluation_author, get_resolved_problems(answers_file, evaluation_author))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [],
   "source": [
    "STRICT_LOGIC_PROMPT = \"\"\"\n",
    "You are a logic module designed to provide accurate answers. \n",
    "In a Bongard Problem the objective is to spot the difference between the contents of images located on the two opposite sides of the problem. \n",
    "You are given correct labels of these sides and must decide whether the answer provided by the user is correct and matches with those labels. Answer with 'OK' or 'WRONG'.\n",
    "The user's answer has to strictly logically match the labels, as shown in examples.\n",
    "\n",
    "FIRST EXAMPLE: \n",
    "LEFT SIDE LABEL: All shapes are small.\n",
    "RIGHT SIDE LABEL: All shapes are big.\n",
    "USER ANSWER: On the left side, one of the shapes is small. On the right side, all of the shapes are big.\n",
    "EVALUATION: WRONG\n",
    "END OF FIRST EXAMPLE.\n",
    "\n",
    "SECOND EXAMPLE:\n",
    "LEFT SIDE LABEL: All shapes are small.\n",
    "RIGHT SIDE LABEL: All shapes are big.\n",
    "USER ANSWER: On the left side, all of the shapes are small. On the right side, all of the shapes are big.\n",
    "EVALUATION: OK\n",
    "END OF SECOND EXAMPLE.\n",
    "\n",
    "\n",
    "LEFT SIDE LABEL:\n",
    "{first_label}\n",
    "\n",
    "RIGHT SIDE LABEL:\n",
    "{second_label}\n",
    "\n",
    "USER ANSWER:\n",
    "{model_answer}\n",
    "\"\"\".strip()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:11<00:00,  1.73it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gpt-4o+strict-logic-prompt {118}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:21<00:00,  1.07s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gpt-4-turbo+strict-logic-prompt {106, 115, 118}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:16<00:00,  1.18it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gemini-1.5-pro+strict-logic-prompt {118}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:22<00:00,  1.13s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "claude-3-5-sonnet-20240620+strict-logic-prompt {118}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "for model in all_models:\n",
    "    evaluation_author = f\"{model.get_name()}+strict-logic-prompt\"\n",
    "    evalute_answers_using_model(\n",
    "        answers_file=answers_file,\n",
    "        labels_file=\"../data/raw/labels.csv\",\n",
    "        model=model,\n",
    "        author=evaluation_author,\n",
    "        prompt=STRICT_LOGIC_PROMPT\n",
    "    )\n",
    "    print(evaluation_author, get_resolved_problems(answers_file, evaluation_author))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:45<00:00,  2.29s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gpt-4o+image+strict-logic-prompt {118}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:34<00:00,  1.70s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gpt-4-turbo+image+strict-logic-prompt {118, 111}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:25<00:00,  1.28s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gemini-1.5-pro+image+strict-logic-prompt {118}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:38<00:00,  1.93s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "claude-3-5-sonnet-20240620+image+strict-logic-prompt {106, 118}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "for model in all_models:\n",
    "    evaluation_author = f\"{model.get_name()}+image+strict-logic-prompt\"\n",
    "    evalute_answers_using_model(\n",
    "        answers_file=answers_file,\n",
    "        labels_file=\"../data/raw/labels.csv\",\n",
    "        model=model,\n",
    "        with_image=True,\n",
    "        splitted_data_path=\"../data/raw/bongard_splitted\",\n",
    "        author=evaluation_author,\n",
    "        prompt=STRICT_LOGIC_PROMPT\n",
    "    )\n",
    "    print(evaluation_author, get_resolved_problems(answers_file, evaluation_author))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{106, 118}"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "authors = [f\"{model.get_name()}+image+logic-prompt\" for model in all_models]\n",
    "get_resolved_problems_with_voting(answers_file, authors)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "STRICT_LOGIC_IMAGES_PROMPT = \"\"\"\n",
    "You are a logic module designed to provide accurate answers. \n",
    "In a Bongard Problem the objective is to spot the difference between the contents of images located on the two opposite sides of the problem. \n",
    "You are given correct labels of these sides and must decide whether the answer provided by the user is correct and matches with those labels. Answer with 'OK' or 'WRONG'.\n",
    "The user's answer should strictly logically match the labels, as shown in the examples. \n",
    "Some answers may be correct even if they do not strictly match the labels, as long as they logically describe the image.\n",
    "\n",
    "FIRST EXAMPLE: \n",
    "LEFT SIDE LABEL: All shapes are small.\n",
    "RIGHT SIDE LABEL: All shapes are big.\n",
    "USER ANSWER: On the left side, one of the shapes is small. On the right side, all of the shapes are big.\n",
    "EVALUATION: WRONG\n",
    "END OF FIRST EXAMPLE.\n",
    "\n",
    "SECOND EXAMPLE:\n",
    "LEFT SIDE LABEL: All shapes are small.\n",
    "RIGHT SIDE LABEL: All shapes are big.\n",
    "USER ANSWER: On the left side, all of the shapes are small. On the right side, all of the shapes are big.\n",
    "EVALUATION: OK\n",
    "END OF SECOND EXAMPLE.\n",
    "\n",
    "LEFT SIDE LABEL:\n",
    "{first_label}\n",
    "\n",
    "RIGHT SIDE LABEL:\n",
    "{second_label}\n",
    "\n",
    "USER ANSWER:\n",
    "{model_answer}\n",
    "\"\"\".strip()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:23<00:00,  1.16s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gpt-4o+strict-logic-images-prompt {118}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:21<00:00,  1.07s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gpt-4-turbo+strict-logic-images-prompt {106, 115, 118, 111}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:17<00:00,  1.13it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gemini-1.5-pro+strict-logic-images-prompt {118, 111}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:26<00:00,  1.30s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "claude-3-5-sonnet-20240620+strict-logic-images-prompt {106, 118}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "for model in all_models:\n",
    "    evaluation_author = f\"{model.get_name()}+strict-logic-images-prompt\"\n",
    "    evalute_answers_using_model(\n",
    "        answers_file=answers_file,\n",
    "        labels_file=\"../data/raw/labels.csv\",\n",
    "        model=model,\n",
    "        author=evaluation_author,\n",
    "        prompt=STRICT_LOGIC_IMAGES_PROMPT\n",
    "    )\n",
    "    print(evaluation_author, get_resolved_problems(answers_file, evaluation_author))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Evaluate first 100 BPs with new prompts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "answers_file = \"../data/processed/bongard/experiments/synthetic_generation_prompting-descriptive/gpt-4o_strict_logic_prompt_evaluation_answers.json\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:00<?, ?it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gpt-4o {97, 36, 5, 6, 7, 8, 9, 10, 38, 43, 15, 53, 23, 31}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:00<?, ?it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gpt-4-turbo {5, 6, 7, 8, 10, 15, 23, 24, 31, 36, 38, 43, 52, 53, 57, 94, 96, 97, 100}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:00<?, ?it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gemini-1.5-pro {5, 6, 7, 9, 10, 15, 23, 24, 27, 29, 31, 36, 38, 43, 47, 52, 53, 57, 97, 100}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "claude-3-5-sonnet-20240620 {97, 36, 5, 6, 7, 8, 38, 10, 43, 100, 15, 53, 24, 57, 94}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "for model in all_models:\n",
    "    evaluation_author = model.get_name()\n",
    "    evalute_answers_using_model(\n",
    "        answers_file=answers_file,\n",
    "        labels_file=\"../data/raw/labels.csv\",\n",
    "        model=model,\n",
    "        author=evaluation_author,\n",
    "        prompt=STRICT_LOGIC_PROMPT,\n",
    "    )\n",
    "    print(evaluation_author, get_resolved_problems(answers_file, evaluation_author))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{5, 6, 7, 8, 9, 10, 15, 23, 24, 31, 36, 38, 43, 52, 53, 57, 94, 97, 100}"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "problems_resolved_by_voting = get_resolved_problems_with_voting(answers_file)\n",
    "problems_resolved_by_voting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.96"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "get_voting_mean_agreement(answers_file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [],
   "source": [
    "answers_file = \"../data/processed/bongard/experiments/synthetic_generation_prompting-descriptive/gpt-4o_strict_logic__photo_prompt_evaluation_answers.json\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:00<00:00, 21124.67it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gpt-4o {97, 36, 5, 6, 7, 8, 9, 10, 38, 43, 100, 15, 53, 23, 24, 57, 31}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:00<?, ?it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gpt-4-turbo {96, 97, 36, 5, 6, 7, 38, 100, 10, 43, 15, 52, 53, 24, 57, 94}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [05:51<00:00,  3.52s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gemini-1.5-pro {5, 6, 7, 9, 10, 15, 23, 24, 27, 29, 31, 34, 36, 38, 43, 47, 52, 53, 57, 75, 94, 97, 100}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [04:39<00:00,  2.79s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "claude-3-5-sonnet-20240620 {1, 5, 6, 7, 8, 10, 15, 24, 36, 38, 43, 52, 53, 57, 71, 94, 96, 97, 100}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "for model in all_models:\n",
    "    evaluation_author = model.get_name()\n",
    "    evalute_answers_using_model(\n",
    "        answers_file=answers_file,\n",
    "        labels_file=\"../data/raw/labels.csv\",\n",
    "        model=model,\n",
    "        author=evaluation_author,\n",
    "        prompt=STRICT_LOGIC_PROMPT,\n",
    "        with_image=True,\n",
    "        splitted_data_path=\"../data/raw/bongard_splitted\",\n",
    "    )\n",
    "    print(evaluation_author, get_resolved_problems(answers_file, evaluation_author))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{5, 6, 7, 8, 9, 10, 15, 23, 24, 31, 36, 38, 43, 52, 53, 57, 94, 97, 100}"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "problems_resolved_by_voting_with_photo = get_resolved_problems_with_voting(answers_file)\n",
    "problems_resolved_by_voting_with_photo"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9525"
      ]
     },
     "execution_count": 57,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "get_voting_mean_agreement(answers_file)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Comparing new prompts to the old one"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "original_answers_file = \"../data/processed/bongard/experiments/synthetic_generation_prompting-descriptive/gpt-4o_answers.json\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Problems manually evaluated as correct (from the set of problems evaluated as correct by any model)\n",
    "marked_as_correct_manually = [\n",
    "    5, 6, 7, 8, 9, 10, 15, 23, 24, 31, 36, 38, 43, 52, 53, 57, 85, 94, 97, 100\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.89"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "marked_as_correct_by_gpt4 = get_resolved_problems(original_answers_file, 'gpt-4o')\n",
    "marked_as_correct_by_gemini = get_resolved_problems(original_answers_file, 'gemini-1.5-pro')\n",
    "marked_as_correct_by_claude = get_resolved_problems(original_answers_file, 'claude-3-5-sonnet-20240620')\n",
    "\n",
    "marked_as_correct_by_all = marked_as_correct_by_gpt4.intersection(marked_as_correct_by_gemini).intersection(marked_as_correct_by_claude)\n",
    "marked_as_correct_by_any = marked_as_correct_by_gpt4.union(marked_as_correct_by_gemini).union(marked_as_correct_by_claude)\n",
    "marked_as_wrong_by_all = set(range(1, 101)).difference(marked_as_correct_by_gpt4).difference(marked_as_correct_by_claude).difference(marked_as_correct_by_gemini)\n",
    "\n",
    "agreement_level = len(marked_as_correct_by_all.union(marked_as_wrong_by_all))/ 100\n",
    "agreement_level"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{1,\n",
       " 5,\n",
       " 6,\n",
       " 7,\n",
       " 8,\n",
       " 10,\n",
       " 15,\n",
       " 23,\n",
       " 24,\n",
       " 27,\n",
       " 29,\n",
       " 31,\n",
       " 36,\n",
       " 38,\n",
       " 43,\n",
       " 47,\n",
       " 52,\n",
       " 53,\n",
       " 57,\n",
       " 71,\n",
       " 85,\n",
       " 94,\n",
       " 97,\n",
       " 100}"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "problems_resolved_by_original_voting = get_resolved_problems_with_voting(original_answers_file)\n",
    "problems_resolved_by_original_voting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{1, 27, 29, 39, 47, 59, 71, 95, 96, 98}"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "marked_as_correct_by_any.symmetric_difference(marked_as_correct_manually)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{9, 27, 29, 31, 47, 85, 94}"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "marked_as_correct_by_all.symmetric_difference(marked_as_correct_manually)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{1, 9, 27, 29, 47, 71}"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "problems_resolved_by_original_voting.symmetric_difference(marked_as_correct_manually)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{85}"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "problems_resolved_by_voting.symmetric_difference(marked_as_correct_manually)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{85}"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "problems_resolved_by_voting_with_photo.symmetric_difference(marked_as_correct_manually)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{1, 9, 27, 29, 47, 71, 85}"
      ]
     },
     "execution_count": 73,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "problems_resolved_by_original_voting.symmetric_difference(problems_resolved_by_voting)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9633333333333336"
      ]
     },
     "execution_count": 74,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "get_voting_mean_agreement(original_answers_file)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Check new prompt on Bongard-OpenWorld"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "open_world_experiment_dir = \"../data/processed/bongard/optimize_evaluation_prompt/open_word_descriptive\"\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "gpt4o, gpt4turbo, gemini, claude = prepare_models(open_world_experiment_dir)\n",
    "all_models : List[LLMMessenger] = [gpt4o, gpt4turbo, gemini, claude]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "('../data/processed/bongard/optimize_evaluation_prompt/open_word_descriptive/gpt-4o_descriptions.json',\n",
       " '../data/processed/bongard/optimize_evaluation_prompt/open_word_descriptive/gpt-4o_answers.json')"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "open_world_descriptions_file = f\"{open_world_experiment_dir}/{gpt4o.get_name()}_descriptions.json\"\n",
    "open_world_answers_file = f\"{open_world_experiment_dir}/{gpt4o.get_name()}_answers.json\"\n",
    "open_world_descriptions_file, open_world_answers_file"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [13:24<00:00, 40.23s/it]\n"
     ]
    }
   ],
   "source": [
    "get_descriptions(\n",
    "    problem_ids=list(range(101, 121)),\n",
    "    splitted_data_path=\"../data/raw/bongard_open_world_splitted\",\n",
    "    model=gpt4o,\n",
    "    output_file=open_world_descriptions_file, \n",
    "    question=DESCRIBE_IMAGE_PROMPT,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:32<00:00,  1.61s/it]\n"
     ]
    }
   ],
   "source": [
    "resolve_bongard_with_descriptive_prompting(\n",
    "    problem_ids=list(range(101, 121)),\n",
    "    splitted_data_path=\"../data/raw/bongard_splitted\",\n",
    "    output_file=open_world_answers_file,\n",
    "    model=gpt4o,\n",
    "    descriptions_file=open_world_descriptions_file,\n",
    "    problem_description=PROBLEM_DESCRIPTION,\n",
    "    question=QUESTION,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:14<00:00,  1.42it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gpt-4o {102, 103, 104, 105, 108, 109, 110, 111, 112, 116, 117, 118, 119}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:17<00:00,  1.15it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gpt-4-turbo {101, 102, 103, 104, 105, 106, 108, 109, 110, 111, 112, 115, 116, 117, 118, 119, 120}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:29<00:00,  1.46s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gemini-1.5-pro {101, 102, 103, 104, 105, 108, 109, 111, 112, 116, 117, 119, 120}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:24<00:00,  1.22s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "claude-3-5-sonnet-20240620 {101, 102, 103, 104, 105, 106, 108, 109, 110, 111, 112, 116, 117, 119, 120}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "for model in all_models:\n",
    "    evalute_answers_using_model(\n",
    "        answers_file=open_world_answers_file,\n",
    "        labels_file=\"../data/raw/bongard_open_world_labels.csv\",\n",
    "        model=model,\n",
    "    )\n",
    "    print(model.get_name(), get_resolved_problems(open_world_answers_file, model.get_name()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {},
   "outputs": [],
   "source": [
    "evaluation_authors = [model.get_name() for model in all_models]\n",
    "marked_as_solved_with_initial_prompt_by_voting = get_resolved_problems_with_voting(open_world_answers_file, evaluation_authors)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:00<?, ?it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gpt-4o+strict-logic-prompt {104, 105, 109, 117, 119}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:00<?, ?it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gpt-4-turbo+strict-logic-prompt {101, 103, 104, 105, 109, 111, 112, 115, 116, 117, 119, 120}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:00<?, ?it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gemini-1.5-pro+strict-logic-prompt {102, 103, 104, 105, 109, 111, 112, 116, 117, 119, 120}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:00<00:00, 78840.30it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "claude-3-5-sonnet-20240620+strict-logic-prompt {104, 105, 109, 103}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "for model in all_models:\n",
    "    evaluation_author = f\"{model.get_name()}+strict-logic-prompt\"\n",
    "    evalute_answers_using_model(\n",
    "        answers_file=open_world_answers_file,\n",
    "        labels_file=\"../data/raw/bongard_open_world_labels.csv\",\n",
    "        model=model,\n",
    "        author=evaluation_author,\n",
    "        prompt=STRICT_LOGIC_PROMPT\n",
    "    )\n",
    "    print(evaluation_author, get_resolved_problems(open_world_answers_file, evaluation_author))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "metadata": {},
   "outputs": [],
   "source": [
    "evaluation_authors = [f\"{model.get_name()}+strict-logic-prompt\" for model in all_models]\n",
    "marked_as_correct_with_final_prompt_by_voting = get_resolved_problems_with_voting(open_world_answers_file, evaluation_authors)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{101, 102, 106, 108, 110, 118}"
      ]
     },
     "execution_count": 76,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "marked_as_solved_with_initial_prompt_by_voting.symmetric_difference(marked_as_correct_with_final_prompt_by_voting)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "marked_as_correct_by_manual = get_resolved_problems(open_world_answers_file, 'manual')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{101, 102, 106, 108, 111, 118}"
      ]
     },
     "execution_count": 79,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "marked_as_correct_by_manual.symmetric_difference(marked_as_solved_with_initial_prompt_by_voting)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{110, 111}"
      ]
     },
     "execution_count": 78,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "marked_as_correct_by_manual.symmetric_difference(marked_as_correct_with_final_prompt_by_voting)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.6"
      ]
     },
     "execution_count": 92,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "marked_as_correct_by_gpt4 = get_resolved_problems(open_world_answers_file, 'gpt-4o+strict-logic-prompt')\n",
    "marked_as_correct_by_gemini = get_resolved_problems(open_world_answers_file, 'gemini-1.5-pro+strict-logic-prompt')\n",
    "marked_as_correct_by_claude = get_resolved_problems(open_world_answers_file, 'claude-3-5-sonnet-20240620+strict-logic-prompt')\n",
    "\n",
    "marked_as_correct_by_all = marked_as_correct_by_gpt4.intersection(marked_as_correct_by_gemini).intersection(marked_as_correct_by_claude)\n",
    "marked_as_correct_by_any = marked_as_correct_by_gpt4.union(marked_as_correct_by_gemini).union(marked_as_correct_by_claude)\n",
    "marked_as_wrong_by_all = set(range(101, 121)).difference(marked_as_correct_by_gpt4).difference(marked_as_correct_by_claude).difference(marked_as_correct_by_gemini)\n",
    "\n",
    "agreement_level = len(marked_as_correct_by_all.union(marked_as_wrong_by_all))/ 20\n",
    "agreement_level"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 95,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{103, 110, 112, 116, 117, 119, 120}"
      ]
     },
     "execution_count": 95,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "marked_as_correct_by_all.symmetric_difference(marked_as_correct_by_manual)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 94,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{102, 110, 111}"
      ]
     },
     "execution_count": 94,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "marked_as_correct_by_any.symmetric_difference(marked_as_correct_by_manual)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Check new prompt on Bongard-OpenWorld with modified examples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "STRICT_LOGIC_IMAGES_PROMPT_REAL_WORLD_EXAMPLES = \"\"\"\n",
    "You are a logic module designed to provide accurate answers. \n",
    "In a Bongard Problem the objective is to spot the difference between the contents of images located on the two opposite sides of the problem. \n",
    "You are given correct labels of these sides and must decide whether the answer provided by the user is correct and matches with those labels. Answer with 'OK' or 'WRONG'.\n",
    "The user's answer should strictly logically match the labels, as shown in the examples. \n",
    "Some answers may be correct even if they do not strictly match the labels, as long as they logically describe the image.\n",
    "\n",
    "FIRST EXAMPLE: \n",
    "LEFT SIDE LABEL: All images depict a running dog.\n",
    "RIGHT SIDE LABEL: All images depict a dog, but it's not running.\n",
    "USER ANSWER: On the left side, the dog is running with a boy. On the right side, the dog is alone.\n",
    "EVALUATION: WRONG\n",
    "END OF FIRST EXAMPLE.\n",
    "\n",
    "SECOND EXAMPLE:\n",
    "LEFT SIDE LABEL: All images depict a running dog.\n",
    "RIGHT SIDE LABEL: All images depict a dog, but it's not running.\n",
    "USER ANSWER: On the left, the dog is running, whereas on the right side, it is not.\n",
    "EVALUATION: OK\n",
    "END OF SECOND EXAMPLE.\n",
    "\n",
    "LEFT SIDE LABEL:\n",
    "{first_label}\n",
    "\n",
    "RIGHT SIDE LABEL:\n",
    "{second_label}\n",
    "\n",
    "USER ANSWER:\n",
    "{model_answer}\n",
    "\"\"\".strip()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:15<00:00,  1.29it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gpt-4o {102, 103, 104, 105, 108, 109, 110, 111, 112, 116, 117, 118, 119}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:29<00:00,  1.47s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gpt-4-turbo {101, 102, 103, 104, 105, 106, 108, 109, 110, 111, 112, 115, 116, 117, 118, 119, 120}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:27<00:00,  1.40s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gemini-1.5-pro {101, 102, 103, 104, 105, 108, 109, 111, 112, 116, 117, 119, 120}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:18<00:00,  1.08it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "claude-3-5-sonnet-20240620 {101, 102, 103, 104, 105, 106, 108, 109, 110, 111, 112, 116, 117, 119, 120}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "for model in all_models:\n",
    "    evaluation_author = f\"{model.get_name()}+strict-logic-prompt+real-world-examples\"\n",
    "    evalute_answers_using_model(\n",
    "        answers_file=open_world_answers_file,\n",
    "        labels_file=\"../data/raw/bongard_open_world_labels.csv\",\n",
    "        model=model,\n",
    "        author=evaluation_author,\n",
    "        prompt=STRICT_LOGIC_IMAGES_PROMPT_REAL_WORLD_EXAMPLES,\n",
    "    )\n",
    "    print(model.get_name(), get_resolved_problems(open_world_answers_file, model.get_name()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "evaluation_authors = [f\"{model.get_name()}+strict-logic-prompt+real-world-examples\" for model in all_models]\n",
    "marked_as_correct_with_final_prompt_by_voting = get_resolved_problems_with_voting(open_world_answers_file, evaluation_authors)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{101, 102, 106, 111}"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "marked_as_correct_by_manual.symmetric_difference(marked_as_correct_with_final_prompt_by_voting)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.5"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "marked_as_correct_by_gpt4 = get_resolved_problems(open_world_answers_file, 'gpt-4o+strict-logic-prompt+real-world-examples')\n",
    "marked_as_correct_by_gemini = get_resolved_problems(open_world_answers_file, 'gemini-1.5-pro+strict-logic-prompt+real-world-examples')\n",
    "marked_as_correct_by_claude = get_resolved_problems(open_world_answers_file, 'claude-3-5-sonnet-20240620+strict-logic-prompt+real-world-examples')\n",
    "\n",
    "marked_as_correct_by_all = marked_as_correct_by_gpt4.intersection(marked_as_correct_by_gemini).intersection(marked_as_correct_by_claude)\n",
    "marked_as_correct_by_any = marked_as_correct_by_gpt4.union(marked_as_correct_by_gemini).union(marked_as_correct_by_claude)\n",
    "marked_as_wrong_by_all = set(range(101, 121)).difference(marked_as_correct_by_gpt4).difference(marked_as_correct_by_claude).difference(marked_as_correct_by_gemini)\n",
    "\n",
    "agreement_level = len(marked_as_correct_by_all.union(marked_as_wrong_by_all))/ 20\n",
    "agreement_level"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{109, 110, 112, 116, 120}"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "marked_as_correct_by_all.symmetric_difference(marked_as_correct_by_manual)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{101, 102, 106, 108, 111}"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "marked_as_correct_by_any.symmetric_difference(marked_as_correct_by_manual)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "STRICT_LOGIC_IMAGES_PROMPT_REAL_WORLD_EXAMPLES_V2 = \"\"\"\n",
    "You are a logic module designed to provide accurate answers. \n",
    "In a Bongard Problem the objective is to spot the difference between the contents of images located on the two opposite sides of the problem. \n",
    "You are given correct labels of these sides and must decide whether the answer provided by the user is correct and matches with those labels. Answer with 'OK' or 'WRONG'.\n",
    "The user's answer should strictly logically match the labels, as shown in the examples. \n",
    "Some answers may be correct even if they do not strictly match the labels, as long as they logically describe the image.\n",
    "\n",
    "FIRST EXAMPLE: \n",
    "LEFT SIDE LABEL: All animals are dogs.\n",
    "RIGHT SIDE LABEL: All animals are cat.\n",
    "USER ANSWER: On the left side, one of the animals is a dog. On the right side, all of the animals are cats.\n",
    "EVALUATION: WRONG\n",
    "END OF FIRST EXAMPLE.\n",
    "\n",
    "SECOND EXAMPLE:\n",
    "LEFT SIDE LABEL: All animals are dogs.\n",
    "RIGHT SIDE LABEL: All animals are cat.\n",
    "USER ANSWER: On the left side, all of the animals are dogs. On the right side, all of the animals are cats.\n",
    "EVALUATION: OK\n",
    "END OF SECOND EXAMPLE.\n",
    "\n",
    "LEFT SIDE LABEL:\n",
    "{first_label}\n",
    "\n",
    "RIGHT SIDE LABEL:\n",
    "{second_label}\n",
    "\n",
    "USER ANSWER:\n",
    "{model_answer}\n",
    "\"\"\".strip()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:13<00:00,  1.51it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gpt-4o {102, 103, 104, 105, 108, 109, 110, 111, 112, 116, 117, 118, 119}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:17<00:00,  1.14it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gpt-4-turbo {101, 102, 103, 104, 105, 106, 108, 109, 110, 111, 112, 115, 116, 117, 118, 119, 120}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:21<00:00,  1.06s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gemini-1.5-pro {101, 102, 103, 104, 105, 108, 109, 111, 112, 116, 117, 119, 120}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:22<00:00,  1.12s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "claude-3-5-sonnet-20240620 {101, 102, 103, 104, 105, 106, 108, 109, 110, 111, 112, 116, 117, 119, 120}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "for model in all_models:\n",
    "    evaluation_author = f\"{model.get_name()}+strict-logic-prompt+real-world-examples-v2\"\n",
    "    evalute_answers_using_model(\n",
    "        answers_file=open_world_answers_file,\n",
    "        labels_file=\"../data/raw/bongard_open_world_labels.csv\",\n",
    "        model=model,\n",
    "        author=evaluation_author,\n",
    "        prompt=STRICT_LOGIC_IMAGES_PROMPT_REAL_WORLD_EXAMPLES_V2,\n",
    "    )\n",
    "    print(model.get_name(), get_resolved_problems(open_world_answers_file, model.get_name()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "evaluation_authors = [f\"{model.get_name()}+strict-logic-prompt+real-world-examples-v2\" for model in all_models]\n",
    "marked_as_correct_with_final_prompt_by_voting = get_resolved_problems_with_voting(open_world_answers_file, evaluation_authors)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{102, 110, 111}"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "marked_as_correct_with_final_prompt_by_voting.symmetric_difference(marked_as_correct_by_manual)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.75"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "marked_as_correct_by_gpt4 = get_resolved_problems(open_world_answers_file, 'gpt-4o+strict-logic-prompt+real-world-examples-v2')\n",
    "marked_as_correct_by_gemini = get_resolved_problems(open_world_answers_file, 'gemini-1.5-pro+strict-logic-prompt+real-world-examples-v2')\n",
    "marked_as_correct_by_claude = get_resolved_problems(open_world_answers_file, 'claude-3-5-sonnet-20240620+strict-logic-prompt+real-world-examples-v2')\n",
    "\n",
    "marked_as_correct_by_all = marked_as_correct_by_gpt4.intersection(marked_as_correct_by_gemini).intersection(marked_as_correct_by_claude)\n",
    "marked_as_correct_by_any = marked_as_correct_by_gpt4.union(marked_as_correct_by_gemini).union(marked_as_correct_by_claude)\n",
    "marked_as_wrong_by_all = set(range(101, 121)).difference(marked_as_correct_by_gpt4).difference(marked_as_correct_by_claude).difference(marked_as_correct_by_gemini)\n",
    "\n",
    "agreement_level = len(marked_as_correct_by_all.union(marked_as_wrong_by_all))/ 20\n",
    "agreement_level"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{110, 112, 116, 120}"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "marked_as_correct_by_all.symmetric_difference(marked_as_correct_by_manual)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{102, 110, 111}"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "marked_as_correct_by_any.symmetric_difference(marked_as_correct_by_manual)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
