{
 "cells": [
  {
   "cell_type": "code",
   "id": "initial_id",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "import os\n",
    "\n",
    "from dotenv import load_dotenv\n",
    "\n",
    "from src.evaluation.model import evalute_answers_using_model, count_correct_answers\n",
    "from src.llm_messenger.messengers.claude_messenger import ClaudeMessenger\n",
    "from src.llm_messenger.messengers.google_messenger import GoogleMessenger\n",
    "from src.llm_messenger.messengers.gpt_messenger import GPTMessenger\n",
    "from src.prompting_techniques.direct import resolve_bongard_with_direct_prompting\n",
    "\n",
    "load_dotenv()"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "5deccb6f9a7eab42",
   "metadata": {},
   "source": [
    "PROBLEM_DESCRIPTION = \"\"\"\n",
    "A Bongard Problem is composed of left and right sides separated by a line. Each side contains six images. All images belonging to one side present a common concept, which is lacking in all images from the other side. The goal is to describe the rule that fits all images on the left side, but none on the right, and, conversely, the rule that fits all images on the right side, but none on the left. The description of the rule should be simple and concise.\n",
    "Example 1: All shapes on left are small. All shapes on right are big.\n",
    "Example 2: The left side contains circles. The right side contains triangles.\n",
    "\"\"\".strip()\n",
    "\n",
    "QUESTION = \"\"\"\n",
    "You are a vision understanding module designed to provide short, clear and accurate answers. Your goal is to solve the provided Bongard Problem. What is the difference between the two sides of the problem?\n",
    "\"\"\".strip()"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "id": "d05c7f9437f983b9",
   "metadata": {},
   "source": [
    "# Bongard AVR"
   ]
  },
  {
   "cell_type": "code",
   "id": "3aad270ef9dc22fd",
   "metadata": {},
   "source": [
    "experiment_dir = \"../data/processed/bongard/experiments/synthetic_generation_prompting-direct\""
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "id": "21410eae1f1f751c",
   "metadata": {},
   "source": [
    "## GPT-4"
   ]
  },
  {
   "cell_type": "code",
   "id": "5e3bcb6d8ec42884",
   "metadata": {},
   "source": [
    "model = GPTMessenger(\n",
    "    model_name=\"gpt-4o\",\n",
    "    api_key=os.environ[\"OPENAI_API_KEY\"],\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "31b09f7a59909d38",
   "metadata": {},
   "source": [
    "resolve_bongard_with_direct_prompting(\n",
    "    problem_ids=list(range(1, 101)),\n",
    "    splitted_data_path=\"../data/raw/bongard_splitted\",\n",
    "    output_file=answers_file,\n",
    "    model=model,\n",
    "    problem_description=PROBLEM_DESCRIPTION,\n",
    "    question=QUESTION,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "557a0d416f2ebf6",
   "metadata": {},
   "source": [
    "evalute_answers_using_model(\n",
    "    answers_file=answers_file,\n",
    "    labels_file=\"../data/raw/labels.csv\",\n",
    "    model=model,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "7d28493cf03ce46d",
   "metadata": {},
   "source": [
    "count_correct_answers(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "id": "cd2de81d140b8114",
   "metadata": {},
   "source": [
    "## Gemini 1.5 Pro"
   ]
  },
  {
   "cell_type": "code",
   "id": "378112d14c9197a6",
   "metadata": {},
   "source": [
    "model = GoogleMessenger(\n",
    "    api_key=os.environ[\"GOOGLE_API_KEY\"],\n",
    "    model_name=\"gemini-1.5-pro\",\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "d0eb573c87c4e337",
   "metadata": {},
   "source": [
    "resolve_bongard_with_direct_prompting(\n",
    "    problem_ids=list(range(1, 101)),\n",
    "    splitted_data_path=\"../data/raw/bongard_splitted\",\n",
    "    output_file=answers_file,\n",
    "    model=model,\n",
    "    problem_description=PROBLEM_DESCRIPTION,\n",
    "    question=QUESTION,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "865ce5385d042036",
   "metadata": {},
   "source": [
    "evalute_answers_using_model(\n",
    "    answers_file=answers_file,\n",
    "    labels_file=\"../data/raw/labels.csv\",\n",
    "    model=model,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "875850c95fdb4bd2",
   "metadata": {},
   "source": [
    "count_correct_answers(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "id": "72d0c00e7e453ba0",
   "metadata": {},
   "source": [
    "## Claude 3.5 Sonnet"
   ]
  },
  {
   "cell_type": "code",
   "id": "41828e827fc0fc41",
   "metadata": {},
   "source": [
    "model = ClaudeMessenger(\n",
    "    api_key=os.environ[\"ANTHROPIC_API_KEY\"],\n",
    "    model_name=\"claude-3-5-sonnet-20240620\",\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "542bddd467656eec",
   "metadata": {},
   "source": [
    "resolve_bongard_with_direct_prompting(\n",
    "    problem_ids=list(range(1, 101)),\n",
    "    splitted_data_path=\"../data/raw/bongard_splitted\",\n",
    "    output_file=answers_file,\n",
    "    model=model,\n",
    "    problem_description=PROBLEM_DESCRIPTION,\n",
    "    question=QUESTION,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "6f079454fa8a80a7",
   "metadata": {},
   "source": [
    "evalute_answers_using_model(\n",
    "    answers_file=answers_file,\n",
    "    labels_file=\"../data/raw/labels.csv\",\n",
    "    model=model,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "f047d3f707a16a72",
   "metadata": {},
   "source": [
    "count_correct_answers(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "id": "b96990b10d9e5500",
   "metadata": {},
   "source": [
    "# Bongard OpenWorld"
   ]
  },
  {
   "cell_type": "code",
   "id": "e521424fbe722fc6",
   "metadata": {},
   "source": [
    "experiment_dir = \"../data/processed/bongard/experiments/openworld_generation_prompting-direct\""
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "id": "56fe0bcb23e8575d",
   "metadata": {},
   "source": [
    "## GPT-4"
   ]
  },
  {
   "cell_type": "code",
   "id": "fef8be60773e9fcf",
   "metadata": {},
   "source": [
    "model = GPTMessenger(\n",
    "    model_name=\"gpt-4o\",\n",
    "    api_key=os.environ[\"OPENAI_API_KEY\"],\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "9db8b731764bb1bd",
   "metadata": {},
   "source": [
    "resolve_bongard_with_direct_prompting(\n",
    "    problem_ids=list(range(1, 101)),\n",
    "    splitted_data_path=\"../data/raw/bongard_open_world_splitted\",\n",
    "    output_file=answers_file,\n",
    "    model=model,\n",
    "    problem_description=PROBLEM_DESCRIPTION,\n",
    "    question=QUESTION,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "364993b18d4e7209",
   "metadata": {},
   "source": [
    "evalute_answers_using_model(\n",
    "    answers_file=answers_file,\n",
    "    labels_file=\"../data/raw/bongard_open_world_labels.csv\",\n",
    "    model=model,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "43fe15a57d3021ce",
   "metadata": {},
   "source": [
    "count_correct_answers(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "id": "feb02adb0534c260",
   "metadata": {},
   "source": [
    "## Gemini 1.5 Pro"
   ]
  },
  {
   "cell_type": "code",
   "id": "3840d56854a77016",
   "metadata": {},
   "source": [
    "model = GoogleMessenger(\n",
    "    api_key=os.environ[\"GOOGLE_API_KEY\"],\n",
    "    model_name=\"gemini-1.5-pro\",\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "2e70f2f5d6c5a3cd",
   "metadata": {},
   "source": [
    "resolve_bongard_with_direct_prompting(\n",
    "    problem_ids=list(range(1, 101)),\n",
    "    splitted_data_path=\"../data/raw/bongard_open_world_splitted\",\n",
    "    output_file=answers_file,\n",
    "    model=model,\n",
    "    problem_description=PROBLEM_DESCRIPTION,\n",
    "    question=QUESTION,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "9fe8ce71b7c962ab",
   "metadata": {},
   "source": [
    "evalute_answers_using_model(\n",
    "    answers_file=answers_file,\n",
    "    labels_file=\"../data/raw/bongard_open_world_labels.csv\",\n",
    "    model=model,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "39a9bcae2bf78457",
   "metadata": {},
   "source": [
    "count_correct_answers(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "id": "138fda20c20763a",
   "metadata": {},
   "source": [
    "## Claude 3.5 Sonnet"
   ]
  },
  {
   "cell_type": "code",
   "id": "bec2fd5cf65282a6",
   "metadata": {},
   "source": [
    "model = ClaudeMessenger(\n",
    "    api_key=os.environ[\"ANTHROPIC_API_KEY\"],\n",
    "    model_name=\"claude-3-5-sonnet-20240620\",\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "388fb00109396b80",
   "metadata": {},
   "source": [
    "resolve_bongard_with_direct_prompting(\n",
    "    problem_ids=list(range(1, 101)),\n",
    "    splitted_data_path=\"../data/raw/bongard_open_world_splitted\",\n",
    "    output_file=answers_file,\n",
    "    model=model,\n",
    "    problem_description=PROBLEM_DESCRIPTION,\n",
    "    question=QUESTION,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "e5aca7233e503b2e",
   "metadata": {},
   "source": [
    "evalute_answers_using_model(\n",
    "    answers_file=answers_file,\n",
    "    labels_file=\"../data/raw/bongard_open_world_labels.csv\",\n",
    "    model=model,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "55f588808e819279",
   "metadata": {},
   "source": [
    "count_correct_answers(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "id": "4a12d849c45cc53b",
   "metadata": {},
   "source": [
    "# Bongard HOI"
   ]
  },
  {
   "cell_type": "code",
   "id": "c42e44ca5313eee4",
   "metadata": {},
   "source": [
    "experiment_dir = \"../data/processed/bongard/experiments/hoi_generation_prompting-direct\""
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "id": "b0be08411ffffa7e",
   "metadata": {},
   "source": [
    "## GPT-4"
   ]
  },
  {
   "cell_type": "code",
   "id": "21033e5e74c7a72a",
   "metadata": {},
   "source": [
    "model = GPTMessenger(\n",
    "    model_name=\"gpt-4o\",\n",
    "    api_key=os.environ[\"OPENAI_API_KEY\"],\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "fbb4c5781ecd8fdd",
   "metadata": {},
   "source": [
    "resolve_bongard_with_direct_prompting(\n",
    "    problem_ids=list(range(1, 101)),\n",
    "    splitted_data_path=\"../data/raw/bongard_hoi_splitted_mix\",\n",
    "    output_file=answers_file,\n",
    "    model=model,\n",
    "    problem_description=PROBLEM_DESCRIPTION,\n",
    "    question=QUESTION,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "24e7dbcef712994",
   "metadata": {},
   "source": [
    "evalute_answers_using_model(\n",
    "    answers_file=answers_file,\n",
    "    labels_file=\"../data/raw/bongard_hoi_mix_labels.csv\",\n",
    "    model=model,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "86183d996e81ab2b",
   "metadata": {},
   "source": [
    "count_correct_answers(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "id": "4e7583f9013ccda6",
   "metadata": {},
   "source": [
    "## Gemini 1.5 Pro"
   ]
  },
  {
   "cell_type": "code",
   "id": "f6fa9189634ede69",
   "metadata": {},
   "source": [
    "model = GoogleMessenger(\n",
    "    api_key=os.environ[\"GOOGLE_API_KEY\"],\n",
    "    model_name=\"gemini-1.5-pro\",\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "6c8c2ddcf265c92f",
   "metadata": {},
   "source": [
    "resolve_bongard_with_direct_prompting(\n",
    "    problem_ids=list(range(1, 101)),\n",
    "    splitted_data_path=\"../data/raw/bongard_hoi_splitted_mix\",\n",
    "    output_file=answers_file,\n",
    "    model=model,\n",
    "    problem_description=PROBLEM_DESCRIPTION,\n",
    "    question=QUESTION,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "a151f29420e0b903",
   "metadata": {},
   "source": [
    "evalute_answers_using_model(\n",
    "    answers_file=answers_file,\n",
    "    labels_file=\"../data/raw/bongard_hoi_mix_labels.csv\",\n",
    "    model=model,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "63607b1d963e3af8",
   "metadata": {},
   "source": [
    "count_correct_answers(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "id": "d865dce8888c3bcd",
   "metadata": {},
   "source": [
    "## Claude 3.5 Sonnet"
   ]
  },
  {
   "cell_type": "code",
   "id": "399beb72a45b89b8",
   "metadata": {},
   "source": [
    "model = ClaudeMessenger(\n",
    "    api_key=os.environ[\"ANTHROPIC_API_KEY\"],\n",
    "    model_name=\"claude-3-5-sonnet-20240620\",\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "5843be4888457504",
   "metadata": {},
   "source": [
    "resolve_bongard_with_direct_prompting(\n",
    "    problem_ids=list(range(1, 101)),\n",
    "    splitted_data_path=\"../data/raw/bongard_hoi_splitted_mix\",\n",
    "    output_file=answers_file,\n",
    "    model=model,\n",
    "    problem_description=PROBLEM_DESCRIPTION,\n",
    "    question=QUESTION,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "765919eaff086dfc",
   "metadata": {},
   "source": [
    "evalute_answers_using_model(\n",
    "    answers_file=answers_file,\n",
    "    labels_file=\"../data/raw/bongard_hoi_mix_labels.csv\",\n",
    "    model=model,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "fd344a7391e679a5",
   "metadata": {},
   "source": [
    "count_correct_answers(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "id": "dc4b898de97bdbb4",
   "metadata": {},
   "source": [
    "# Bongard RWR"
   ]
  },
  {
   "cell_type": "code",
   "id": "4038b95a",
   "metadata": {},
   "source": [
    "bongard_rwr_problem_ids = sorted([int(id) for id in os.listdir(\"../data/raw/bongard_rwr_splitted\") if id.isdigit()])\n",
    "print(bongard_rwr_problem_ids)\n",
    "print('Count:', len(bongard_rwr_problem_ids))"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "77b15bb1d623acee",
   "metadata": {},
   "source": [
    "experiment_dir = \"../data/processed/bongard/experiments/rwr_generation_prompting-direct\""
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "id": "7bdccd8ef75169fd",
   "metadata": {},
   "source": [
    "## GPT-4"
   ]
  },
  {
   "cell_type": "code",
   "id": "86138f65faaf780b",
   "metadata": {},
   "source": [
    "model = GPTMessenger(\n",
    "    model_name=\"gpt-4-turbo\",\n",
    "    api_key=os.environ[\"OPENAI_API_KEY\"],\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "405d88fe3c9c4dd4",
   "metadata": {},
   "source": [
    "resolve_bongard_with_direct_prompting(\n",
    "    problem_ids=bongard_rwr_problem_ids,\n",
    "    splitted_data_path=\"../data/raw/bongard_rwr\",\n",
    "    output_file=answers_file,\n",
    "    model=model,\n",
    "    problem_description=PROBLEM_DESCRIPTION,\n",
    "    question=QUESTION,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "f803fee3ffc3828e",
   "metadata": {},
   "source": [
    "evalute_answers_using_model(\n",
    "    answers_file=answers_file,\n",
    "    labels_file=\"../data/raw/bongard_hoi_mix_labels.csv\",\n",
    "    model=model,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "c3b38bc03f75cc27",
   "metadata": {},
   "source": [
    "count_correct_answers(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "id": "23a1f3a0a797a61",
   "metadata": {},
   "source": [
    "## Gemini 1.5 Pro"
   ]
  },
  {
   "cell_type": "code",
   "id": "7517c2c16c598dee",
   "metadata": {},
   "source": [
    "model = GoogleMessenger(\n",
    "    api_key=os.environ[\"GOOGLE_API_KEY\"],\n",
    "    model_name=\"gemini-1.5-pro\",\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "9dfe1c2e84ff9b73",
   "metadata": {},
   "source": [
    "resolve_bongard_with_direct_prompting(\n",
    "    problem_ids=bongard_rwr_problem_ids,\n",
    "    splitted_data_path=\"../data/raw/bongard_rwr\",\n",
    "    output_file=answers_file,\n",
    "    model=model,\n",
    "    problem_description=PROBLEM_DESCRIPTION,\n",
    "    question=QUESTION,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "f0c948beb57253f4",
   "metadata": {},
   "source": [
    "evalute_answers_using_model(\n",
    "    answers_file=answers_file,\n",
    "    labels_file=\"../data/raw/bongard_hoi_mix_labels.csv\",\n",
    "    model=model,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "c41eae074c9e327",
   "metadata": {},
   "source": [
    "count_correct_answers(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "id": "b63a4c0611db3c42",
   "metadata": {},
   "source": [
    "## Claude 3.5 Sonnet"
   ]
  },
  {
   "cell_type": "code",
   "id": "2d4f49a28d3ff9c5",
   "metadata": {},
   "source": [
    "model = ClaudeMessenger(\n",
    "    api_key=os.environ[\"ANTHROPIC_API_KEY\"],\n",
    "    model_name=\"claude-3-5-sonnet-20240620\",\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "5ed28415e98ca231",
   "metadata": {},
   "source": [
    "resolve_bongard_with_direct_prompting(\n",
    "    problem_ids=bongard_rwr_problem_ids,\n",
    "    splitted_data_path=\"../data/raw/bongard_rwr\",\n",
    "    output_file=answers_file,\n",
    "    model=model,\n",
    "    problem_description=PROBLEM_DESCRIPTION,\n",
    "    question=QUESTION,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "cadb4ca1d7077a21",
   "metadata": {},
   "source": [
    "evalute_answers_using_model(\n",
    "    answers_file=answers_file,\n",
    "    labels_file=\"../data/raw/bongard_hoi_mix_labels.csv\",\n",
    "    model=model,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "a548539959fe3c52",
   "metadata": {},
   "source": [
    "count_correct_answers(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "f2d80b26ef5b3f37",
   "metadata": {},
   "source": [],
   "outputs": [],
   "execution_count": null
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
