{
 "cells": [
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "import os\n",
    "\n",
    "from dotenv import load_dotenv\n",
    "\n",
    "from src.evaluation.model import get_perfect_resolve_attempt, get_wrong_resolve_attempt, check_solution_correctness, get_resolved_problems, get_unresolved_problems\n",
    "from src.llm_messenger.messengers.claude_messenger import ClaudeMessenger\n",
    "from src.llm_messenger.messengers.google_messenger import GoogleMessenger\n",
    "from src.llm_messenger.messengers.gpt_messenger import GPTMessenger\n",
    "from src.prompting_techniques.prompt import IsLabelCorrectPrompts, CommonPrompts\n",
    "\n",
    "load_dotenv()"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Bongard AVR"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Binary classification: correct answers as labels"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "experiment_dir = \"../data/processed/bongard/experiments/synthetic_binary-classification_joint-correct-answers\""
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### GPT-4"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = GPTMessenger(\n",
    "    model_name=\"gpt-4o\",\n",
    "    api_key=os.environ[\"OPENAI_API_KEY\"],\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "perfect_resolve_attempt = get_perfect_resolve_attempt(\n",
    "    list(range(1, 101)),\n",
    "    \"../data/raw/labels.csv\",\n",
    "    answers_file,\n",
    ")\n",
    "perfect_resolve_attempt.save(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "check_solution_correctness(\n",
    "    answers_file,\n",
    "    \"../data/raw/bongard_splitted\",\n",
    "    model,\n",
    "    description=CommonPrompts.PROBLEM_DESCRIPTION,\n",
    "    question=IsLabelCorrectPrompts.QUESTION,\n",
    "    use_joint_image=True,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "correct_evaluations = get_resolved_problems(answers_file)\n",
    "print(f\"{model.get_name()} scored {len(correct_evaluations)} correct answers: {correct_evaluations}\")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = GPTMessenger(\n",
    "    model_name=\"gpt-4-turbo\",\n",
    "    api_key=os.environ[\"OPENAI_API_KEY\"],\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "perfect_resolve_attempt = get_perfect_resolve_attempt(\n",
    "    list(range(1, 101)),\n",
    "    \"../data/raw/labels.csv\",\n",
    "    answers_file,\n",
    ")\n",
    "perfect_resolve_attempt.save(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "check_solution_correctness(\n",
    "    answers_file,\n",
    "    \"../data/raw/bongard_splitted\",\n",
    "    model,\n",
    "    description=CommonPrompts.PROBLEM_DESCRIPTION,\n",
    "    question=IsLabelCorrectPrompts.QUESTION,\n",
    "    use_joint_image=True,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "correct_evaluations = get_resolved_problems(answers_file)\n",
    "print(f\"{model.get_name()} scored {len(correct_evaluations)} correct answers: {correct_evaluations}\")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Gemini 1.5 Pro"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = GoogleMessenger(\n",
    "    api_key=os.environ[\"GOOGLE_API_KEY\"],\n",
    "    model_name=\"gemini-1.5-pro\",\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "perfect_resolve_attempt = get_perfect_resolve_attempt(\n",
    "    list(range(1, 101)),\n",
    "    \"../data/raw/labels.csv\",\n",
    "    answers_file,\n",
    ")\n",
    "perfect_resolve_attempt.save(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "check_solution_correctness(\n",
    "    answers_file,\n",
    "    \"../data/raw/bongard_splitted\",\n",
    "    model,\n",
    "    description=CommonPrompts.PROBLEM_DESCRIPTION,\n",
    "    question=IsLabelCorrectPrompts.QUESTION,\n",
    "    use_joint_image=True,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "correct_evaluations = get_resolved_problems(answers_file)\n",
    "print(f\"{model.get_name()} scored {len(correct_evaluations)} correct answers: {correct_evaluations}\")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Claude 3.5 Sonnet"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = ClaudeMessenger(\n",
    "    api_key=os.environ[\"ANTHROPIC_API_KEY\"],\n",
    "    model_name=\"claude-3-5-sonnet-20240620\",\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "perfect_resolve_attempt = get_perfect_resolve_attempt(\n",
    "    list(range(1, 101)),\n",
    "    \"../data/raw/labels.csv\",\n",
    "    answers_file,\n",
    ")\n",
    "perfect_resolve_attempt.save(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "check_solution_correctness(\n",
    "    answers_file,\n",
    "    \"../data/raw/bongard_splitted\",\n",
    "    model,\n",
    "    description=CommonPrompts.PROBLEM_DESCRIPTION,\n",
    "    question=IsLabelCorrectPrompts.QUESTION,\n",
    "    use_joint_image=True,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "correct_evaluations = get_resolved_problems(answers_file)\n",
    "print(f\"{model.get_name()} scored {len(correct_evaluations)} correct answers: {correct_evaluations}\")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Binary classification: incorrect answers as labels"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "experiment_dir = \"../data/processed/bongard/experiments/synthetic_binary-classification_joint-incorrect-answers\""
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "offset = 20\n",
    "circular_buffer_size = 100"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### GPT-4"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = GPTMessenger(\n",
    "    model_name=\"gpt-4o\",\n",
    "    api_key=os.environ[\"OPENAI_API_KEY\"],\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "perfect_resolve_attempt = get_wrong_resolve_attempt(\n",
    "    list(range(1, 101)),\n",
    "    \"../data/raw/labels.csv\",\n",
    "    answers_file,\n",
    "    offset=offset,\n",
    "    circular_buffer_size=circular_buffer_size,\n",
    ")\n",
    "perfect_resolve_attempt.save(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "check_solution_correctness(\n",
    "    answers_file,\n",
    "    \"../data/raw/bongard_splitted\",\n",
    "    model,\n",
    "    description=CommonPrompts.PROBLEM_DESCRIPTION,\n",
    "    question=IsLabelCorrectPrompts.QUESTION,\n",
    "    use_joint_image=True,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "correct_evaluations = get_unresolved_problems(answers_file)\n",
    "print(f\"{model.get_name()} scored {len(correct_evaluations)} correct answers: {correct_evaluations}\")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = GPTMessenger(\n",
    "    model_name=\"gpt-4-turbo\",\n",
    "    api_key=os.environ[\"OPENAI_API_KEY\"],\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "perfect_resolve_attempt = get_wrong_resolve_attempt(\n",
    "    list(range(1, 101)),\n",
    "    \"../data/raw/labels.csv\",\n",
    "    answers_file,\n",
    "    offset=offset,\n",
    "    circular_buffer_size=circular_buffer_size,\n",
    ")\n",
    "perfect_resolve_attempt.save(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "check_solution_correctness(\n",
    "    answers_file,\n",
    "    \"../data/raw/bongard_splitted\",\n",
    "    model,\n",
    "    description=CommonPrompts.PROBLEM_DESCRIPTION,\n",
    "    question=IsLabelCorrectPrompts.QUESTION,\n",
    "    use_joint_image=True,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "correct_evaluations = get_unresolved_problems(answers_file)\n",
    "print(f\"{model.get_name()} scored {len(correct_evaluations)} correct answers: {correct_evaluations}\")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Gemini 1.5 Pro"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = GoogleMessenger(\n",
    "    api_key=os.environ[\"GOOGLE_API_KEY\"],\n",
    "    model_name=\"gemini-1.5-pro\",\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "perfect_resolve_attempt = get_wrong_resolve_attempt(\n",
    "    list(range(1, 101)),\n",
    "    \"../data/raw/labels.csv\",\n",
    "    answers_file,\n",
    "    offset=offset,\n",
    "    circular_buffer_size=circular_buffer_size,\n",
    ")\n",
    "perfect_resolve_attempt.save(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "check_solution_correctness(\n",
    "    answers_file,\n",
    "    \"../data/raw/bongard_splitted\",\n",
    "    model,\n",
    "    description=CommonPrompts.PROBLEM_DESCRIPTION,\n",
    "    question=IsLabelCorrectPrompts.QUESTION,\n",
    "    use_joint_image=True,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "correct_evaluations = get_unresolved_problems(answers_file)\n",
    "print(f\"{model.get_name()} scored {len(correct_evaluations)} correct answers: {correct_evaluations}\")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Claude 3.5 Sonnet"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = ClaudeMessenger(\n",
    "    api_key=os.environ[\"ANTHROPIC_API_KEY\"],\n",
    "    model_name=\"claude-3-5-sonnet-20240620\",\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "perfect_resolve_attempt = get_wrong_resolve_attempt(\n",
    "    list(range(1, 101)),\n",
    "    \"../data/raw/labels.csv\",\n",
    "    answers_file,\n",
    "    offset=offset,\n",
    "    circular_buffer_size=circular_buffer_size,\n",
    ")\n",
    "perfect_resolve_attempt.save(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "check_solution_correctness(\n",
    "    answers_file,\n",
    "    \"../data/raw/bongard_splitted\",\n",
    "    model,\n",
    "    description=CommonPrompts.PROBLEM_DESCRIPTION,\n",
    "    question=IsLabelCorrectPrompts.QUESTION,\n",
    "    use_joint_image=True,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "correct_evaluations = get_unresolved_problems(answers_file)\n",
    "print(f\"{model.get_name()} scored {len(correct_evaluations)} correct answers: {correct_evaluations}\")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Bongard OpenWorld"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Binary classification: correct answers as labels"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "experiment_dir = \"../data/processed/bongard/experiments/openworld_binary-classification_joint-correct-answers\""
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### GPT-4"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = GPTMessenger(\n",
    "    model_name=\"gpt-4o\",\n",
    "    api_key=os.environ[\"OPENAI_API_KEY\"],\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "perfect_resolve_attempt = get_perfect_resolve_attempt(\n",
    "    list(range(1, 101)),\n",
    "    \"../data/raw/bongard_open_world_labels.csv\",\n",
    "    answers_file,\n",
    ")\n",
    "perfect_resolve_attempt.save(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "check_solution_correctness(\n",
    "    answers_file,\n",
    "    \"../data/raw/bongard_open_world_splitted\",\n",
    "    model,\n",
    "    description=CommonPrompts.PROBLEM_DESCRIPTION,\n",
    "    question=IsLabelCorrectPrompts.QUESTION,\n",
    "    use_joint_image=True,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "correct_evaluations = get_resolved_problems(answers_file)\n",
    "print(f\"{model.get_name()} scored {len(correct_evaluations)} correct answers: {correct_evaluations}\")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = GPTMessenger(\n",
    "    model_name=\"gpt-4-turbo\",\n",
    "    api_key=os.environ[\"OPENAI_API_KEY\"],\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "perfect_resolve_attempt = get_perfect_resolve_attempt(\n",
    "    list(range(1, 101)),\n",
    "    \"../data/raw/bongard_open_world_labels.csv\",\n",
    "    answers_file,\n",
    ")\n",
    "perfect_resolve_attempt.save(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "check_solution_correctness(\n",
    "    answers_file,\n",
    "    \"../data/raw/bongard_open_world_splitted\",\n",
    "    model,\n",
    "    description=CommonPrompts.PROBLEM_DESCRIPTION,\n",
    "    question=IsLabelCorrectPrompts.QUESTION,\n",
    "    use_joint_image=True,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "correct_evaluations = get_resolved_problems(answers_file)\n",
    "print(f\"{model.get_name()} scored {len(correct_evaluations)} correct answers: {correct_evaluations}\")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Gemini 1.5 Pro"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = GoogleMessenger(\n",
    "    api_key=os.environ[\"GOOGLE_API_KEY\"],\n",
    "    model_name=\"gemini-1.5-pro\",\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "perfect_resolve_attempt = get_perfect_resolve_attempt(\n",
    "    list(range(1, 101)),\n",
    "    \"../data/raw/bongard_open_world_labels.csv\",\n",
    "    answers_file,\n",
    ")\n",
    "perfect_resolve_attempt.save(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "check_solution_correctness(\n",
    "    answers_file,\n",
    "    \"../data/raw/bongard_open_world_splitted\",\n",
    "    model,\n",
    "    description=CommonPrompts.PROBLEM_DESCRIPTION,\n",
    "    question=IsLabelCorrectPrompts.QUESTION,\n",
    "    use_joint_image=True,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "correct_evaluations = get_resolved_problems(answers_file)\n",
    "print(f\"{model.get_name()} scored {len(correct_evaluations)} correct answers: {correct_evaluations}\")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Claude 3.5 Sonnet"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = ClaudeMessenger(\n",
    "    api_key=os.environ[\"ANTHROPIC_API_KEY\"],\n",
    "    model_name=\"claude-3-5-sonnet-20240620\",\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "perfect_resolve_attempt = get_perfect_resolve_attempt(\n",
    "    list(range(1, 101)),\n",
    "    \"../data/raw/bongard_open_world_labels.csv\",\n",
    "    answers_file,\n",
    ")\n",
    "perfect_resolve_attempt.save(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "check_solution_correctness(\n",
    "    answers_file,\n",
    "    \"../data/raw/bongard_open_world_splitted\",\n",
    "    model,\n",
    "    description=CommonPrompts.PROBLEM_DESCRIPTION,\n",
    "    question=IsLabelCorrectPrompts.QUESTION,\n",
    "    use_joint_image=True,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "correct_evaluations = get_resolved_problems(answers_file)\n",
    "print(f\"{model.get_name()} scored {len(correct_evaluations)} correct answers: {correct_evaluations}\")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Binary classification: incorrect answers as labels"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "experiment_dir = \"../data/processed/bongard/experiments/openworld_binary-classification_joint-incorrect-answers\""
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "offset = 20\n",
    "circular_buffer_size = 100"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### GPT-4"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = GPTMessenger(\n",
    "    model_name=\"gpt-4o\",\n",
    "    api_key=os.environ[\"OPENAI_API_KEY\"],\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "perfect_resolve_attempt = get_wrong_resolve_attempt(\n",
    "    list(range(1, 101)),\n",
    "    \"../data/raw/bongard_open_world_labels.csv\",\n",
    "    answers_file,\n",
    "    offset=offset,\n",
    "    circular_buffer_size=circular_buffer_size,\n",
    ")\n",
    "perfect_resolve_attempt.save(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "check_solution_correctness(\n",
    "    answers_file,\n",
    "    \"../data/raw/bongard_open_world_splitted\",\n",
    "    model,\n",
    "    description=CommonPrompts.PROBLEM_DESCRIPTION,\n",
    "    question=IsLabelCorrectPrompts.QUESTION,\n",
    "    use_joint_image=True,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "correct_evaluations = get_unresolved_problems(answers_file)\n",
    "print(f\"{model.get_name()} scored {len(correct_evaluations)} correct answers: {correct_evaluations}\")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = GPTMessenger(\n",
    "    model_name=\"gpt-4-turbo\",\n",
    "    api_key=os.environ[\"OPENAI_API_KEY\"],\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "perfect_resolve_attempt = get_wrong_resolve_attempt(\n",
    "    list(range(1, 101)),\n",
    "    \"../data/raw/bongard_open_world_labels.csv\",\n",
    "    answers_file,\n",
    "    offset=offset,\n",
    "    circular_buffer_size=circular_buffer_size,\n",
    ")\n",
    "perfect_resolve_attempt.save(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "check_solution_correctness(\n",
    "    answers_file,\n",
    "    \"../data/raw/bongard_open_world_splitted\",\n",
    "    model,\n",
    "    description=CommonPrompts.PROBLEM_DESCRIPTION,\n",
    "    question=IsLabelCorrectPrompts.QUESTION,\n",
    "    use_joint_image=True,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "correct_evaluations = get_unresolved_problems(answers_file)\n",
    "print(f\"{model.get_name()} scored {len(correct_evaluations)} correct answers: {correct_evaluations}\")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Gemini 1.5 Pro"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = GoogleMessenger(\n",
    "    api_key=os.environ[\"GOOGLE_API_KEY\"],\n",
    "    model_name=\"gemini-1.5-pro\",\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "perfect_resolve_attempt = get_wrong_resolve_attempt(\n",
    "    list(range(1, 101)),\n",
    "    \"../data/raw/bongard_open_world_labels.csv\",\n",
    "    answers_file,\n",
    "    offset=offset,\n",
    "    circular_buffer_size=circular_buffer_size,\n",
    ")\n",
    "perfect_resolve_attempt.save(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "check_solution_correctness(\n",
    "    answers_file,\n",
    "    \"../data/raw/bongard_open_world_splitted\",\n",
    "    model,\n",
    "    description=CommonPrompts.PROBLEM_DESCRIPTION,\n",
    "    question=IsLabelCorrectPrompts.QUESTION,\n",
    "    use_joint_image=True,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "correct_evaluations = get_unresolved_problems(answers_file)\n",
    "print(f\"{model.get_name()} scored {len(correct_evaluations)} correct answers: {correct_evaluations}\")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Claude 3.5 Sonnet"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = ClaudeMessenger(\n",
    "    api_key=os.environ[\"ANTHROPIC_API_KEY\"],\n",
    "    model_name=\"claude-3-5-sonnet-20240620\",\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "perfect_resolve_attempt = get_wrong_resolve_attempt(\n",
    "    list(range(1, 101)),\n",
    "    \"../data/raw/bongard_open_world_labels.csv\",\n",
    "    answers_file,\n",
    "    offset=offset,\n",
    "    circular_buffer_size=circular_buffer_size,\n",
    ")\n",
    "perfect_resolve_attempt.save(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "check_solution_correctness(\n",
    "    answers_file,\n",
    "    \"../data/raw/bongard_open_world_splitted\",\n",
    "    model,\n",
    "    description=CommonPrompts.PROBLEM_DESCRIPTION,\n",
    "    question=IsLabelCorrectPrompts.QUESTION,\n",
    "    use_joint_image=True,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "correct_evaluations = get_unresolved_problems(answers_file)\n",
    "print(f\"{model.get_name()} scored {len(correct_evaluations)} correct answers: {correct_evaluations}\")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Bongard HOI"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Binary classification: correct answers as labels"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "experiment_dir = \"../data/processed/bongard/experiments/hoi_binary-classification_joint-correct-answers\""
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### GPT-4"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = GPTMessenger(\n",
    "    model_name=\"gpt-4o\",\n",
    "    api_key=os.environ[\"OPENAI_API_KEY\"],\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "perfect_resolve_attempt = get_perfect_resolve_attempt(\n",
    "    list(range(1, 101)),\n",
    "    \"../data/raw/bongard_hoi_mix_labels.csv\",\n",
    "    answers_file,\n",
    ")\n",
    "perfect_resolve_attempt.save(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "check_solution_correctness(\n",
    "    answers_file,\n",
    "    \"../data/raw/bongard_hoi_splitted_mix\",\n",
    "    model,\n",
    "    description=CommonPrompts.PROBLEM_DESCRIPTION,\n",
    "    question=IsLabelCorrectPrompts.QUESTION,\n",
    "    use_joint_image=True,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "correct_evaluations = get_resolved_problems(answers_file)\n",
    "print(f\"{model.get_name()} scored {len(correct_evaluations)} correct answers: {correct_evaluations}\")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = GPTMessenger(\n",
    "    model_name=\"gpt-4-turbo\",\n",
    "    api_key=os.environ[\"OPENAI_API_KEY\"],\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "perfect_resolve_attempt = get_perfect_resolve_attempt(\n",
    "    list(range(1, 101)),\n",
    "    \"../data/raw/bongard_hoi_mix_labels.csv\",\n",
    "    answers_file,\n",
    ")\n",
    "perfect_resolve_attempt.save(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "check_solution_correctness(\n",
    "    answers_file,\n",
    "    \"../data/raw/bongard_hoi_splitted_mix\",\n",
    "    model,\n",
    "    description=CommonPrompts.PROBLEM_DESCRIPTION,\n",
    "    question=IsLabelCorrectPrompts.QUESTION,\n",
    "    use_joint_image=True,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "correct_evaluations = get_resolved_problems(answers_file)\n",
    "print(f\"{model.get_name()} scored {len(correct_evaluations)} correct answers: {correct_evaluations}\")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Gemini 1.5 Pro"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = GoogleMessenger(\n",
    "    api_key=os.environ[\"GOOGLE_API_KEY\"],\n",
    "    model_name=\"gemini-1.5-pro\",\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "perfect_resolve_attempt = get_perfect_resolve_attempt(\n",
    "    list(range(1, 101)),\n",
    "    \"../data/raw/bongard_hoi_mix_labels.csv\",\n",
    "    answers_file,\n",
    ")\n",
    "perfect_resolve_attempt.save(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "check_solution_correctness(\n",
    "    answers_file,\n",
    "    \"../data/raw/bongard_hoi_splitted_mix\",\n",
    "    model,\n",
    "    description=CommonPrompts.PROBLEM_DESCRIPTION,\n",
    "    question=IsLabelCorrectPrompts.QUESTION,\n",
    "    use_joint_image=True,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "correct_evaluations = get_resolved_problems(answers_file)\n",
    "print(f\"{model.get_name()} scored {len(correct_evaluations)} correct answers: {correct_evaluations}\")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Claude 3.5 Sonnet"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = ClaudeMessenger(\n",
    "    api_key=os.environ[\"ANTHROPIC_API_KEY\"],\n",
    "    model_name=\"claude-3-5-sonnet-20240620\",\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "perfect_resolve_attempt = get_perfect_resolve_attempt(\n",
    "    list(range(1, 101)),\n",
    "    \"../data/raw/bongard_hoi_mix_labels.csv\",\n",
    "    answers_file,\n",
    ")\n",
    "perfect_resolve_attempt.save(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "check_solution_correctness(\n",
    "    answers_file,\n",
    "    \"../data/raw/bongard_hoi_splitted_mix\",\n",
    "    model,\n",
    "    description=CommonPrompts.PROBLEM_DESCRIPTION,\n",
    "    question=IsLabelCorrectPrompts.QUESTION,\n",
    "    use_joint_image=True,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "correct_evaluations = get_resolved_problems(answers_file)\n",
    "print(f\"{model.get_name()} scored {len(correct_evaluations)} correct answers: {correct_evaluations}\")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Binary classification: incorrect answers as labels"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "experiment_dir = \"../data/processed/bongard/experiments/hoi_binary-classification_joint-incorrect-answers\""
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "offset = 20\n",
    "circular_buffer_size = 100"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### GPT-4"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = GPTMessenger(\n",
    "    model_name=\"gpt-4o\",\n",
    "    api_key=os.environ[\"OPENAI_API_KEY\"],\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "perfect_resolve_attempt = get_wrong_resolve_attempt(\n",
    "    list(range(1, 101)),\n",
    "    \"../data/raw/bongard_hoi_mix_labels.csv\",\n",
    "    answers_file,\n",
    "    offset=offset,\n",
    "    circular_buffer_size=circular_buffer_size,\n",
    ")\n",
    "perfect_resolve_attempt.save(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "check_solution_correctness(\n",
    "    answers_file,\n",
    "    \"../data/raw/bongard_hoi_splitted_mix\",\n",
    "    model,\n",
    "    description=CommonPrompts.PROBLEM_DESCRIPTION,\n",
    "    question=IsLabelCorrectPrompts.QUESTION,\n",
    "    use_joint_image=True,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "correct_evaluations = get_unresolved_problems(answers_file)\n",
    "print(f\"{model.get_name()} scored {len(correct_evaluations)} correct answers: {correct_evaluations}\")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = GPTMessenger(\n",
    "    model_name=\"gpt-4-turbo\",\n",
    "    api_key=os.environ[\"OPENAI_API_KEY\"],\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "perfect_resolve_attempt = get_wrong_resolve_attempt(\n",
    "    list(range(1, 101)),\n",
    "    \"../data/raw/bongard_hoi_mix_labels.csv\",\n",
    "    answers_file,\n",
    "    offset=offset,\n",
    "    circular_buffer_size=circular_buffer_size,\n",
    ")\n",
    "perfect_resolve_attempt.save(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "check_solution_correctness(\n",
    "    answers_file,\n",
    "    \"../data/raw/bongard_hoi_splitted_mix\",\n",
    "    model,\n",
    "    description=CommonPrompts.PROBLEM_DESCRIPTION,\n",
    "    question=IsLabelCorrectPrompts.QUESTION,\n",
    "    use_joint_image=True,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "correct_evaluations = get_unresolved_problems(answers_file)\n",
    "print(f\"{model.get_name()} scored {len(correct_evaluations)} correct answers: {correct_evaluations}\")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Gemini 1.5 Pro"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = GoogleMessenger(\n",
    "    api_key=os.environ[\"GOOGLE_API_KEY\"],\n",
    "    model_name=\"gemini-1.5-pro\",\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "perfect_resolve_attempt = get_wrong_resolve_attempt(\n",
    "    list(range(1, 101)),\n",
    "    \"../data/raw/bongard_hoi_mix_labels.csv\",\n",
    "    answers_file,\n",
    "    offset=offset,\n",
    "    circular_buffer_size=circular_buffer_size,\n",
    ")\n",
    "perfect_resolve_attempt.save(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "check_solution_correctness(\n",
    "    answers_file,\n",
    "    \"../data/raw/bongard_hoi_splitted_mix\",\n",
    "    model,\n",
    "    description=CommonPrompts.PROBLEM_DESCRIPTION,\n",
    "    question=IsLabelCorrectPrompts.QUESTION,\n",
    "    use_joint_image=True,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "correct_evaluations = get_unresolved_problems(answers_file)\n",
    "print(f\"{model.get_name()} scored {len(correct_evaluations)} correct answers: {correct_evaluations}\")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Claude 3.5 Sonnet"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = ClaudeMessenger(\n",
    "    api_key=os.environ[\"ANTHROPIC_API_KEY\"],\n",
    "    model_name=\"claude-3-5-sonnet-20240620\",\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "perfect_resolve_attempt = get_wrong_resolve_attempt(\n",
    "    list(range(1, 101)),\n",
    "    \"../data/raw/bongard_hoi_mix_labels.csv\",\n",
    "    answers_file,\n",
    "    offset=offset,\n",
    "    circular_buffer_size=circular_buffer_size,\n",
    ")\n",
    "perfect_resolve_attempt.save(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "check_solution_correctness(\n",
    "    answers_file,\n",
    "    \"../data/raw/bongard_hoi_splitted_mix\",\n",
    "    model,\n",
    "    description=CommonPrompts.PROBLEM_DESCRIPTION,\n",
    "    question=IsLabelCorrectPrompts.QUESTION,\n",
    "    use_joint_image=True,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "correct_evaluations = get_unresolved_problems(answers_file)\n",
    "print(f\"{model.get_name()} scored {len(correct_evaluations)} correct answers: {correct_evaluations}\")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Bongard RWR"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "splitted_data_path = \"../data/raw/bongard_rwr\"\n",
    "labels_file = \"../data/raw/labels.csv\"\n",
    "rwr_problem_ids = {int(pid) for pid in os.listdir(splitted_data_path)}\n",
    "problem_ids = [pid for pid in list(range(1, 101)) if pid in rwr_problem_ids]"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Binary classification: correct answers as labels"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "experiment_dir = \"../data/processed/bongard/experiments/rwr_binary-classification_joint-correct-answers\""
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### GPT-4"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = GPTMessenger(\n",
    "    model_name=\"gpt-4o\",\n",
    "    api_key=os.environ[\"OPENAI_API_KEY\"],\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "perfect_resolve_attempt = get_perfect_resolve_attempt(\n",
    "    problem_ids,\n",
    "    labels_file,\n",
    "    answers_file,\n",
    ")\n",
    "perfect_resolve_attempt.save(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "check_solution_correctness(\n",
    "    answers_file,\n",
    "    splitted_data_path,\n",
    "    model,\n",
    "    description=CommonPrompts.PROBLEM_DESCRIPTION,\n",
    "    question=IsLabelCorrectPrompts.QUESTION,\n",
    "    use_joint_image=True,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "correct_evaluations = get_resolved_problems(answers_file)\n",
    "print(f\"{model.get_name()} scored {len(correct_evaluations)} correct answers: {correct_evaluations}\")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = GPTMessenger(\n",
    "    model_name=\"gpt-4-turbo\",\n",
    "    api_key=os.environ[\"OPENAI_API_KEY\"],\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "perfect_resolve_attempt = get_perfect_resolve_attempt(\n",
    "    problem_ids,\n",
    "    labels_file,\n",
    "    answers_file,\n",
    ")\n",
    "perfect_resolve_attempt.save(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "check_solution_correctness(\n",
    "    answers_file,\n",
    "    splitted_data_path,\n",
    "    model,\n",
    "    description=CommonPrompts.PROBLEM_DESCRIPTION,\n",
    "    question=IsLabelCorrectPrompts.QUESTION,\n",
    "    use_joint_image=True,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "correct_evaluations = get_resolved_problems(answers_file)\n",
    "print(f\"{model.get_name()} scored {len(correct_evaluations)} correct answers: {correct_evaluations}\")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Gemini 1.5 Pro"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = GoogleMessenger(\n",
    "    api_key=os.environ[\"GOOGLE_API_KEY\"],\n",
    "    model_name=\"gemini-1.5-pro\",\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "perfect_resolve_attempt = get_perfect_resolve_attempt(\n",
    "    problem_ids,\n",
    "    labels_file,\n",
    "    answers_file,\n",
    ")\n",
    "perfect_resolve_attempt.save(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "check_solution_correctness(\n",
    "    answers_file,\n",
    "    splitted_data_path,\n",
    "    model,\n",
    "    description=CommonPrompts.PROBLEM_DESCRIPTION,\n",
    "    question=IsLabelCorrectPrompts.QUESTION,\n",
    "    use_joint_image=True,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "correct_evaluations = get_resolved_problems(answers_file)\n",
    "print(f\"{model.get_name()} scored {len(correct_evaluations)} correct answers: {correct_evaluations}\")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Claude 3.5 Sonnet"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = ClaudeMessenger(\n",
    "    api_key=os.environ[\"ANTHROPIC_API_KEY\"],\n",
    "    model_name=\"claude-3-5-sonnet-20240620\",\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "perfect_resolve_attempt = get_perfect_resolve_attempt(\n",
    "    problem_ids,\n",
    "    labels_file,\n",
    "    answers_file,\n",
    ")\n",
    "perfect_resolve_attempt.save(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "check_solution_correctness(\n",
    "    answers_file,\n",
    "    splitted_data_path,\n",
    "    model,\n",
    "    description=CommonPrompts.PROBLEM_DESCRIPTION,\n",
    "    question=IsLabelCorrectPrompts.QUESTION,\n",
    "    use_joint_image=True,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "correct_evaluations = get_resolved_problems(answers_file)\n",
    "print(f\"{model.get_name()} scored {len(correct_evaluations)} correct answers: {correct_evaluations}\")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Binary classification: incorrect answers as labels"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "experiment_dir = \"../data/processed/bongard/experiments/rwr_binary-classification_joint-incorrect-answers\""
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "offset = 20\n",
    "circular_buffer_size = 100"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### GPT-4"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = GPTMessenger(\n",
    "    model_name=\"gpt-4o\",\n",
    "    api_key=os.environ[\"OPENAI_API_KEY\"],\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "perfect_resolve_attempt = get_wrong_resolve_attempt(\n",
    "    problem_ids,\n",
    "    labels_file,\n",
    "    answers_file,\n",
    "    offset=offset,\n",
    "    circular_buffer_size=circular_buffer_size,\n",
    ")\n",
    "perfect_resolve_attempt.save(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "check_solution_correctness(\n",
    "    answers_file,\n",
    "    splitted_data_path,\n",
    "    model,\n",
    "    description=CommonPrompts.PROBLEM_DESCRIPTION,\n",
    "    question=IsLabelCorrectPrompts.QUESTION,\n",
    "    use_joint_image=True,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "correct_evaluations = get_unresolved_problems(answers_file)\n",
    "print(f\"{model.get_name()} scored {len(correct_evaluations)} correct answers: {correct_evaluations}\")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = GPTMessenger(\n",
    "    model_name=\"gpt-4-turbo\",\n",
    "    api_key=os.environ[\"OPENAI_API_KEY\"],\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "perfect_resolve_attempt = get_wrong_resolve_attempt(\n",
    "    problem_ids,\n",
    "    labels_file,\n",
    "    answers_file,\n",
    "    offset=offset,\n",
    "    circular_buffer_size=circular_buffer_size,\n",
    ")\n",
    "perfect_resolve_attempt.save(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "check_solution_correctness(\n",
    "    answers_file,\n",
    "    splitted_data_path,\n",
    "    model,\n",
    "    description=CommonPrompts.PROBLEM_DESCRIPTION,\n",
    "    question=IsLabelCorrectPrompts.QUESTION,\n",
    "    use_joint_image=True,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "correct_evaluations = get_unresolved_problems(answers_file)\n",
    "print(f\"{model.get_name()} scored {len(correct_evaluations)} correct answers: {correct_evaluations}\")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Gemini 1.5 Pro"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = GoogleMessenger(\n",
    "    api_key=os.environ[\"GOOGLE_API_KEY\"],\n",
    "    model_name=\"gemini-1.5-pro\",\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "perfect_resolve_attempt = get_wrong_resolve_attempt(\n",
    "    problem_ids,\n",
    "    labels_file,\n",
    "    answers_file,\n",
    "    offset=offset,\n",
    "    circular_buffer_size=circular_buffer_size,\n",
    ")\n",
    "perfect_resolve_attempt.save(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "check_solution_correctness(\n",
    "    answers_file,\n",
    "    splitted_data_path,\n",
    "    model,\n",
    "    description=CommonPrompts.PROBLEM_DESCRIPTION,\n",
    "    question=IsLabelCorrectPrompts.QUESTION,\n",
    "    use_joint_image=True,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "correct_evaluations = get_unresolved_problems(answers_file)\n",
    "print(f\"{model.get_name()} scored {len(correct_evaluations)} correct answers: {correct_evaluations}\")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Claude 3.5 Sonnet"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "model = ClaudeMessenger(\n",
    "    api_key=os.environ[\"ANTHROPIC_API_KEY\"],\n",
    "    model_name=\"claude-3-5-sonnet-20240620\",\n",
    "    log_directory=experiment_dir,\n",
    ")\n",
    "answers_file = f\"{experiment_dir}/{model.get_name()}_answers.json\"\n",
    "answers_file"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "perfect_resolve_attempt = get_wrong_resolve_attempt(\n",
    "    problem_ids,\n",
    "    labels_file,\n",
    "    answers_file,\n",
    "    offset=offset,\n",
    "    circular_buffer_size=circular_buffer_size,\n",
    ")\n",
    "perfect_resolve_attempt.save(answers_file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "check_solution_correctness(\n",
    "    answers_file,\n",
    "    splitted_data_path,\n",
    "    model,\n",
    "    description=CommonPrompts.PROBLEM_DESCRIPTION,\n",
    "    question=IsLabelCorrectPrompts.QUESTION,\n",
    "    use_joint_image=True,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "correct_evaluations = get_unresolved_problems(answers_file)\n",
    "print(f\"{model.get_name()} scored {len(correct_evaluations)} correct answers: {correct_evaluations}\")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [],
   "outputs": [],
   "execution_count": null
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
