{
 "cells": [
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "import os.path\n",
    "from collections import defaultdict\n",
    "from functools import partial\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from pandas.io.formats import style\n",
    "from tqdm import tqdm\n",
    "\n",
    "from src.bongard_problems.classes import BongardResolveAttempt, ClassificationAttempt\n",
    "from src.evaluation.model import count_voting_solutions"
   ],
   "id": "103fa92a4a702497",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "experiment_dir = \"../data/processed/bongard/experiments\"\n",
    "paid_models = [\n",
    "    \"gpt-4o\",\n",
    "    \"gpt-4-turbo\",\n",
    "    \"gemini-1.5-pro\",\n",
    "    \"claude-3-5-sonnet-20240620\",\n",
    "]\n",
    "models = [\n",
    "    *paid_models,\n",
    "    # \"openbmb/MiniCPM-V-2_6\",\n",
    "    # \"lmms-lab/llava-onevision-qwen2-7b-ov\",\n",
    "    \"OpenGVLab/InternVL2-8B\",\n",
    "    \"llava-hf/llava-v1.6-mistral-7b-hf\",\n",
    "    \"microsoft/Phi-3.5-vision-instruct\",\n",
    "    \"mistralai/Pixtral-12B-2409\",\n",
    "]"
   ],
   "id": "39f4c0889dda5bfa",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## Ensure all problems were evaluated",
   "id": "2c740ecaf4a0c543"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": "rwr_problem_ids = {int(pid) for pid in os.listdir(\"../data/raw/bongard_rwr\")}",
   "id": "c2947a3478d1de50",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "methods = [\n",
    "    \"generation_prompting-direct\",\n",
    "    \"generation_prompting-descriptive\",\n",
    "    \"generation_prompting-iterative\",\n",
    "    \"generation_prompting-descriptive-direct\",\n",
    "    \"generation_prompting-contrastive\",\n",
    "    \"generation_prompting-iterative-contrastive\",\n",
    "    \"generation_prompting-contrastive-direct\",\n",
    "    # \"binary-classification_correct-answers\",\n",
    "    # \"binary-classification_incorrect-answers\",\n",
    "    # \"binary-classification_image-to-side\",\n",
    "    \"binary-classification_joint-correct-answers\",\n",
    "    \"binary-classification_joint-incorrect-answers\",\n",
    "    \"binary-classification_joint-image-to-side\",\n",
    "]\n",
    "for dataset in [\"synthetic\", \"hoi\", \"openworld\", \"rwr\"]:\n",
    "    for method in methods:\n",
    "        predicate = lambda evaluation: (\"WRONG\" if \"incorrect\" in method else \"OK\") in evaluation\n",
    "        for model in models:\n",
    "            answers_file = f\"{experiment_dir}/{dataset}_{method}/{model}_answers.json\"\n",
    "            if os.path.exists(answers_file):\n",
    "                resolve_attempt = BongardResolveAttempt.from_file(answers_file)\n",
    "                if \"image-to-side\" in method:\n",
    "                    authors = [\"oracle\"]\n",
    "                elif \"binary-classification\" in method:\n",
    "                    authors = [model]\n",
    "                else:\n",
    "                    authors = [f\"{model}_STRICT_LOGIC_PROMPT\" for model in paid_models]\n",
    "                for author in authors:\n",
    "                    evaluations = resolve_attempt.get_evaluations(author, ignore_missing=True)\n",
    "                    evaluated_problem_ids = {\n",
    "                        int(problem_id)\n",
    "                        for problem_id, answer in evaluations.items()\n",
    "                        if answer is not None and answer != \"\"\n",
    "                    }\n",
    "                    expected_problem_ids = rwr_problem_ids if dataset == \"rwr\" else set(range(1, 101))\n",
    "                    missing_problem_ids = expected_problem_ids - evaluated_problem_ids\n",
    "                    if len(missing_problem_ids) > 0:\n",
    "                        print(f\"{dataset} {method} {model} - {len(missing_problem_ids)} problems not evaluated by {author}: {missing_problem_ids}\")\n",
    "            else:\n",
    "                print(f\"{dataset} {method} {model} - No answers file\")"
   ],
   "id": "7e0c990ed0f4f5d7",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## Load evaluations",
   "id": "3936076dce5cec43"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": "results = []",
   "id": "ea5849bee20599e1",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "### Load binary classification evaluations",
   "id": "8ee14e33d71544c1"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "methods = [\n",
    "    \"binary-classification_joint-correct-answers\",\n",
    "    \"binary-classification_joint-incorrect-answers\",\n",
    "]\n",
    "for dataset in [\"synthetic\", \"hoi\", \"openworld\", \"rwr\"]:\n",
    "    for method in tqdm(methods, desc=dataset):\n",
    "        predicate = lambda evaluation: evaluation is not None and (\n",
    "            \"WRONG\" if \"incorrect\" in method else \"OK\") in evaluation\n",
    "        for model in models:\n",
    "            answers_file = f\"{experiment_dir}/{dataset}_{method}/{model}_answers.json\"\n",
    "            resolve_attempt = BongardResolveAttempt.from_file(answers_file)\n",
    "            evaluations = resolve_attempt.get_evaluations(model, ignore_missing=True).values()\n",
    "            results.append({\n",
    "                \"model\": model,\n",
    "                \"method\": method,\n",
    "                \"num_correct_answers\": sum([predicate(e) for e in evaluations]),\n",
    "                \"num_all_answers\": len([e for e in evaluations if e is not None]),\n",
    "                \"dataset\": dataset,\n",
    "            })"
   ],
   "id": "59760a7385071cb",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "methods = [\n",
    "    \"binary-classification_joint-image-to-side\",\n",
    "    \"binary-classification_joint-image-to-side-shuffle\",\n",
    "]\n",
    "for dataset in [\"synthetic\", \"hoi\", \"openworld\", \"rwr\"]:\n",
    "    for method in tqdm(methods, desc=dataset):\n",
    "        predicate = lambda evaluation: evaluation is not None and \"OK\" in evaluation\n",
    "        for model in models:\n",
    "            answers_file = f\"{experiment_dir}/{dataset}_{method}/{model}_answers.json\"\n",
    "            attempt = ClassificationAttempt.from_file(answers_file)\n",
    "            num_correct_answers = defaultdict(int)\n",
    "            for (problem_id, side), answer in attempt.get_evaluations().items():\n",
    "                if predicate(answer):\n",
    "                    num_correct_answers[problem_id] += 1\n",
    "            results.append({\n",
    "                \"model\": model,\n",
    "                \"method\": method,\n",
    "                \"num_correct_answers\": len([n for n in num_correct_answers.values() if n == 2]),\n",
    "                \"num_all_answers\": len(num_correct_answers),\n",
    "                \"dataset\": dataset,\n",
    "            })"
   ],
   "id": "67dd575fa15e667d",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "### Load generation evaluations",
   "id": "54cbdb6b7e5828ab"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "methods = [\n",
    "    \"generation_prompting-direct\",\n",
    "    \"generation_prompting-descriptive\",\n",
    "    \"generation_prompting-iterative\",\n",
    "    \"generation_prompting-descriptive-direct\",\n",
    "    \"generation_prompting-contrastive\",\n",
    "    \"generation_prompting-iterative-contrastive\",\n",
    "    \"generation_prompting-contrastive-direct\",\n",
    "]\n",
    "authors = [\n",
    "    \"gpt-4-turbo_STRICT_LOGIC_PROMPT\",\n",
    "    \"gpt-4o_STRICT_LOGIC_PROMPT\",\n",
    "    \"gemini-1.5-pro_STRICT_LOGIC_PROMPT\",\n",
    "    \"claude-3-5-sonnet-20240620_STRICT_LOGIC_PROMPT\",\n",
    "]\n",
    "voting_threshold = 2\n",
    "for dataset in [\"synthetic\", \"hoi\", \"openworld\", \"rwr\"]:\n",
    "    for method in tqdm(methods, desc=dataset):\n",
    "        predicate = lambda evaluation: \"OK\" in evaluation\n",
    "        for model in models:\n",
    "            answers_file = f\"{experiment_dir}/{dataset}_{method}/{model}_answers.json\"\n",
    "            num_all_problems, num_solved_problems = count_voting_solutions(answers_file, authors, predicate, voting_threshold, ignore_missing=True)\n",
    "            results.append({\n",
    "                \"model\": model,\n",
    "                \"method\": method,\n",
    "                \"num_correct_answers\": num_solved_problems,\n",
    "                \"num_all_answers\": num_all_problems,\n",
    "                \"dataset\": dataset,\n",
    "            })"
   ],
   "id": "e9aaeecacc3b568a",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "df = pd.DataFrame(results)\n",
    "df"
   ],
   "id": "df6d77cb22c04676",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "d = df.query(\"dataset == 'synthetic'\")\n",
    "d = d.pivot_table(index=\"method\", columns=\"model\", values=[\"num_correct_answers\"])\n",
    "d.columns = [column[1] for column in d.columns]\n",
    "d[models]"
   ],
   "id": "86cd051de183219f",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "d = df.query(\"dataset == 'synthetic'\")\n",
    "d = d.pivot_table(index=\"method\", columns=\"model\", values=[\"num_all_answers\"])\n",
    "d.columns = [column[1] for column in d.columns]\n",
    "d[models]"
   ],
   "id": "cbbc18facbfb88be",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "d = df.query(\"dataset == 'hoi'\")\n",
    "d = d.pivot_table(index=\"method\", columns=\"model\", values=[\"num_correct_answers\"])\n",
    "d.columns = [column[1] for column in d.columns]\n",
    "d[models]"
   ],
   "id": "cd0249fc8fc0dee5",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "d = df.query(\"dataset == 'hoi'\")\n",
    "d = d.pivot_table(index=\"method\", columns=\"model\", values=[\"num_all_answers\"])\n",
    "d.columns = [column[1] for column in d.columns]\n",
    "d[models]"
   ],
   "id": "d51cffe7eaab9a1b",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "d = df.query(\"dataset == 'openworld'\")\n",
    "d = d.pivot_table(index=\"method\", columns=\"model\", values=[\"num_correct_answers\"])\n",
    "d.columns = [column[1] for column in d.columns]\n",
    "d[models]"
   ],
   "id": "6acd3777fac9d137",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "d = df.query(\"dataset == 'openworld'\")\n",
    "d = d.pivot_table(index=\"method\", columns=\"model\", values=[\"num_all_answers\"])\n",
    "d.columns = [column[1] for column in d.columns]\n",
    "d[models]"
   ],
   "id": "d60bbbcadbb2d144",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "d = df.query(\"dataset == 'rwr'\")\n",
    "d = d.pivot_table(index=\"method\", columns=\"model\", values=[\"num_correct_answers\"])\n",
    "d.columns = [column[1] for column in d.columns]\n",
    "d[models]"
   ],
   "id": "93b9fadc7b6e5d52",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "d = df.query(\"dataset == 'rwr'\")\n",
    "d = d.pivot_table(index=\"method\", columns=\"model\", values=[\"num_all_answers\"])\n",
    "d.columns = [column[1] for column in d.columns]\n",
    "d[models]"
   ],
   "id": "ae0cff0e1d9143b4",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## Render LaTeX tables",
   "id": "6259a2c5324a9cdc"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "model_names = {\n",
    "    \"gpt-4o\": \"GPT-4 Omni\",\n",
    "    \"gpt-4-turbo\": \"GPT-4 Turbo\",\n",
    "    \"gemini-1.5-pro\": \"Gemini 1.5 Pro\",\n",
    "    \"claude-3-5-sonnet-20240620\": \"Claude 3.5 Sonnet\",\n",
    "    \"OpenGVLab/InternVL2-8B\": \"InternVL2 8B\",\n",
    "    \"llava-hf/llava-v1.6-mistral-7b-hf\": \"LLaVa Mistral 7B\",\n",
    "    \"microsoft/Phi-3.5-vision-instruct\": \"Phi 3.5V\",\n",
    "    \"mistralai/Pixtral-12B-2409\": \"Pixtral 12B\",\n",
    "}\n",
    "\n",
    "method_names = {\n",
    "    \"binary-classification_joint-correct-answers\": \"Ground-truth\",\n",
    "    \"binary-classification_joint-incorrect-answers\": \"Incorrect label\",\n",
    "    \"binary-classification_joint-image-to-side\": \"Images to sides - prev\",\n",
    "    \"binary-classification_joint-image-to-side-shuffle\": \"Images to sides\",\n",
    "    \"generation_prompting-direct\": \"Direct\",\n",
    "    \"generation_prompting-descriptive\": \"Descriptive\",\n",
    "    \"generation_prompting-iterative\": \"Descriptive-iter.\",\n",
    "    \"generation_prompting-descriptive-direct\": \"Descriptive-direct\",\n",
    "    \"generation_prompting-contrastive\": \"Contrastive\",\n",
    "    \"generation_prompting-iterative-contrastive\": \"Contrastive-iter.\",\n",
    "    \"generation_prompting-contrastive-direct\": \"Contrastive-direct\",\n",
    "}\n",
    "captions = {\n",
    "    \"synthetic\": \"\\\\textbf{Synthetic BPs.} The table presents the number of correct answers to the first 100 hand-crafted BPs. The best result across models evaluated with a given strategy is marked in bold and the second best is underlined.\",\n",
    "    \"hoi\": \"\\\\textbf{Bongard HOI.} The table presents the number of correct answers to the selected 100 problems from the Bongard HOI dataset.\",\n",
    "    \"openworld\": \"\\\\textbf{Bongard-OpenWorld.} The table presents the number of correct answers to the selected 100 problems from the Bongard-OpenWorld dataset.\",\n",
    "    \"rwr\": \"\\\\textbf{Bongard-rwr.} The table presents the number of correct answers to all of the 60 instances of Bongard-rwr.\",\n",
    "}"
   ],
   "id": "c5a37d684ed78ed0",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "def format(x: int, max_value: int, second_max_value: int) -> str:\n",
    "    if x == max_value:\n",
    "        return f\"$\\\\textbf{{{x}}}$\"\n",
    "    elif x == second_max_value:\n",
    "        return f\"$\\\\underline{{{x}}}$\"\n",
    "    else:\n",
    "        return f\"${x}$\"\n",
    "\n",
    "def arrange(df: pd.DataFrame) -> pd.DataFrame:\n",
    "    df = df.pivot_table(index=\"method\", columns=\"model\", values=[\"num_correct_answers\"])\n",
    "    df.columns = [column[1] for column in df.columns]\n",
    "    df = df.rename(columns=model_names)\n",
    "    df = df.rename(index=method_names)\n",
    "    df = df.loc[[v for v in method_names.values() if v in df.index]]\n",
    "    df.index.name = None\n",
    "    df[list(model_names.values())] = df[list(model_names.values())].astype(int)\n",
    "    df = df[model_names.values()]\n",
    "    for row_idx in range(len(df)):\n",
    "        x = df.iloc[row_idx]\n",
    "        max_value = np.max(x)\n",
    "        second_max_value = np.partition(x, -2)[-2]\n",
    "        df.iloc[row_idx] = x.apply(partial(format, max_value=max_value, second_max_value=second_max_value))\n",
    "    return df"
   ],
   "id": "c8541c8735dfe5ee",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## Synthetic",
   "id": "7a1f8408bc7bf11"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "dataset = \"synthetic\"\n",
    "d = df.query(f\"dataset == '{dataset}'\")\n",
    "d = arrange(d)\n",
    "style.Styler(d).to_latex(\n",
    "    f\"latex/{dataset}.tex\",\n",
    "    column_format=\"lcccccccc\",\n",
    "    caption=captions[dataset],\n",
    "    label=f\"tab:{dataset}\",\n",
    "    environment=\"table\",\n",
    "    position=\"t\",\n",
    "    position_float=\"centering\",\n",
    "    multicol_align=\"c\",\n",
    "    hrules=True,\n",
    ")\n",
    "d"
   ],
   "id": "1e1397b0b23af434",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## HOI",
   "id": "cbd067ee06d3b243"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "dataset = \"hoi\"\n",
    "d = df.query(f\"dataset == '{dataset}'\")\n",
    "d = arrange(d)\n",
    "style.Styler(d).to_latex(\n",
    "    f\"latex/{dataset}.tex\",\n",
    "    column_format=\"lcccccccc\",\n",
    "    caption=captions[dataset],\n",
    "    label=f\"tab:{dataset}\",\n",
    "    environment=\"table\",\n",
    "    position=\"t\",\n",
    "    position_float=\"centering\",\n",
    "    multicol_align=\"c\",\n",
    "    hrules=True,\n",
    ")\n",
    "d"
   ],
   "id": "7dc5c4ac1c58714c",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## OpenWorld",
   "id": "858fb6ccda699b91"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "dataset = \"openworld\"\n",
    "d = df.query(f\"dataset == '{dataset}'\")\n",
    "d = arrange(d)\n",
    "style.Styler(d).to_latex(\n",
    "    f\"latex/{dataset}.tex\",\n",
    "    column_format=\"lcccccccc\",\n",
    "    caption=captions[dataset],\n",
    "    label=f\"tab:{dataset}\",\n",
    "    environment=\"table\",\n",
    "    position=\"t\",\n",
    "    position_float=\"centering\",\n",
    "    multicol_align=\"c\",\n",
    "    hrules=True,\n",
    ")\n",
    "d"
   ],
   "id": "f902ae6ba4b8f77e",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## rwr",
   "id": "a3fd50c497d0e0ba"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "dataset = \"rwr\"\n",
    "d = df.query(f\"dataset == '{dataset}'\")\n",
    "d = arrange(d)\n",
    "style.Styler(d).to_latex(\n",
    "    f\"latex/{dataset}.tex\",\n",
    "    column_format=\"lcccccccc\",\n",
    "    caption=captions[dataset],\n",
    "    label=f\"tab:{dataset}\",\n",
    "    environment=\"table\",\n",
    "    position=\"t\",\n",
    "    position_float=\"centering\",\n",
    "    multicol_align=\"c\",\n",
    "    hrules=True,\n",
    ")\n",
    "d"
   ],
   "id": "9b663810f5323744",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": "",
   "id": "cb2807d5d52e4f62",
   "outputs": [],
   "execution_count": null
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
