{
 "cells": [
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "import os\n",
    "\n",
    "from dotenv import load_dotenv\n",
    "\n",
    "from src.evaluation.model import evalute_answers_using_model, count_correct_answers\n",
    "from src.llm_messenger.messengers.claude_messenger import ClaudeMessenger\n",
    "from src.llm_messenger.messengers.google_messenger import GoogleMessenger\n",
    "from src.llm_messenger.messengers.gpt_messenger import GPTMessenger\n",
    "from src.prompting_techniques.contrastive import get_comparisons, resolve_bongard_with_contrastive_prompting\n",
    "\n",
    "load_dotenv()"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "PROBLEM_DESCRIPTION = \"\"\"\n",
    "A Bongard Problem is composed of left and right sides separated by a line. Each side contains six images. All images belonging to one side present a common concept, which is lacking in all images from the other side. The goal is to describe the rule that fits all images on the left side, but none on the right, and, conversely, the rule that fits all images on the right side, but none on the left. The description of the rule should be simple and concise.\n",
    "Example 1: All shapes on left are small. All shapes on right are big.\n",
    "Example 2: The left side contains circles. The right side contains triangles.\n",
    "\"\"\".strip()\n",
    "\n",
    "COMPARE_IMAGES_PROMPT = \"\"\"\n",
    "You are given two images extracted from the left and right side of a Bongard Problem, respectively. Your goal is to compare the images. Your comparison should be as concise as possible.\n",
    "\"\"\".strip()\n",
    "\n",
    "QUESTION = \"\"\"\n",
    "You are a vision understanding module designed to provide short, clear and accurate answers. Your goal is to solve the provided Bongard Problem using comparisons between pairs of images. Each pair contains one image from the left and one from the right side of the problem.\n",
    "\n",
    "COMPARISONS:\n",
    "{comparisons}\n",
    "\n",
    "What is the difference between the two sides of the problem?\n",
    "\"\"\".strip()\n",
    "\n",
    "QUESTION_CONTRASTIVE_DIRECT = \"\"\"\n",
    "You are a vision understanding module designed to provide short, clear and accurate answers. Your goal is to solve the provided Bongard Problem with the help of comparisons between pairs of images. Each pair contains one image from the left and one from the right side of the problem.\n",
    "\n",
    "COMPARISONS:\n",
    "{comparisons}\n",
    "\n",
    "What is the difference between the two sides of the problem?\n",
    "\"\"\".strip()"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Bongard AVR"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prompting: Contrastive"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "experiment_dir = \"../data/processed/bongard/experiments/synthetic_generation_prompting-contrastive\""
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### GPT-4"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = GPTMessenger(\n",
    "    model_name=\"gpt-4o\",\n",
    "    api_key=os.environ[\"OPENAI_API_KEY\"],\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "comparisons_file = f\"{experiment_dir}/{model.get_name()}_comparisons.json\"\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "comparisons_file, answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "get_comparisons(\n",
    "    problem_ids=list(range(1, 101)),\n",
    "    splitted_data_path=\"../data/raw/bongard_splitted\",\n",
    "    model=model,\n",
    "    output_file=comparisons_file,\n",
    "    problem_description=PROBLEM_DESCRIPTION,\n",
    "    question=COMPARE_IMAGES_PROMPT,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "resolve_bongard_with_contrastive_prompting(\n",
    "    problem_ids=list(range(1, 101)),\n",
    "    splitted_data_path=\"../data/raw/bongard_splitted\",\n",
    "    output_file=answers_file,\n",
    "    model=model,\n",
    "    comparisons_file=comparisons_file,\n",
    "    problem_description=PROBLEM_DESCRIPTION,\n",
    "    question=QUESTION,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "evalute_answers_using_model(\n",
    "    answers_file=answers_file,\n",
    "    labels_file=\"../data/raw/labels.csv\",\n",
    "    model=model,\n",
    ")\n",
    "count_correct_answers(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Gemini 1.5 Pro"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = GoogleMessenger(\n",
    "    api_key=os.environ[\"GOOGLE_API_KEY\"],\n",
    "    model_name=\"gemini-1.5-pro\",\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "comparisons_file = f\"{experiment_dir}/{model.get_name()}_comparisons.json\"\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "comparisons_file, answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "get_comparisons(\n",
    "    problem_ids=list(range(1, 101)),\n",
    "    splitted_data_path=\"../data/raw/bongard_splitted\",\n",
    "    model=model,\n",
    "    output_file=comparisons_file,\n",
    "    problem_description=PROBLEM_DESCRIPTION,\n",
    "    question=COMPARE_IMAGES_PROMPT,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "resolve_bongard_with_contrastive_prompting(\n",
    "    problem_ids=list(range(1, 101)),\n",
    "    splitted_data_path=\"../data/raw/bongard_splitted\",\n",
    "    output_file=answers_file,\n",
    "    model=model,\n",
    "    comparisons_file=comparisons_file,\n",
    "    problem_description=PROBLEM_DESCRIPTION,\n",
    "    question=QUESTION,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "evalute_answers_using_model(\n",
    "    answers_file=answers_file,\n",
    "    labels_file=\"../data/raw/labels.csv\",\n",
    "    model=model,\n",
    ")\n",
    "count_correct_answers(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Claude 3.5 Sonnet"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = ClaudeMessenger(\n",
    "    api_key=os.environ[\"ANTHROPIC_API_KEY\"],\n",
    "    model_name=\"claude-3-5-sonnet-20240620\",\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "comparisons_file = f\"{experiment_dir}/{model.get_name()}_comparisons.json\"\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "comparisons_file, answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "get_comparisons(\n",
    "    problem_ids=list(range(1, 101)),\n",
    "    splitted_data_path=\"../data/raw/bongard_splitted\",\n",
    "    model=model,\n",
    "    output_file=comparisons_file,\n",
    "    problem_description=PROBLEM_DESCRIPTION,\n",
    "    question=COMPARE_IMAGES_PROMPT,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "resolve_bongard_with_contrastive_prompting(\n",
    "    problem_ids=list(range(1, 101)),\n",
    "    splitted_data_path=\"../data/raw/bongard_splitted\",\n",
    "    output_file=answers_file,\n",
    "    model=model,\n",
    "    comparisons_file=comparisons_file,\n",
    "    problem_description=PROBLEM_DESCRIPTION,\n",
    "    question=QUESTION,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "evalute_answers_using_model(\n",
    "    answers_file=answers_file,\n",
    "    labels_file=\"../data/raw/labels.csv\",\n",
    "    model=model,\n",
    ")\n",
    "count_correct_answers(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prompting: Contrastive-direct"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "experiment_dir = \"../data/processed/bongard/experiments/synthetic_generation_prompting-contrastive-direct\""
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### GPT-4"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = GPTMessenger(\n",
    "    model_name=\"gpt-4o\",\n",
    "    api_key=os.environ[\"OPENAI_API_KEY\"],\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "comparisons_file = f\"{experiment_dir.replace('-direct', '')}/{model.get_name()}_comparisons.json\"\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "comparisons_file, answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "resolve_bongard_with_contrastive_prompting(\n",
    "    problem_ids=list(range(1, 101)),\n",
    "    splitted_data_path=\"../data/raw/bongard_splitted\",\n",
    "    output_file=answers_file,\n",
    "    model=model,\n",
    "    comparisons_file=comparisons_file,\n",
    "    problem_description=PROBLEM_DESCRIPTION,\n",
    "    question=QUESTION_CONTRASTIVE_DIRECT,\n",
    "    with_image=True,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "evalute_answers_using_model(\n",
    "    answers_file=answers_file,\n",
    "    labels_file=\"../data/raw/labels.csv\",\n",
    "    model=model,\n",
    ")\n",
    "count_correct_answers(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Gemini 1.5 Pro"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = GoogleMessenger(\n",
    "    api_key=os.environ[\"GOOGLE_API_KEY\"],\n",
    "    model_name=\"gemini-1.5-pro\",\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "comparisons_file = f\"{experiment_dir.replace('-direct', '')}/{model.get_name()}_comparisons.json\"\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "comparisons_file, answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "resolve_bongard_with_contrastive_prompting(\n",
    "    problem_ids=list(range(1, 101)),\n",
    "    splitted_data_path=\"../data/raw/bongard_splitted\",\n",
    "    output_file=answers_file,\n",
    "    model=model,\n",
    "    comparisons_file=comparisons_file,\n",
    "    problem_description=PROBLEM_DESCRIPTION,\n",
    "    question=QUESTION_CONTRASTIVE_DIRECT,\n",
    "    with_image=True,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "evalute_answers_using_model(\n",
    "    answers_file=answers_file,\n",
    "    labels_file=\"../data/raw/labels.csv\",\n",
    "    model=model,\n",
    ")\n",
    "count_correct_answers(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Claude 3.5 Sonnet"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = ClaudeMessenger(\n",
    "    api_key=os.environ[\"ANTHROPIC_API_KEY\"],\n",
    "    model_name=\"claude-3-5-sonnet-20240620\",\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "comparisons_file = f\"{experiment_dir.replace('-direct', '')}/{model.get_name()}_comparisons.json\"\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "comparisons_file, answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "resolve_bongard_with_contrastive_prompting(\n",
    "    problem_ids=list(range(1, 101)),\n",
    "    splitted_data_path=\"../data/raw/bongard_splitted\",\n",
    "    output_file=answers_file,\n",
    "    model=model,\n",
    "    comparisons_file=comparisons_file,\n",
    "    problem_description=PROBLEM_DESCRIPTION,\n",
    "    question=QUESTION_CONTRASTIVE_DIRECT,\n",
    "    with_image=True,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "evalute_answers_using_model(\n",
    "    answers_file=answers_file,\n",
    "    labels_file=\"../data/raw/labels.csv\",\n",
    "    model=model,\n",
    ")\n",
    "count_correct_answers(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Bongard OpenWorld"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prompting: Contrastive"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "experiment_dir = \"../data/processed/bongard/experiments/openworld_generation_prompting-contrastive\""
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### GPT-4"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = GPTMessenger(\n",
    "    model_name=\"gpt-4o\",\n",
    "    api_key=os.environ[\"OPENAI_API_KEY\"],\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "comparisons_file = f\"{experiment_dir}/{model.get_name()}_comparisons.json\"\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "comparisons_file, answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "get_comparisons(\n",
    "    problem_ids=list(range(1, 101)),\n",
    "    splitted_data_path=\"../data/raw/bongard_open_world_splitted\",\n",
    "    model=model,\n",
    "    output_file=comparisons_file,\n",
    "    problem_description=PROBLEM_DESCRIPTION,\n",
    "    question=COMPARE_IMAGES_PROMPT,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "resolve_bongard_with_contrastive_prompting(\n",
    "    problem_ids=list(range(1, 101)),\n",
    "    splitted_data_path=\"../data/raw/bongard_open_world_splitted\",\n",
    "    output_file=answers_file,\n",
    "    model=model,\n",
    "    comparisons_file=comparisons_file,\n",
    "    problem_description=PROBLEM_DESCRIPTION,\n",
    "    question=QUESTION,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "evalute_answers_using_model(\n",
    "    answers_file=answers_file,\n",
    "    labels_file=\"../data/raw/bongard_open_world_labels.csv\",\n",
    "    model=model,\n",
    ")\n",
    "count_correct_answers(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Gemini 1.5 Pro"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = GoogleMessenger(\n",
    "    api_key=os.environ[\"GOOGLE_API_KEY\"],\n",
    "    model_name=\"gemini-1.5-pro\",\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "comparisons_file = f\"{experiment_dir}/{model.get_name()}_comparisons.json\"\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "comparisons_file, answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "get_comparisons(\n",
    "    problem_ids=list(range(1, 101)),\n",
    "    splitted_data_path=\"../data/raw/bongard_open_world_splitted\",\n",
    "    model=model,\n",
    "    output_file=comparisons_file,\n",
    "    problem_description=PROBLEM_DESCRIPTION,\n",
    "    question=COMPARE_IMAGES_PROMPT,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "resolve_bongard_with_contrastive_prompting(\n",
    "    problem_ids=list(range(1, 101)),\n",
    "    splitted_data_path=\"../data/raw/bongard_open_world_splitted\",\n",
    "    output_file=answers_file,\n",
    "    model=model,\n",
    "    comparisons_file=comparisons_file,\n",
    "    problem_description=PROBLEM_DESCRIPTION,\n",
    "    question=QUESTION,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "evalute_answers_using_model(\n",
    "    answers_file=answers_file,\n",
    "    labels_file=\"../data/raw/bongard_open_world_labels.csv\",\n",
    "    model=model,\n",
    ")\n",
    "count_correct_answers(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Claude 3.5 Sonnet"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = ClaudeMessenger(\n",
    "    api_key=os.environ[\"ANTHROPIC_API_KEY\"],\n",
    "    model_name=\"claude-3-5-sonnet-20240620\",\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "comparisons_file = f\"{experiment_dir}/{model.get_name()}_comparisons.json\"\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "comparisons_file, answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "get_comparisons(\n",
    "    problem_ids=list(range(1, 101)),\n",
    "    splitted_data_path=\"../data/raw/bongard_open_world_splitted\",\n",
    "    model=model,\n",
    "    output_file=comparisons_file,\n",
    "    problem_description=PROBLEM_DESCRIPTION,\n",
    "    question=COMPARE_IMAGES_PROMPT,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "resolve_bongard_with_contrastive_prompting(\n",
    "    problem_ids=list(range(1, 101)),\n",
    "    splitted_data_path=\"../data/raw/bongard_open_world_splitted\",\n",
    "    output_file=answers_file,\n",
    "    model=model,\n",
    "    comparisons_file=comparisons_file,\n",
    "    problem_description=PROBLEM_DESCRIPTION,\n",
    "    question=QUESTION,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "evalute_answers_using_model(\n",
    "    answers_file=answers_file,\n",
    "    labels_file=\"../data/raw/bongard_open_world_labels.csv\",\n",
    "    model=model,\n",
    ")\n",
    "count_correct_answers(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prompting: Contrastive-direct"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "experiment_dir = \"../data/processed/bongard/experiments/openworld_generation_prompting-contrastive-direct\""
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### GPT-4"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = GPTMessenger(\n",
    "    model_name=\"gpt-4o\",\n",
    "    api_key=os.environ[\"OPENAI_API_KEY\"],\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "comparisons_file = f\"{experiment_dir.replace('-direct', '')}/{model.get_name()}_comparisons.json\"\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "comparisons_file, answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "resolve_bongard_with_contrastive_prompting(\n",
    "    problem_ids=list(range(1, 101)),\n",
    "    splitted_data_path=\"../data/raw/bongard_open_world_splitted\",\n",
    "    output_file=answers_file,\n",
    "    model=model,\n",
    "    comparisons_file=comparisons_file,\n",
    "    problem_description=PROBLEM_DESCRIPTION,\n",
    "    question=QUESTION_CONTRASTIVE_DIRECT,\n",
    "    with_image=True,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "evalute_answers_using_model(\n",
    "    answers_file=answers_file,\n",
    "    labels_file=\"../data/raw/bongard_open_world_labels.csv\",\n",
    "    model=model,\n",
    ")\n",
    "count_correct_answers(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Gemini 1.5 Pro"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = GoogleMessenger(\n",
    "    api_key=os.environ[\"GOOGLE_API_KEY\"],\n",
    "    model_name=\"gemini-1.5-pro\",\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "comparisons_file = f\"{experiment_dir.replace('-direct', '')}/{model.get_name()}_comparisons.json\"\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "comparisons_file, answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "resolve_bongard_with_contrastive_prompting(\n",
    "    problem_ids=list(range(1, 101)),\n",
    "    splitted_data_path=\"../data/raw/bongard_open_world_splitted\",\n",
    "    output_file=answers_file,\n",
    "    model=model,\n",
    "    comparisons_file=comparisons_file,\n",
    "    problem_description=PROBLEM_DESCRIPTION,\n",
    "    question=QUESTION_CONTRASTIVE_DIRECT,\n",
    "    with_image=True,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "evalute_answers_using_model(\n",
    "    answers_file=answers_file,\n",
    "    labels_file=\"../data/raw/bongard_open_world_labels.csv\",\n",
    "    model=model,\n",
    ")\n",
    "count_correct_answers(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Claude 3.5 Sonnet"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = ClaudeMessenger(\n",
    "    api_key=os.environ[\"ANTHROPIC_API_KEY\"],\n",
    "    model_name=\"claude-3-5-sonnet-20240620\",\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "comparisons_file = f\"{experiment_dir.replace('-direct', '')}/{model.get_name()}_comparisons.json\"\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "comparisons_file, answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "resolve_bongard_with_contrastive_prompting(\n",
    "    problem_ids=list(range(1, 101)),\n",
    "    splitted_data_path=\"../data/raw/bongard_open_world_splitted\",\n",
    "    output_file=answers_file,\n",
    "    model=model,\n",
    "    comparisons_file=comparisons_file,\n",
    "    problem_description=PROBLEM_DESCRIPTION,\n",
    "    question=QUESTION_CONTRASTIVE_DIRECT,\n",
    "    with_image=True,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "evalute_answers_using_model(\n",
    "    answers_file=answers_file,\n",
    "    labels_file=\"../data/raw/bongard_open_world_labels.csv\",\n",
    "    model=model,\n",
    ")\n",
    "count_correct_answers(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Bongard HOI"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prompting: Contrastive"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "experiment_dir = \"../data/processed/bongard/experiments/hoi_generation_prompting-contrastive\""
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### GPT-4"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = GPTMessenger(\n",
    "    model_name=\"gpt-4o\",\n",
    "    api_key=os.environ[\"OPENAI_API_KEY\"],\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "comparisons_file = f\"{experiment_dir}/{model.get_name()}_comparisons.json\"\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "comparisons_file, answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "get_comparisons(\n",
    "    problem_ids=list(range(1, 101)),\n",
    "    splitted_data_path=\"../data/raw/bongard_hoi_splitted_mix\",\n",
    "    model=model,\n",
    "    output_file=comparisons_file,\n",
    "    problem_description=PROBLEM_DESCRIPTION,\n",
    "    question=COMPARE_IMAGES_PROMPT,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "resolve_bongard_with_contrastive_prompting(\n",
    "    problem_ids=list(range(1, 101)),\n",
    "    splitted_data_path=\"../data/raw/bongard_hoi_splitted_mix\",\n",
    "    output_file=answers_file,\n",
    "    model=model,\n",
    "    comparisons_file=comparisons_file,\n",
    "    problem_description=PROBLEM_DESCRIPTION,\n",
    "    question=QUESTION,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "evalute_answers_using_model(\n",
    "    answers_file=answers_file,\n",
    "    labels_file=\"../data/raw/bongard_hoi_mix_labels.csv\",\n",
    "    model=model,\n",
    ")\n",
    "count_correct_answers(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Gemini 1.5 Pro"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = GoogleMessenger(\n",
    "    api_key=os.environ[\"GOOGLE_API_KEY\"],\n",
    "    model_name=\"gemini-1.5-pro\",\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "comparisons_file = f\"{experiment_dir}/{model.get_name()}_comparisons.json\"\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "comparisons_file, answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "get_comparisons(\n",
    "    problem_ids=list(range(1, 101)),\n",
    "    splitted_data_path=\"../data/raw/bongard_hoi_splitted_mix\",\n",
    "    model=model,\n",
    "    output_file=comparisons_file,\n",
    "    problem_description=PROBLEM_DESCRIPTION,\n",
    "    question=COMPARE_IMAGES_PROMPT,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "resolve_bongard_with_contrastive_prompting(\n",
    "    problem_ids=list(range(1, 101)),\n",
    "    splitted_data_path=\"../data/raw/bongard_hoi_splitted_mix\",\n",
    "    output_file=answers_file,\n",
    "    model=model,\n",
    "    comparisons_file=comparisons_file,\n",
    "    problem_description=PROBLEM_DESCRIPTION,\n",
    "    question=QUESTION,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "evalute_answers_using_model(\n",
    "    answers_file=answers_file,\n",
    "    labels_file=\"../data/raw/bongard_hoi_mix_labels.csv\",\n",
    "    model=model,\n",
    ")\n",
    "count_correct_answers(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Claude 3.5 Sonnet"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = ClaudeMessenger(\n",
    "    api_key=os.environ[\"ANTHROPIC_API_KEY\"],\n",
    "    model_name=\"claude-3-5-sonnet-20240620\",\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "comparisons_file = f\"{experiment_dir}/{model.get_name()}_comparisons.json\"\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "comparisons_file, answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "get_comparisons(\n",
    "    problem_ids=list(range(1, 101)),\n",
    "    splitted_data_path=\"../data/raw/bongard_hoi_splitted_mix\",\n",
    "    model=model,\n",
    "    output_file=comparisons_file,\n",
    "    problem_description=PROBLEM_DESCRIPTION,\n",
    "    question=COMPARE_IMAGES_PROMPT,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "resolve_bongard_with_contrastive_prompting(\n",
    "    problem_ids=list(range(1, 101)),\n",
    "    splitted_data_path=\"../data/raw/bongard_hoi_splitted_mix\",\n",
    "    output_file=answers_file,\n",
    "    model=model,\n",
    "    comparisons_file=comparisons_file,\n",
    "    problem_description=PROBLEM_DESCRIPTION,\n",
    "    question=QUESTION,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "evalute_answers_using_model(\n",
    "    answers_file=answers_file,\n",
    "    labels_file=\"../data/raw/bongard_hoi_mix_labels.csv\",\n",
    "    model=model,\n",
    ")\n",
    "count_correct_answers(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prompting: Contrastive-direct"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "experiment_dir = \"../data/processed/bongard/experiments/hoi_generation_prompting-contrastive-direct\""
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### GPT-4"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = GPTMessenger(\n",
    "    model_name=\"gpt-4o\",\n",
    "    api_key=os.environ[\"OPENAI_API_KEY\"],\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "comparisons_file = f\"{experiment_dir.replace('-direct', '')}/{model.get_name()}_comparisons.json\"\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "comparisons_file, answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "resolve_bongard_with_contrastive_prompting(\n",
    "    problem_ids=list(range(1, 101)),\n",
    "    splitted_data_path=\"../data/raw/bongard_hoi_splitted_mix\",\n",
    "    output_file=answers_file,\n",
    "    model=model,\n",
    "    comparisons_file=comparisons_file,\n",
    "    problem_description=PROBLEM_DESCRIPTION,\n",
    "    question=QUESTION_CONTRASTIVE_DIRECT,\n",
    "    with_image=True,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "evalute_answers_using_model(\n",
    "    answers_file=answers_file,\n",
    "    labels_file=\"../data/raw/bongard_hoi_mix_labels.csv\",\n",
    "    model=model,\n",
    ")\n",
    "count_correct_answers(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Gemini 1.5 Pro"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = GoogleMessenger(\n",
    "    api_key=os.environ[\"GOOGLE_API_KEY\"],\n",
    "    model_name=\"gemini-1.5-pro\",\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "comparisons_file = f\"{experiment_dir.replace('-direct', '')}/{model.get_name()}_comparisons.json\"\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "comparisons_file, answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "resolve_bongard_with_contrastive_prompting(\n",
    "    problem_ids=list(range(1, 101)),\n",
    "    splitted_data_path=\"../data/raw/bongard_hoi_splitted_mix\",\n",
    "    output_file=answers_file,\n",
    "    model=model,\n",
    "    comparisons_file=comparisons_file,\n",
    "    problem_description=PROBLEM_DESCRIPTION,\n",
    "    question=QUESTION_CONTRASTIVE_DIRECT,\n",
    "    with_image=True,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "evalute_answers_using_model(\n",
    "    answers_file=answers_file,\n",
    "    labels_file=\"../data/raw/bongard_hoi_mix_labels.csv\",\n",
    "    model=model,\n",
    ")\n",
    "count_correct_answers(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Claude 3.5 Sonnet"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = ClaudeMessenger(\n",
    "    api_key=os.environ[\"ANTHROPIC_API_KEY\"],\n",
    "    model_name=\"claude-3-5-sonnet-20240620\",\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "comparisons_file = f\"{experiment_dir.replace('-direct', '')}/{model.get_name()}_comparisons.json\"\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "comparisons_file, answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "resolve_bongard_with_contrastive_prompting(\n",
    "    problem_ids=list(range(1, 101)),\n",
    "    splitted_data_path=\"../data/raw/bongard_hoi_splitted_mix\",\n",
    "    output_file=answers_file,\n",
    "    model=model,\n",
    "    comparisons_file=comparisons_file,\n",
    "    problem_description=PROBLEM_DESCRIPTION,\n",
    "    question=QUESTION_CONTRASTIVE_DIRECT,\n",
    "    with_image=True,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "evalute_answers_using_model(\n",
    "    answers_file=answers_file,\n",
    "    labels_file=\"../data/raw/bongard_hoi_mix_labels.csv\",\n",
    "    model=model,\n",
    ")\n",
    "count_correct_answers(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Bongard RWR"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "bongard_rwr_problem_ids = sorted([int(id) for id in os.listdir(\"../data/raw/bongard_rwr_splitted\") if id.isdigit()])\n",
    "print(bongard_rwr_problem_ids)\n",
    "print('Count:', len(bongard_rwr_problem_ids))"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prompting: Contrastive"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "experiment_dir = \"../data/processed/bongard/experiments/rwr_generation_prompting-contrastive\""
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### GPT-4"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = GPTMessenger(\n",
    "    model_name=\"gpt-4-turbo\",\n",
    "    api_key=os.environ[\"OPENAI_API_KEY\"],\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "comparisons_file = f\"{experiment_dir}/{model.get_name()}_comparisons.json\"\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "comparisons_file, answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "get_comparisons(\n",
    "    problem_ids=bongard_rwr_problem_ids,\n",
    "    splitted_data_path=\"../data/raw/bongard_rwr\",\n",
    "    model=model,\n",
    "    output_file=comparisons_file,\n",
    "    problem_description=PROBLEM_DESCRIPTION,\n",
    "    question=COMPARE_IMAGES_PROMPT,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "resolve_bongard_with_contrastive_prompting(\n",
    "    problem_ids=bongard_rwr_problem_ids,\n",
    "    splitted_data_path=\"../data/raw/bongard_rwr\",\n",
    "    output_file=answers_file,\n",
    "    model=model,\n",
    "    comparisons_file=comparisons_file,\n",
    "    problem_description=PROBLEM_DESCRIPTION,\n",
    "    question=QUESTION,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "evalute_answers_using_model(\n",
    "    answers_file=answers_file,\n",
    "    labels_file=\"../data/raw/labels.csv\",\n",
    "    model=model,\n",
    ")\n",
    "count_correct_answers(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Gemini 1.5 Pro"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = GoogleMessenger(\n",
    "    api_key=os.environ[\"GOOGLE_API_KEY\"],\n",
    "    model_name=\"gemini-1.5-pro\",\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "comparisons_file = f\"{experiment_dir}/{model.get_name()}_comparisons.json\"\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "comparisons_file, answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "get_comparisons(\n",
    "    problem_ids=bongard_rwr_problem_ids,\n",
    "    splitted_data_path=\"../data/raw/bongard_rwr\",\n",
    "    model=model,\n",
    "    output_file=comparisons_file,\n",
    "    problem_description=PROBLEM_DESCRIPTION,\n",
    "    question=COMPARE_IMAGES_PROMPT,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "resolve_bongard_with_contrastive_prompting(\n",
    "    problem_ids=bongard_rwr_problem_ids,\n",
    "    splitted_data_path=\"../data/raw/bongard_rwr\",\n",
    "    output_file=answers_file,\n",
    "    model=model,\n",
    "    comparisons_file=comparisons_file,\n",
    "    problem_description=PROBLEM_DESCRIPTION,\n",
    "    question=QUESTION,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "evalute_answers_using_model(\n",
    "    answers_file=answers_file,\n",
    "    labels_file=\"../data/raw/labels.csv\",\n",
    "    model=model,\n",
    ")\n",
    "count_correct_answers(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Claude 3.5 Sonnet"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = ClaudeMessenger(\n",
    "    api_key=os.environ[\"ANTHROPIC_API_KEY\"],\n",
    "    model_name=\"claude-3-5-sonnet-20240620\",\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "comparisons_file = f\"{experiment_dir}/{model.get_name()}_comparisons.json\"\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "comparisons_file, answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "get_comparisons(\n",
    "    problem_ids=bongard_rwr_problem_ids,\n",
    "    splitted_data_path=\"../data/raw/bongard_rwr\",\n",
    "    model=model,\n",
    "    output_file=comparisons_file,\n",
    "    problem_description=PROBLEM_DESCRIPTION,\n",
    "    question=COMPARE_IMAGES_PROMPT,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "resolve_bongard_with_contrastive_prompting(\n",
    "    problem_ids=bongard_rwr_problem_ids,\n",
    "    splitted_data_path=\"../data/raw/bongard_rwr\",\n",
    "    output_file=answers_file,\n",
    "    model=model,\n",
    "    comparisons_file=comparisons_file,\n",
    "    problem_description=PROBLEM_DESCRIPTION,\n",
    "    question=QUESTION,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "evalute_answers_using_model(\n",
    "    answers_file=answers_file,\n",
    "    labels_file=\"../data/raw/labels.csv\",\n",
    "    model=model,\n",
    ")\n",
    "count_correct_answers(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prompting: Contrastive-direct"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "experiment_dir = \"../data/processed/bongard/experiments/rwr_generation_prompting-contrastive-direct\""
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### GPT-4"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = GPTMessenger(\n",
    "    model_name=\"gpt-4-turbo\",\n",
    "    api_key=os.environ[\"OPENAI_API_KEY\"],\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "comparisons_file = f\"{experiment_dir.replace('-direct', '')}/{model.get_name()}_comparisons.json\"\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "comparisons_file, answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "resolve_bongard_with_contrastive_prompting(\n",
    "    problem_ids=bongard_rwr_problem_ids,\n",
    "    splitted_data_path=\"../data/raw/bongard_rwr\",\n",
    "    output_file=answers_file,\n",
    "    model=model,\n",
    "    comparisons_file=comparisons_file,\n",
    "    problem_description=PROBLEM_DESCRIPTION,\n",
    "    question=QUESTION_CONTRASTIVE_DIRECT,\n",
    "    with_image=True,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "evalute_answers_using_model(\n",
    "    answers_file=answers_file,\n",
    "    labels_file=\"../data/raw/labels.csv\",\n",
    "    model=model,\n",
    ")\n",
    "count_correct_answers(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Gemini 1.5 Pro"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = GoogleMessenger(\n",
    "    api_key=os.environ[\"GOOGLE_API_KEY\"],\n",
    "    model_name=\"gemini-1.5-pro\",\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "comparisons_file = f\"{experiment_dir.replace('-direct', '')}/{model.get_name()}_comparisons.json\"\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "comparisons_file, answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "resolve_bongard_with_contrastive_prompting(\n",
    "    problem_ids=bongard_rwr_problem_ids,\n",
    "    splitted_data_path=\"../data/raw/bongard_rwr\",\n",
    "    output_file=answers_file,\n",
    "    model=model,\n",
    "    comparisons_file=comparisons_file,\n",
    "    problem_description=PROBLEM_DESCRIPTION,\n",
    "    question=QUESTION_CONTRASTIVE_DIRECT,\n",
    "    with_image=True,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "evalute_answers_using_model(\n",
    "    answers_file=answers_file,\n",
    "    labels_file=\"../data/raw/labels.csv\",\n",
    "    model=model,\n",
    ")\n",
    "count_correct_answers(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Claude 3.5 Sonnet"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = ClaudeMessenger(\n",
    "    api_key=os.environ[\"ANTHROPIC_API_KEY\"],\n",
    "    model_name=\"claude-3-5-sonnet-20240620\",\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "comparisons_file = f\"{experiment_dir.replace('-direct', '')}/{model.get_name()}_comparisons.json\"\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "comparisons_file, answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "resolve_bongard_with_contrastive_prompting(\n",
    "    problem_ids=bongard_rwr_problem_ids,\n",
    "    splitted_data_path=\"../data/raw/bongard_rwr\",\n",
    "    output_file=answers_file,\n",
    "    model=model,\n",
    "    comparisons_file=comparisons_file,\n",
    "    problem_description=PROBLEM_DESCRIPTION,\n",
    "    question=QUESTION_CONTRASTIVE_DIRECT,\n",
    "    with_image=True,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "evalute_answers_using_model(\n",
    "    answers_file=answers_file,\n",
    "    labels_file=\"../data/raw/labels.csv\",\n",
    "    model=model,\n",
    ")\n",
    "count_correct_answers(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
