{
 "cells": [
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "import os\n",
    "\n",
    "from dotenv import load_dotenv\n",
    "\n",
    "from src.evaluation.model import evalute_answers_using_model, count_correct_answers\n",
    "from src.llm_messenger.messengers.claude_messenger import ClaudeMessenger\n",
    "from src.llm_messenger.messengers.google_messenger import GoogleMessenger\n",
    "from src.llm_messenger.messengers.gpt_messenger import GPTMessenger\n",
    "from src.prompting_techniques.iterative import get_iterative_descriptions, resolve_bongard_with_iterative_prompting\n",
    "\n",
    "load_dotenv()"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "DESCRIBE_ITERATIVELY_SIDE_IMAGES_PROMPT = \"\"\"\n",
    "You'll receive a sequence of images that are a part of a single side of a Bongard Problem. The images will be provided one by one. Your goal is to find a common concept presented in all images. Your description should be as concise as possible. Focus on the most important details. Try to enhance the description of the concept after each image.\n",
    "\n",
    "The image is always provided correctly. Respond only to the specific request. The first image will be provided in the next message.\n",
    "\"\"\".strip()\n",
    "\n",
    "FINAL_ANSWER_PROMPT = \"That was the last image. Now provide your final answer.\"\n",
    "\n",
    "PROBLEM_DESCRIPTION = \"\"\"\n",
    "A Bongard Problem is composed of left and right sides separated by a line. Each side contains six images. All images belonging to one side present a common concept, which is lacking in all images from the other side. The goal is to describe the rule that fits all images on the left side, but none on the right, and, conversely, the rule that fits all images on the right side, but none on the left. The description of the rule should be simple and concise.\n",
    "Example 1: All shapes on left are small. All shapes on right are big.\n",
    "Example 2: The left side contains circles. The right side contains triangles.\n",
    "\"\"\".strip()\n",
    "\n",
    "QUESTION = \"\"\"\n",
    "You are a vision understanding module designed to provide short, clear and accurate answers. Your goal is to solve the provided Bongard Problem using descriptions of two sides of the problem.\n",
    "\n",
    "LEFT SIDE DESCRIPTION:\n",
    "{left_description}\n",
    "\n",
    "RIGHT SIDE DESCRIPTION:\n",
    "{right_description}\n",
    "\n",
    "What is the difference between the two sides of the problem?\n",
    "\"\"\".strip()"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Bongard AVR"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "experiment_dir = \"../data/processed/bongard/experiments/synthetic_generation_prompting-iterative\""
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### GPT-4"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = GPTMessenger(\n",
    "    model_name=\"gpt-4o\",\n",
    "    api_key=os.environ[\"OPENAI_API_KEY\"],\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "descriptions_file = f\"{experiment_dir}/{model.get_name()}_descriptions.json\"\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "descriptions_file, answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "get_iterative_descriptions(\n",
    "    problem_ids=list(range(1, 101)),\n",
    "    splitted_data_path=\"../data/raw/bongard_splitted\",\n",
    "    model=model,\n",
    "    output_file=descriptions_file,\n",
    "    task_description=DESCRIBE_ITERATIVELY_SIDE_IMAGES_PROMPT,\n",
    "    final_answer_prompt=FINAL_ANSWER_PROMPT,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "resolve_bongard_with_iterative_prompting(\n",
    "    problem_ids=list(range(1, 101)),\n",
    "    output_file=answers_file,\n",
    "    model=model,\n",
    "    descriptions_file=descriptions_file,\n",
    "    problem_description=PROBLEM_DESCRIPTION,\n",
    "    question=QUESTION,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "evalute_answers_using_model(\n",
    "    answers_file=answers_file,\n",
    "    labels_file=\"../data/raw/labels.csv\",\n",
    "    model=model,\n",
    ")\n",
    "count_correct_answers(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Gemini 1.5 Pro"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = GoogleMessenger(\n",
    "    api_key=os.environ[\"GOOGLE_API_KEY\"],\n",
    "    model_name=\"gemini-1.5-pro\",\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "descriptions_file = f\"{experiment_dir}/{model.get_name()}_descriptions.json\"\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "descriptions_file, answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "get_iterative_descriptions(\n",
    "    problem_ids=list(range(1, 101)),\n",
    "    splitted_data_path=\"../data/raw/bongard_splitted\",\n",
    "    model=model,\n",
    "    output_file=descriptions_file,\n",
    "    task_description=DESCRIBE_ITERATIVELY_SIDE_IMAGES_PROMPT,\n",
    "    final_answer_prompt=FINAL_ANSWER_PROMPT,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "resolve_bongard_with_iterative_prompting(\n",
    "    problem_ids=list(range(1, 101)),\n",
    "    output_file=answers_file,\n",
    "    model=model,\n",
    "    descriptions_file=descriptions_file,\n",
    "    problem_description=PROBLEM_DESCRIPTION,\n",
    "    question=QUESTION,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "evalute_answers_using_model(\n",
    "    answers_file=answers_file,\n",
    "    labels_file=\"../data/raw/labels.csv\",\n",
    "    model=model,\n",
    ")\n",
    "count_correct_answers(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Claude 3.5 Sonnet"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = ClaudeMessenger(\n",
    "    api_key=os.environ[\"ANTHROPIC_API_KEY\"],\n",
    "    model_name=\"claude-3-5-sonnet-20240620\",\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "descriptions_file = f\"{experiment_dir}/{model.get_name()}_descriptions.json\"\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "descriptions_file, answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "get_iterative_descriptions(\n",
    "    problem_ids=list(range(1, 101)),\n",
    "    splitted_data_path=\"../data/raw/bongard_splitted\",\n",
    "    model=model,\n",
    "    output_file=descriptions_file,\n",
    "    task_description=DESCRIBE_ITERATIVELY_SIDE_IMAGES_PROMPT,\n",
    "    final_answer_prompt=FINAL_ANSWER_PROMPT,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "resolve_bongard_with_iterative_prompting(\n",
    "    problem_ids=list(range(1, 101)),\n",
    "    output_file=answers_file,\n",
    "    model=model,\n",
    "    descriptions_file=descriptions_file,\n",
    "    problem_description=PROBLEM_DESCRIPTION,\n",
    "    question=QUESTION,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "evalute_answers_using_model(\n",
    "    answers_file=answers_file,\n",
    "    labels_file=\"../data/raw/labels.csv\",\n",
    "    model=model,\n",
    ")\n",
    "count_correct_answers(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Bongard OpenWorld"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "experiment_dir = \"../data/processed/bongard/experiments/openworld_generation_prompting-iterative\""
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### GPT-4"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = GPTMessenger(\n",
    "    model_name=\"gpt-4o\",\n",
    "    api_key=os.environ[\"OPENAI_API_KEY\"],\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "descriptions_file = f\"{experiment_dir}/{model.get_name()}_descriptions.json\"\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "descriptions_file, answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "get_iterative_descriptions(\n",
    "    problem_ids=list(range(1, 101)),\n",
    "    splitted_data_path=\"../data/raw/bongard_open_world_splitted\",\n",
    "    model=model,\n",
    "    output_file=descriptions_file,\n",
    "    task_description=DESCRIBE_ITERATIVELY_SIDE_IMAGES_PROMPT,\n",
    "    final_answer_prompt=FINAL_ANSWER_PROMPT,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "resolve_bongard_with_iterative_prompting(\n",
    "    problem_ids=list(range(1, 101)),\n",
    "    output_file=answers_file,\n",
    "    model=model,\n",
    "    descriptions_file=descriptions_file,\n",
    "    problem_description=PROBLEM_DESCRIPTION,\n",
    "    question=QUESTION,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "evalute_answers_using_model(\n",
    "    answers_file=answers_file,\n",
    "    labels_file=\"../data/raw/bongard_open_world_labels.csv\",\n",
    "    model=model,\n",
    ")\n",
    "count_correct_answers(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Gemini 1.5 Pro"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = GoogleMessenger(\n",
    "    api_key=os.environ[\"GOOGLE_API_KEY\"],\n",
    "    model_name=\"gemini-1.5-pro\",\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "descriptions_file = f\"{experiment_dir}/{model.get_name()}_descriptions.json\"\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "descriptions_file, answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "get_iterative_descriptions(\n",
    "    problem_ids=list(range(1, 101)),\n",
    "    splitted_data_path=\"../data/raw/bongard_open_world_splitted\",\n",
    "    model=model,\n",
    "    output_file=descriptions_file,\n",
    "    task_description=DESCRIBE_ITERATIVELY_SIDE_IMAGES_PROMPT,\n",
    "    final_answer_prompt=FINAL_ANSWER_PROMPT,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "resolve_bongard_with_iterative_prompting(\n",
    "    problem_ids=list(range(1, 101)),\n",
    "    output_file=answers_file,\n",
    "    model=model,\n",
    "    descriptions_file=descriptions_file,\n",
    "    problem_description=PROBLEM_DESCRIPTION,\n",
    "    question=QUESTION,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "evalute_answers_using_model(\n",
    "    answers_file=answers_file,\n",
    "    labels_file=\"../data/raw/bongard_open_world_labels.csv\",\n",
    "    model=model,\n",
    ")\n",
    "count_correct_answers(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Claude 3.5 Sonnet"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = ClaudeMessenger(\n",
    "    api_key=os.environ[\"ANTHROPIC_API_KEY\"],\n",
    "    model_name=\"claude-3-5-sonnet-20240620\",\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "descriptions_file = f\"{experiment_dir}/{model.get_name()}_descriptions.json\"\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "descriptions_file, answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "get_iterative_descriptions(\n",
    "    problem_ids=list(range(1, 101)),\n",
    "    splitted_data_path=\"../data/raw/bongard_open_world_splitted\",\n",
    "    model=model,\n",
    "    output_file=descriptions_file,\n",
    "    task_description=DESCRIBE_ITERATIVELY_SIDE_IMAGES_PROMPT,\n",
    "    final_answer_prompt=FINAL_ANSWER_PROMPT,\n",
    "    keep_image_history=False,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "resolve_bongard_with_iterative_prompting(\n",
    "    problem_ids=list(range(1, 101)),\n",
    "    output_file=answers_file,\n",
    "    model=model,\n",
    "    descriptions_file=descriptions_file,\n",
    "    problem_description=PROBLEM_DESCRIPTION,\n",
    "    question=QUESTION,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "evalute_answers_using_model(\n",
    "    answers_file=answers_file,\n",
    "    labels_file=\"../data/raw/bongard_open_world_labels.csv\",\n",
    "    model=model,\n",
    ")\n",
    "count_correct_answers(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Bongard HOI"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "experiment_dir = \"../data/processed/bongard/experiments/hoi_generation_prompting-iterative\""
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### GPT-4"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = GPTMessenger(\n",
    "    model_name=\"gpt-4o\",\n",
    "    api_key=os.environ[\"OPENAI_API_KEY\"],\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "descriptions_file = f\"{experiment_dir}/{model.get_name()}_descriptions.json\"\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "descriptions_file, answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "get_iterative_descriptions(\n",
    "    problem_ids=list(range(1, 101)),\n",
    "    splitted_data_path=\"../data/raw/bongard_hoi_splitted_mix\",\n",
    "    model=model,\n",
    "    output_file=descriptions_file,\n",
    "    task_description=DESCRIBE_ITERATIVELY_SIDE_IMAGES_PROMPT,\n",
    "    final_answer_prompt=FINAL_ANSWER_PROMPT,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "resolve_bongard_with_iterative_prompting(\n",
    "    problem_ids=list(range(1, 101)),\n",
    "    output_file=answers_file,\n",
    "    model=model,\n",
    "    descriptions_file=descriptions_file,\n",
    "    problem_description=PROBLEM_DESCRIPTION,\n",
    "    question=QUESTION,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "evalute_answers_using_model(\n",
    "    answers_file=answers_file,\n",
    "    labels_file=\"../data/raw/bongard_hoi_mix_labels.csv\",\n",
    "    model=model,\n",
    ")\n",
    "count_correct_answers(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Gemini 1.5 Pro"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = GoogleMessenger(\n",
    "    api_key=os.environ[\"GOOGLE_API_KEY\"],\n",
    "    model_name=\"gemini-1.5-pro\",\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "descriptions_file = f\"{experiment_dir}/{model.get_name()}_descriptions.json\"\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "descriptions_file, answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "get_iterative_descriptions(\n",
    "    problem_ids=list(range(1, 101)),\n",
    "    splitted_data_path=\"../data/raw/bongard_hoi_splitted_mix\",\n",
    "    model=model,\n",
    "    output_file=descriptions_file,\n",
    "    task_description=DESCRIBE_ITERATIVELY_SIDE_IMAGES_PROMPT,\n",
    "    final_answer_prompt=FINAL_ANSWER_PROMPT,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "resolve_bongard_with_iterative_prompting(\n",
    "    problem_ids=list(range(1, 101)),\n",
    "    output_file=answers_file,\n",
    "    model=model,\n",
    "    descriptions_file=descriptions_file,\n",
    "    problem_description=PROBLEM_DESCRIPTION,\n",
    "    question=QUESTION,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "evalute_answers_using_model(\n",
    "    answers_file=answers_file,\n",
    "    labels_file=\"../data/raw/bongard_hoi_mix_labels.csv\",\n",
    "    model=model,\n",
    ")\n",
    "count_correct_answers(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Claude 3.5 Sonnet"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = ClaudeMessenger(\n",
    "    api_key=os.environ[\"ANTHROPIC_API_KEY\"],\n",
    "    model_name=\"claude-3-5-sonnet-20240620\",\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "descriptions_file = f\"{experiment_dir}/{model.get_name()}_descriptions.json\"\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "descriptions_file, answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "get_iterative_descriptions(\n",
    "    problem_ids=list(range(51, 101)),\n",
    "    splitted_data_path=\"../data/raw/bongard_hoi_splitted_mix\",\n",
    "    model=model,\n",
    "    output_file=descriptions_file,\n",
    "    task_description=DESCRIBE_ITERATIVELY_SIDE_IMAGES_PROMPT,\n",
    "    final_answer_prompt=FINAL_ANSWER_PROMPT,\n",
    "    keep_image_history=False,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "resolve_bongard_with_iterative_prompting(\n",
    "    problem_ids=list(range(1, 101)),\n",
    "    output_file=answers_file,\n",
    "    model=model,\n",
    "    descriptions_file=descriptions_file,\n",
    "    problem_description=PROBLEM_DESCRIPTION,\n",
    "    question=QUESTION,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "evalute_answers_using_model(\n",
    "    answers_file=answers_file,\n",
    "    labels_file=\"../data/raw/bongard_hoi_mix_labels.csv\",\n",
    "    model=model,\n",
    ")\n",
    "count_correct_answers(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Bongard RWR"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "bongard_rwr_problem_ids = sorted([int(id) for id in os.listdir(\"../data/raw/bongard_rwr_splitted\") if id.isdigit()])\n",
    "print(bongard_rwr_problem_ids)\n",
    "print('Count:', len(bongard_rwr_problem_ids))"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "experiment_dir = \"../data/processed/bongard/experiments/rwr_generation_prompting-iterative\""
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### GPT-4"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = GPTMessenger(\n",
    "    model_name=\"gpt-4-turbo\",\n",
    "    api_key=os.environ[\"OPENAI_API_KEY\"],\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "descriptions_file = f\"{experiment_dir}/{model.get_name()}_descriptions.json\"\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "descriptions_file, answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "get_iterative_descriptions(\n",
    "    problem_ids=bongard_rwr_problem_ids,\n",
    "    splitted_data_path=\"../data/raw/bongard_rwr\",\n",
    "    model=model,\n",
    "    output_file=descriptions_file,\n",
    "    task_description=DESCRIBE_ITERATIVELY_SIDE_IMAGES_PROMPT,\n",
    "    final_answer_prompt=FINAL_ANSWER_PROMPT,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "resolve_bongard_with_iterative_prompting(\n",
    "    problem_ids=bongard_rwr_problem_ids,\n",
    "    output_file=answers_file,\n",
    "    model=model,\n",
    "    descriptions_file=descriptions_file,\n",
    "    problem_description=PROBLEM_DESCRIPTION,\n",
    "    question=QUESTION,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "evalute_answers_using_model(\n",
    "    answers_file=answers_file,\n",
    "    labels_file=\"../data/raw/bongard_hoi_mix_labels.csv\",\n",
    "    model=model,\n",
    ")\n",
    "count_correct_answers(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Gemini 1.5 Pro"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = GoogleMessenger(\n",
    "    api_key=os.environ[\"GOOGLE_API_KEY\"],\n",
    "    model_name=\"gemini-1.5-pro\",\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "descriptions_file = f\"{experiment_dir}/{model.get_name()}_descriptions.json\"\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "descriptions_file, answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "get_iterative_descriptions(\n",
    "    problem_ids=bongard_rwr_problem_ids,\n",
    "    splitted_data_path=\"../data/raw/bongard_rwr\",\n",
    "    model=model,\n",
    "    output_file=descriptions_file,\n",
    "    task_description=DESCRIBE_ITERATIVELY_SIDE_IMAGES_PROMPT,\n",
    "    final_answer_prompt=FINAL_ANSWER_PROMPT,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "resolve_bongard_with_iterative_prompting(\n",
    "    problem_ids=bongard_rwr_problem_ids,\n",
    "    output_file=answers_file,\n",
    "    model=model,\n",
    "    descriptions_file=descriptions_file,\n",
    "    problem_description=PROBLEM_DESCRIPTION,\n",
    "    question=QUESTION,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "evalute_answers_using_model(\n",
    "    answers_file=answers_file,\n",
    "    labels_file=\"../data/raw/bongard_hoi_mix_labels.csv\",\n",
    "    model=model,\n",
    ")\n",
    "count_correct_answers(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Claude 3.5 Sonnet"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = ClaudeMessenger(\n",
    "    api_key=os.environ[\"ANTHROPIC_API_KEY\"],\n",
    "    model_name=\"claude-3-5-sonnet-20240620\",\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "descriptions_file = f\"{experiment_dir}/{model.get_name()}_descriptions.json\"\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "descriptions_file, answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "get_iterative_descriptions(\n",
    "    problem_ids=bongard_rwr_problem_ids,\n",
    "    splitted_data_path=\"../data/raw/bongard_rwr\",\n",
    "    model=model,\n",
    "    output_file=descriptions_file,\n",
    "    task_description=DESCRIBE_ITERATIVELY_SIDE_IMAGES_PROMPT,\n",
    "    final_answer_prompt=FINAL_ANSWER_PROMPT,\n",
    "    keep_image_history=False,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "resolve_bongard_with_iterative_prompting(\n",
    "    problem_ids=bongard_rwr_problem_ids,\n",
    "    output_file=answers_file,\n",
    "    model=model,\n",
    "    descriptions_file=descriptions_file,\n",
    "    problem_description=PROBLEM_DESCRIPTION,\n",
    "    question=QUESTION,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "evalute_answers_using_model(\n",
    "    answers_file=answers_file,\n",
    "    labels_file=\"../data/raw/bongard_hoi_mix_labels.csv\",\n",
    "    model=model,\n",
    ")\n",
    "count_correct_answers(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [],
   "outputs": [],
   "execution_count": null
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
