{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from tqdm import tqdm\n",
    "import torch\n",
    "\n",
    "from vla_calibration.utils import *\n",
    "from vla_calibration.calibration import *\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "\n",
    "plt.style.use('seaborn-v0_8')\n",
    "pal = plt.rcParams['axes.prop_cycle'].by_key()['color']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def recal_experiment(\n",
    "    action_conf, \n",
    "    action_logits,\n",
    "    correct, \n",
    "    test_size=0.4,\n",
    "    n_trials = 100,\n",
    "    n_cal_bins = 10\n",
    "):\n",
    "    \n",
    "    uncal_eces = []\n",
    "    recal_eces = []\n",
    "    action_recal_eces = []\n",
    "    temp_scale_eces = []\n",
    "    \n",
    "    for trial_no in tqdm(range(n_trials)):\n",
    "\n",
    "        conf_train, conf_test, logits_train, logits_test, correct_train, correct_test \\\n",
    "              = train_test_split_three_way(action_conf, action_logits, correct, test_size=test_size, random_state=trial_no)\n",
    "        \n",
    "        mean_train_conf = np.mean(conf_train, -1)\n",
    "        mean_test_conf = np.mean(conf_test, -1)\n",
    "\n",
    "        uncal_ece = round(get_ece(mean_test_conf, correct_test, n_cal_bins), 3)\n",
    "\n",
    "        scaler = PlattScaler(max_iter=200, tol=1e-6)\n",
    "        scaler.fit(mean_train_conf, correct_train)\n",
    "\n",
    "        calibrated_probs = scaler.predict(mean_test_conf)\n",
    "\n",
    "        recal_ece = round(get_ece(calibrated_probs, correct_test, n_cal_bins), 3)\n",
    "\n",
    "        action_scaler = ActionPlattScaler(max_iter=200, tol=1e-8, combine_method=\"mean\")\n",
    "        action_scaler.fit(conf_train, correct_train)\n",
    "\n",
    "        calibrated_probs = action_scaler.predict(conf_test)\n",
    "\n",
    "        action_recal_ece = round(get_ece(calibrated_probs, correct_test, n_cal_bins), 3)\n",
    "\n",
    "        temp_scaler = TempScaler()\n",
    "        temp_scaler.set_temperature(torch.Tensor(logits_train), torch.FloatTensor(correct_train))\n",
    "        temp_scaler.eval()\n",
    "        with torch.no_grad():\n",
    "            scaled_test_conf = temp_scaler.temperature_scale(torch.Tensor(logits_test).cuda()).cpu().numpy()\n",
    "        temp_scale_ece = round(get_ece(scaled_test_conf, correct_test, n_cal_bins), 3)\n",
    "\n",
    "\n",
    "        uncal_eces.append(uncal_ece)\n",
    "        recal_eces.append(recal_ece)\n",
    "        action_recal_eces.append(action_recal_ece)\n",
    "\n",
    "        temp_scale_eces.append(temp_scale_ece)\n",
    "        \n",
    "\n",
    "    print(f\"uncal ece: {np.mean(uncal_eces)} | recal ece: {np.mean(recal_eces)} | action recal ece: {np.mean(action_recal_eces)}\")\n",
    "    print(f\"temp scale ece: {np.mean(temp_scale_eces)}\")\n",
    "    return {\n",
    "        \"uncal_ece\": np.mean(uncal_eces), \"recal_ece\": np.mean(recal_eces), \"action_recal_ece\": np.mean(action_recal_eces),\n",
    "        \"temp_scale_ece\": np.mean(temp_scale_eces),\n",
    "        \"recal_ece_std\": np.std(recal_eces), \"action_recal_ece_std\": np.std(action_recal_eces),\n",
    "        \"temp_scale_std\": np.std(temp_scale_eces),\n",
    "        }\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_exp(task_name, quant, n_bins=12, test_size=0.8, n_prompts=20, n_trials=1000, alternate_set=1):\n",
    "\n",
    "\n",
    "    base_probs, _, base_logits, _, correct, by_dim_results = get_scaling_data(\n",
    "        task_name, \n",
    "        alternate_set=alternate_set, \n",
    "        n_prompts=n_prompts,\n",
    "        quant=quant,\n",
    "    )\n",
    "\n",
    "    base_conf = np.mean(np.max(base_probs, -1), -2)[:,0]\n",
    "\n",
    "    print(\"---------------------\\nBase Recalibration\")\n",
    "    baseline_results = recal_experiment(\n",
    "        base_conf, \n",
    "        base_logits[:,0,0],\n",
    "        correct, \n",
    "        test_size=test_size,\n",
    "        n_trials=n_trials,\n",
    "        n_cal_bins=n_bins,\n",
    "    )\n",
    "    print(baseline_results)\n",
    "\n",
    "    bar_colors = pal[1:4]\n",
    "    scale_factor = 1.25\n",
    "\n",
    "    fig, axs = plt.subplots(1,2, figsize=(10, 3.25), width_ratios=[0.45,0.55])\n",
    "\n",
    "    scores = by_dim_results[\"baseline\"]\n",
    "    X = np.arange(len(scores))\n",
    "    axs[0].bar(X, scores)\n",
    "    labels = [f\"{i+1}\" for i in range(len(by_dim_results[\"baseline\"]))]\n",
    "    axs[0].set_xticks(X, labels, fontsize=15)\n",
    "    axs[0].set_xlabel(\"Action Dimension\", fontsize=18)\n",
    "\n",
    "    scores = [baseline_results[\"temp_scale_ece\"], baseline_results[\"recal_ece\"], baseline_results[\"action_recal_ece\"]]\n",
    "    stds = [baseline_results[\"temp_scale_std\"], baseline_results[\"recal_ece_std\"], baseline_results[\"action_recal_ece_std\"]]\n",
    "    labels = [\"Temp\\nScaling\",\"Platt\\nScaling\",\"Action-Wise\\nPlatt Scaling\"]\n",
    "    X = np.arange(len(scores))\n",
    "    axs[1].bar(X, scores, color=bar_colors, yerr=stds/np.sqrt(n_trials), error_kw=dict(ecolor='dimgrey', lw=2, capsize=3, capthick=2))\n",
    "    axs[1].set_xticks(X, labels, fontsize=15)\n",
    "\n",
    "    for i in range(2):\n",
    "        axs[i].set_ylabel(r\"$\\text{ECE}_1$\", fontsize=18)\n",
    "        axs[i].tick_params(axis=\"y\", labelsize=13)\n",
    "\n",
    "    if quant is not None:\n",
    "        quant_tag = f\" ({quant})\"\n",
    "    else:\n",
    "        quant_tag = \"\"\n",
    "\n",
    "    axs[1].set_ylim(min(baseline_results[\"recal_ece\"], baseline_results[\"action_recal_ece\"])/scale_factor, None)\n",
    "\n",
    "    fig.suptitle(f\"{str.title(task_name)}{str.title(quant_tag)}\", fontsize=18, y=0.95)\n",
    "    fig.tight_layout()\n",
    "\n",
    "    quant_save_string = quant_tag.replace(\"(\",\"\").replace(\")\",\"\").strip()\n",
    "\n",
    "    plt.savefig(f\"../plots/action_scaling_{task_name}_{quant_save_string}_baseline_w_temp_scaling.png\", dpi=600, bbox_inches=\"tight\")\n",
    "    plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_bins = 10\n",
    "test_size = 0.8\n",
    "n_trials = 1000\n",
    "\n",
    "\n",
    "run_exp(\"spatial\", quant=None, n_bins=n_bins, test_size=test_size, n_trials=n_trials)\n",
    "run_exp(\"goal\", quant=None, n_bins=n_bins, test_size=test_size, n_trials=n_trials)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "run_exp(\"spatial\", quant=\"quant8\", n_bins=n_bins, test_size=test_size, n_trials=n_trials)\n",
    "run_exp(\"goal\", quant=\"quant8\", n_bins=n_bins, test_size=test_size, n_trials=n_trials)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "run_exp(\"spatial\", quant=\"quant4\", n_bins=n_bins, test_size=test_size, n_trials=n_trials)\n",
    "run_exp(\"goal\", quant=\"quant4\", n_bins=n_bins, test_size=test_size, n_trials=n_trials)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
