{
 "cells": [
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "import os\n",
    "\n",
    "from dotenv import load_dotenv\n",
    "\n",
    "from src.bongard_problems.data import print_classification_summary\n",
    "from src.llm_messenger.messengers.claude_messenger import ClaudeMessenger\n",
    "from src.llm_messenger.messengers.google_messenger import GoogleMessenger\n",
    "from src.llm_messenger.messengers.gpt_messenger import GPTMessenger\n",
    "from src.prompting_techniques.prompt import CommonPrompts, ImageToSidePrompts\n",
    "from src.prompting_techniques.two_at_once_classification import \\\n",
    "    resolve_bongard_with_last_image_to_side_classification_two_sides_at_once\n",
    "\n",
    "load_dotenv()"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Bongard AVR"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": "experiment_dir = \"../data/processed/bongard/experiments/synthetic_binary-classification_joint-image-to-side-shuffle\"",
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### GPT-4"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = GPTMessenger(\n",
    "    model_name=\"gpt-4o\",\n",
    "    api_key=os.environ[\"OPENAI_API_KEY\"],\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "resolve_bongard_with_last_image_to_side_classification_two_sides_at_once(\n",
    "    problem_ids=list(range(1, 101)),\n",
    "    splitted_data_path=\"../data/raw/bongard_splitted\",\n",
    "    output_file=answers_file,\n",
    "    model=model,\n",
    "    problem_description=CommonPrompts.PROBLEM_DESCRIPTION,\n",
    "    question=ImageToSidePrompts.QUESTION,\n",
    "    use_joint_side_image=True,\n",
    "    do_hallucination_intervention=True,\n",
    "    do_shuffle_test_panels=True,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "print_classification_summary(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "model = GPTMessenger(\n",
    "    model_name=\"gpt-4-turbo\",\n",
    "    api_key=os.environ[\"OPENAI_API_KEY\"],\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "resolve_bongard_with_last_image_to_side_classification_two_sides_at_once(\n",
    "    problem_ids=list(range(1, 101)),\n",
    "    splitted_data_path=\"../data/raw/bongard_splitted\",\n",
    "    output_file=answers_file,\n",
    "    model=model,\n",
    "    problem_description=CommonPrompts.PROBLEM_DESCRIPTION,\n",
    "    question=ImageToSidePrompts.QUESTION,\n",
    "    use_joint_side_image=True,\n",
    "    do_hallucination_intervention=True,\n",
    "    do_shuffle_test_panels=True,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": "print_classification_summary(answers_file)",
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Gemini 1.5 Pro"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = GoogleMessenger(\n",
    "    api_key=os.environ[\"GOOGLE_API_KEY\"],\n",
    "    model_name=\"gemini-1.5-pro\",\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "resolve_bongard_with_last_image_to_side_classification_two_sides_at_once(\n",
    "    problem_ids=list(range(1, 101)),\n",
    "    splitted_data_path=\"../data/raw/bongard_splitted\",\n",
    "    output_file=answers_file,\n",
    "    model=model,\n",
    "    problem_description=CommonPrompts.PROBLEM_DESCRIPTION,\n",
    "    question=ImageToSidePrompts.QUESTION,\n",
    "    use_joint_side_image=True,\n",
    "    do_hallucination_intervention=True,\n",
    "    do_shuffle_test_panels=True,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "print_classification_summary(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Claude 3.5 Sonnet"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = ClaudeMessenger(\n",
    "    api_key=os.environ[\"ANTHROPIC_API_KEY\"],\n",
    "    model_name=\"claude-3-5-sonnet-20240620\",\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "resolve_bongard_with_last_image_to_side_classification_two_sides_at_once(\n",
    "    problem_ids=list(range(1, 101)),\n",
    "    splitted_data_path=\"../data/raw/bongard_splitted\",\n",
    "    output_file=answers_file,\n",
    "    model=model,\n",
    "    problem_description=CommonPrompts.PROBLEM_DESCRIPTION,\n",
    "    question=ImageToSidePrompts.QUESTION,\n",
    "    use_joint_side_image=True,\n",
    "    do_hallucination_intervention=True,\n",
    "    do_shuffle_test_panels=True,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "print_classification_summary(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Bongard OpenWorld"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": "experiment_dir = \"../data/processed/bongard/experiments/openworld_binary-classification_joint-image-to-side-shuffle\"",
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### GPT-4"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = GPTMessenger(\n",
    "    model_name=\"gpt-4o\",\n",
    "    api_key=os.environ[\"OPENAI_API_KEY\"],\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "resolve_bongard_with_last_image_to_side_classification_two_sides_at_once(\n",
    "    problem_ids=list(range(1, 101)),\n",
    "    splitted_data_path=\"../data/raw/bongard_open_world_splitted\",\n",
    "    output_file=answers_file,\n",
    "    model=model,\n",
    "    problem_description=CommonPrompts.PROBLEM_DESCRIPTION,\n",
    "    question=ImageToSidePrompts.QUESTION,\n",
    "    use_joint_side_image=True,\n",
    "    do_hallucination_intervention=True,\n",
    "    do_shuffle_test_panels=True,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "print_classification_summary(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "model = GPTMessenger(\n",
    "    model_name=\"gpt-4-turbo\",\n",
    "    api_key=os.environ[\"OPENAI_API_KEY\"],\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "resolve_bongard_with_last_image_to_side_classification_two_sides_at_once(\n",
    "    problem_ids=list(range(1, 101)),\n",
    "    splitted_data_path=\"../data/raw/bongard_open_world_splitted\",\n",
    "    output_file=answers_file,\n",
    "    model=model,\n",
    "    problem_description=CommonPrompts.PROBLEM_DESCRIPTION,\n",
    "    question=ImageToSidePrompts.QUESTION,\n",
    "    use_joint_side_image=True,\n",
    "    do_hallucination_intervention=True,\n",
    "    do_shuffle_test_panels=True,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": "print_classification_summary(answers_file)",
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Gemini 1.5 Pro"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = GoogleMessenger(\n",
    "    api_key=os.environ[\"GOOGLE_API_KEY\"],\n",
    "    model_name=\"gemini-1.5-pro\",\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "resolve_bongard_with_last_image_to_side_classification_two_sides_at_once(\n",
    "    problem_ids=list(range(1, 101)),\n",
    "    splitted_data_path=\"../data/raw/bongard_open_world_splitted\",\n",
    "    output_file=answers_file,\n",
    "    model=model,\n",
    "    problem_description=CommonPrompts.PROBLEM_DESCRIPTION,\n",
    "    question=ImageToSidePrompts.QUESTION,\n",
    "    use_joint_side_image=True,\n",
    "    do_hallucination_intervention=True,\n",
    "    do_shuffle_test_panels=True,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "print_classification_summary(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Claude 3.5 Sonnet"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = ClaudeMessenger(\n",
    "    api_key=os.environ[\"ANTHROPIC_API_KEY\"],\n",
    "    model_name=\"claude-3-5-sonnet-20240620\",\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "resolve_bongard_with_last_image_to_side_classification_two_sides_at_once(\n",
    "    problem_ids=list(range(1, 101)),\n",
    "    splitted_data_path=\"../data/raw/bongard_open_world_splitted\",\n",
    "    output_file=answers_file,\n",
    "    model=model,\n",
    "    problem_description=CommonPrompts.PROBLEM_DESCRIPTION,\n",
    "    question=ImageToSidePrompts.QUESTION,\n",
    "    use_joint_side_image=True,\n",
    "    do_hallucination_intervention=True,\n",
    "    do_shuffle_test_panels=True,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "print_classification_summary(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Bongard HOI"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": "experiment_dir = \"../data/processed/bongard/experiments/hoi_binary-classification_joint-image-to-side-shuffle\"",
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### GPT-4"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = GPTMessenger(\n",
    "    model_name=\"gpt-4o\",\n",
    "    api_key=os.environ[\"OPENAI_API_KEY\"],\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "resolve_bongard_with_last_image_to_side_classification_two_sides_at_once(\n",
    "    problem_ids=list(range(1, 101)),\n",
    "    splitted_data_path=\"../data/raw/bongard_hoi_splitted_mix\",\n",
    "    output_file=answers_file,\n",
    "    model=model,\n",
    "    problem_description=CommonPrompts.PROBLEM_DESCRIPTION,\n",
    "    question=ImageToSidePrompts.QUESTION,\n",
    "    use_joint_side_image=True,\n",
    "    do_hallucination_intervention=True,\n",
    "    do_shuffle_test_panels=True,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "print_classification_summary(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "model = GPTMessenger(\n",
    "    model_name=\"gpt-4-turbo\",\n",
    "    api_key=os.environ[\"OPENAI_API_KEY\"],\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "resolve_bongard_with_last_image_to_side_classification_two_sides_at_once(\n",
    "    problem_ids=list(range(1, 101)),\n",
    "    splitted_data_path=\"../data/raw/bongard_hoi_splitted_mix\",\n",
    "    output_file=answers_file,\n",
    "    model=model,\n",
    "    problem_description=CommonPrompts.PROBLEM_DESCRIPTION,\n",
    "    question=ImageToSidePrompts.QUESTION,\n",
    "    use_joint_side_image=True,\n",
    "    do_hallucination_intervention=True,\n",
    "    do_shuffle_test_panels=True,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": "print_classification_summary(answers_file)",
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Gemini 1.5 Pro"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = GoogleMessenger(\n",
    "    api_key=os.environ[\"GOOGLE_API_KEY\"],\n",
    "    model_name=\"gemini-1.5-pro\",\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "resolve_bongard_with_last_image_to_side_classification_two_sides_at_once(\n",
    "    problem_ids=list(range(1, 101)),\n",
    "    splitted_data_path=\"../data/raw/bongard_hoi_splitted_mix\",\n",
    "    output_file=answers_file,\n",
    "    model=model,\n",
    "    problem_description=CommonPrompts.PROBLEM_DESCRIPTION,\n",
    "    question=ImageToSidePrompts.QUESTION,\n",
    "    use_joint_side_image=True,\n",
    "    do_hallucination_intervention=True,\n",
    "    do_shuffle_test_panels=True,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "print_classification_summary(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Claude 3.5 Sonnet"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = ClaudeMessenger(\n",
    "    api_key=os.environ[\"ANTHROPIC_API_KEY\"],\n",
    "    model_name=\"claude-3-5-sonnet-20240620\",\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "resolve_bongard_with_last_image_to_side_classification_two_sides_at_once(\n",
    "    problem_ids=list(range(1, 101)),\n",
    "    splitted_data_path=\"../data/raw/bongard_hoi_splitted_mix\",\n",
    "    output_file=answers_file,\n",
    "    model=model,\n",
    "    problem_description=CommonPrompts.PROBLEM_DESCRIPTION,\n",
    "    question=ImageToSidePrompts.QUESTION,\n",
    "    use_joint_side_image=True,\n",
    "    do_hallucination_intervention=True,\n",
    "    do_shuffle_test_panels=True,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "print_classification_summary(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## Bongard RWR"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "splitted_data_path = \"../data/raw/bongard_rwr\"\n",
    "rwr_problem_ids = {int(pid) for pid in os.listdir(splitted_data_path)}\n",
    "problem_ids = [pid for pid in list(range(1, 101)) if pid in rwr_problem_ids]"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": "experiment_dir = \"../data/processed/bongard/experiments/rwr_binary-classification_joint-image-to-side-shuffle\"",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "### GPT-4"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "model = GPTMessenger(\n",
    "    model_name=\"gpt-4o\",\n",
    "    api_key=os.environ[\"OPENAI_API_KEY\"],\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "resolve_bongard_with_last_image_to_side_classification_two_sides_at_once(\n",
    "    problem_ids=problem_ids,\n",
    "    splitted_data_path=splitted_data_path,\n",
    "    output_file=answers_file,\n",
    "    model=model,\n",
    "    problem_description=CommonPrompts.PROBLEM_DESCRIPTION,\n",
    "    question=ImageToSidePrompts.QUESTION,\n",
    "    use_joint_side_image=True,\n",
    "    do_hallucination_intervention=True,\n",
    "    do_shuffle_test_panels=True,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": "print_classification_summary(answers_file)",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "model = GPTMessenger(\n",
    "    model_name=\"gpt-4-turbo\",\n",
    "    api_key=os.environ[\"OPENAI_API_KEY\"],\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "resolve_bongard_with_last_image_to_side_classification_two_sides_at_once(\n",
    "    problem_ids=problem_ids,\n",
    "    splitted_data_path=splitted_data_path,\n",
    "    output_file=answers_file,\n",
    "    model=model,\n",
    "    problem_description=CommonPrompts.PROBLEM_DESCRIPTION,\n",
    "    question=ImageToSidePrompts.QUESTION,\n",
    "    use_joint_side_image=True,\n",
    "    do_hallucination_intervention=True,\n",
    "    do_shuffle_test_panels=True,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": "print_classification_summary(answers_file)",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "### Gemini 1.5 Pro"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "model = GoogleMessenger(\n",
    "    api_key=os.environ[\"GOOGLE_API_KEY\"],\n",
    "    model_name=\"gemini-1.5-pro\",\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "resolve_bongard_with_last_image_to_side_classification_two_sides_at_once(\n",
    "    problem_ids=problem_ids,\n",
    "    splitted_data_path=splitted_data_path,\n",
    "    output_file=answers_file,\n",
    "    model=model,\n",
    "    problem_description=CommonPrompts.PROBLEM_DESCRIPTION,\n",
    "    question=ImageToSidePrompts.QUESTION,\n",
    "    use_joint_side_image=True,\n",
    "    do_hallucination_intervention=True,\n",
    "    do_shuffle_test_panels=True,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": "print_classification_summary(answers_file)",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "### Claude 3.5 Sonnet"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "model = ClaudeMessenger(\n",
    "    api_key=os.environ[\"ANTHROPIC_API_KEY\"],\n",
    "    model_name=\"claude-3-5-sonnet-20240620\",\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "resolve_bongard_with_last_image_to_side_classification_two_sides_at_once(\n",
    "    problem_ids=problem_ids,\n",
    "    splitted_data_path=splitted_data_path,\n",
    "    output_file=answers_file,\n",
    "    model=model,\n",
    "    problem_description=CommonPrompts.PROBLEM_DESCRIPTION,\n",
    "    question=ImageToSidePrompts.QUESTION,\n",
    "    use_joint_side_image=True,\n",
    "    do_hallucination_intervention=True,\n",
    "    do_shuffle_test_panels=True,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": "print_classification_summary(answers_file)",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": "",
   "outputs": [],
   "execution_count": null
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
