{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import numpy as np\n",
    "import os\n",
    "from datasets import load_dataset\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Preprocess data "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = load_dataset('anonymous_research/MATH_with_LM_Judges_and_Reward_Model_Results_V2')['data']\n",
    "samples = []\n",
    "for i in range(len(dataset)):\n",
    "    example = dataset[i]\n",
    "    samples.append({\n",
    "        'samples': example['samples'],\n",
    "        'correctness': example.get('answer_correct', example.get('correctness', None)),\n",
    "        'extracted_answers': example.get('extracted_answers', None),\n",
    "        **{k: example[k] for k in example.keys()}\n",
    "    })"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Diversity metrics (we will expand on this)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.stats import pearsonr \n",
    "def get_pearson(samples, verifier_1, verifier_2):\n",
    "    verifier_1_scores = np.array([sample[verifier_1] for sample in samples]).flatten().astype(float)\n",
    "    verifier_2_scores = np.array([sample[verifier_2] for sample in samples]).flatten().astype(float)\n",
    "    return float(pearsonr(verifier_1_scores, verifier_2_scores)[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Visualize results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_dir = \"./results/\"\n",
    "\n",
    "all_results = []\n",
    "for file in os.listdir(results_dir):\n",
    "    \n",
    "    with open(results_dir + file, \"r\") as f:\n",
    "        all_results.append(json.load(f))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for result in all_results:\n",
    "    verifier_1, verifier_2 = result['verifiers']\n",
    "    result['pearson'] = get_pearson(samples, verifier_1, verifier_2)\n",
    "    individual_accuracies = [perf['accuracy'] for judge, perf in result['lm_judges'].items()] + [perf['accuracy'] for rm, perf in result['reward_models'].items()]\n",
    "    result['average_verifier_accuracy'] = np.array(individual_accuracies).mean()\n",
    "    result['best_verifier_accuracy'] = max(individual_accuracies)\n",
    "    result['naive_ensemble_improvement_over_avg'] = result['baselines']['highest_joint_score']['accuracy'] - result['average_verifier_accuracy']\n",
    "    result['search_ensemble_improvement_over_avg'] = result['baselines']['search_weighted_ensemble']['accuracy'] - result['average_verifier_accuracy']\n",
    "    result['naive_ensemble_improvement_over_best'] = result['baselines']['highest_joint_score']['accuracy'] - result['best_verifier_accuracy']\n",
    "    result['search_ensemble_improvement_over_best'] = result['baselines']['search_weighted_ensemble']['accuracy'] - result['best_verifier_accuracy']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = [result['pearson'] for result in all_results]\n",
    "y = [result['search_ensemble_improvement_over_avg'] for result in all_results]\n",
    "\n",
    "print(f\"Min improvement: {min(y)}\")\n",
    "\n",
    "plt.scatter(x, y)\n",
    "plt.xlabel(\"Pearson correlation between verifiers\")\n",
    "plt.ylabel(\"Grid search weighted ensemble acc - avg verifier acc\", fontsize=8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = [result['pearson'] for result in all_results]\n",
    "y = [result['search_ensemble_improvement_over_best'] for result in all_results]\n",
    "\n",
    "print(f\"Min improvement: {min(y)}\")\n",
    "\n",
    "plt.scatter(x, y)\n",
    "plt.xlabel(\"Pearson correlation between verifiers\")\n",
    "plt.ylabel(\"Grid search weighted ensemble acc - best verifier acc\", fontsize=8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = [result['pearson'] for result in all_results]\n",
    "y = [result['naive_ensemble_improvement_over_avg'] for result in all_results]\n",
    "\n",
    "print(f\"Min improvement: {min(y)}\")\n",
    "\n",
    "plt.scatter(x, y)\n",
    "plt.xlabel(\"Pearson correlation between verifiers\")\n",
    "plt.ylabel(\"Unweighted ensemble acc - avg verifier acc\", fontsize=8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = [result['pearson'] for result in all_results]\n",
    "y = [result['naive_ensemble_improvement_over_best'] for result in all_results]\n",
    "\n",
    "print(f\"Min improvement: {min(y)}\")\n",
    "\n",
    "plt.scatter(x, y)\n",
    "plt.xlabel(\"Pearson correlation between verifiers\")\n",
    "plt.ylabel(\"Unweighted ensemble acc - best verifier acc\", fontsize=8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = [result['average_verifier_accuracy'] for result in all_results]\n",
    "y = [result['naive_ensemble_improvement_over_avg'] for result in all_results]\n",
    "\n",
    "plt.scatter(x, y)\n",
    "plt.xlabel(\"Average verifier accuracy\")\n",
    "plt.ylabel(\"Unweighted ensemble acc - average verifier acc\", fontsize=8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = [result['best_verifier_accuracy'] for result in all_results]\n",
    "y = [result['naive_ensemble_improvement_over_best'] for result in all_results]\n",
    "\n",
    "plt.scatter(x, y)\n",
    "plt.xlabel(\"Best verifier accuracy\")\n",
    "plt.ylabel(\"Unweighted ensemble acc - average verifier acc\", fontsize=8)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mayeeenv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
