{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle as pkl\n",
    "from scipy.special import softmax\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "\n",
    "from vla_calibration.utils import *\n",
    "from vla_calibration.calibration import *\n",
    "\n",
    "plt.style.use('seaborn-v0_8')\n",
    "pal = plt.rcParams['axes.prop_cycle'].by_key()['color']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_experiment(\n",
    "        task_name, \n",
    "        quant=None,\n",
    "        alternate_set=1, \n",
    "        n_prompts=20, \n",
    "        n_cal_bins=12,\n",
    "):\n",
    "    \n",
    "    data_save_dir = f\"../results/libero_{task_name}\"\n",
    "    if quant is not None:\n",
    "        data_save_dir += f\"/{quant}\"\n",
    "\n",
    "    top_n_steps=1\n",
    "\n",
    "    base_probs, _, correct = get_base_data(data_save_dir, top_n_steps)\n",
    "\n",
    "    base_probs = np.expand_dims(base_probs, axis=2)\n",
    "\n",
    "    all_probs = []\n",
    "\n",
    "    for i in range(n_prompts):\n",
    "\n",
    "        prompt_probs = []\n",
    "\n",
    "        if alternate_set == 1:\n",
    "            data_save_str = f\"{data_save_dir}/episode_data_prompt_{i}.pkl\"\n",
    "        elif alternate_set == 2:\n",
    "            data_save_str = f\"{data_save_dir}/episode_data_prompt_{i}_v2.pkl\"\n",
    "        elif alternate_set == 3:\n",
    "            data_save_str = f\"{data_save_dir}/episode_data_prompt_{i}_v3.pkl\"\n",
    "        else:\n",
    "            raise ValueError\n",
    "\n",
    "        with open(data_save_str, \"rb\") as f:  \n",
    "            data = pkl.load(f)\n",
    "\n",
    "        for episode in data:\n",
    "\n",
    "            episode_probs = []\n",
    "\n",
    "            steps = episode[\"steps\"]\n",
    "\n",
    "            for step in steps[:top_n_steps]:\n",
    "\n",
    "                logits = step[\"logits\"]\n",
    "                probs = softmax(logits, -1)\n",
    "\n",
    "                episode_probs.append(probs)\n",
    "\n",
    "            episode_probs = np.stack(episode_probs)\n",
    "            prompt_probs.append(episode_probs)\n",
    "\n",
    "        prompt_probs = np.stack(prompt_probs)\n",
    "\n",
    "        all_probs.append(prompt_probs)\n",
    "\n",
    "    \n",
    "    all_probs = np.stack(all_probs)\n",
    "    ens_probs = np.transpose(all_probs, (1,2,0,3,4))\n",
    "\n",
    "    base_probs = base_probs[:,0]\n",
    "    ens_probs = ens_probs[:,0]\n",
    "\n",
    "\n",
    "    base_probs = np.max(base_probs, -1)\n",
    "    ens_probs = np.max(ens_probs, -1)\n",
    "\n",
    "    base_conf = np.mean(base_probs, -2)\n",
    "    ens_conf = np.mean(ens_probs, -2)\n",
    "\n",
    "    mean_base_conf = np.mean(base_conf, -1)\n",
    "    mean_ens_conf = np.mean(ens_conf, -1)\n",
    "\n",
    "    base_ece1 = round(get_ece(mean_base_conf, correct, n_cal_bins, p=1), 3)\n",
    "    ens_ece1 = round(get_ece(mean_ens_conf, correct, n_cal_bins, p=1), 3)\n",
    "\n",
    "    base_ece2 = round(get_ece(mean_base_conf, correct, n_cal_bins, p=2), 3)\n",
    "    ens_ece2 = round(get_ece(mean_ens_conf, correct, n_cal_bins, p=2), 3)\n",
    "\n",
    "    base_brier = round(np.mean((mean_base_conf - correct)**2), 3)\n",
    "    ens_brier = round(np.mean((mean_ens_conf - correct)**2), 3)\n",
    "\n",
    "    base_ce = cross_entropy(correct, mean_base_conf)\n",
    "    ens_ce = cross_entropy(correct, mean_ens_conf)\n",
    "\n",
    "    if quant is not None:\n",
    "        quant_tag = quant\n",
    "    else:\n",
    "        quant_tag = \"Full\"\n",
    "\n",
    "    base_row = [task_name, quant_tag, \"baseline\", base_ece1, base_ece2, base_brier, base_ce, np.mean(correct)]\n",
    "    ens_row = [task_name, quant_tag, \"reprompt\", ens_ece1, ens_ece2, ens_brier, ens_ce, np.mean(correct)]\n",
    "    rows = [base_row, ens_row]\n",
    "\n",
    "    df = pd.DataFrame(rows, columns=[\"Dataset\",\"Model\",\"Method\",\"ECE-1\",\"ECE-2\",\"Brier\",\"NLL\",\"Accuracy\"])\n",
    "    return df\n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_bins = 12\n",
    "alternate_set = 1\n",
    "\n",
    "full_df = pd.DataFrame()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = run_experiment(\n",
    "    \"spatial\", \n",
    "    alternate_set=alternate_set, \n",
    "    n_cal_bins=n_bins,\n",
    "    n_prompts=20\n",
    ")\n",
    "full_df = pd.concat([full_df, df])\n",
    "\n",
    "df = run_experiment(\n",
    "    \"object\", \n",
    "    alternate_set=alternate_set, \n",
    "    n_cal_bins=n_bins,\n",
    "    n_prompts=20\n",
    ")\n",
    "full_df = pd.concat([full_df, df])\n",
    "\n",
    "df = run_experiment(\n",
    "    \"goal\", \n",
    "    alternate_set=alternate_set,  \n",
    "    n_cal_bins=n_bins,\n",
    "    n_prompts=20\n",
    ")\n",
    "full_df = pd.concat([full_df, df])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "full_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = run_experiment(\n",
    "    \"spatial\", \n",
    "    alternate_set=alternate_set, \n",
    "    n_cal_bins=n_bins,\n",
    "    quant=\"quant8\",\n",
    "    n_prompts=20\n",
    ")\n",
    "full_df = pd.concat([full_df, df])\n",
    "\n",
    "df = run_experiment(\n",
    "    \"object\", \n",
    "    alternate_set=alternate_set, \n",
    "    n_cal_bins=n_bins,\n",
    "    quant=\"quant8\",\n",
    "    n_prompts=20\n",
    ")\n",
    "full_df = pd.concat([full_df, df])\n",
    "\n",
    "df = run_experiment(\n",
    "    \"goal\", \n",
    "    alternate_set=alternate_set, \n",
    "    n_cal_bins=n_bins,\n",
    "    quant=\"quant8\",\n",
    "    n_prompts=20\n",
    ")\n",
    "full_df = pd.concat([full_df, df])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "full_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = run_experiment(\n",
    "    \"spatial\", \n",
    "    alternate_set=alternate_set, \n",
    "    n_cal_bins=n_bins,\n",
    "    quant=\"quant4\",\n",
    "    n_prompts=20\n",
    ")\n",
    "full_df = pd.concat([full_df, df])\n",
    "\n",
    "df = run_experiment(\n",
    "    \"object\", \n",
    "    alternate_set=alternate_set, \n",
    "    n_cal_bins=n_bins,\n",
    "    quant=\"quant4\",\n",
    "    n_prompts=20\n",
    ")\n",
    "full_df = pd.concat([full_df, df])\n",
    "\n",
    "df = run_experiment(\n",
    "    \"goal\", \n",
    "    alternate_set=alternate_set, \n",
    "    n_cal_bins=n_bins,\n",
    "    quant=\"quant4\",\n",
    "    n_prompts=20\n",
    ")\n",
    "full_df = pd.concat([full_df, df])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "full_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "main_df = full_df[[\"Model\", \"Dataset\", \"Method\", \"ECE-1\", \"ECE-2\", \"Brier\", \"NLL\"]]\n",
    "main_df = main_df[main_df[\"Model\"] != \"quant4\"]\n",
    "display(main_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(main_df.to_latex(index=False, float_format=\"%.3f\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "app_df = full_df[[\"Model\", \"Dataset\", \"Method\", \"ECE-1\", \"ECE-2\", \"Brier\", \"NLL\"]]\n",
    "app_df = app_df[app_df[\"Model\"] == \"quant4\"]\n",
    "display(app_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(app_df.to_latex(index=False, float_format=\"%.3f\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tradeoff_df = full_df[full_df[\"Method\"] == \"reprompt\"]\n",
    "display(tradeoff_df)\n",
    "\n",
    "metrics_list = [\"ECE-1\", \"ECE-2\", \"Brier\", \"NLL\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "success_df = tradeoff_df[[\"Dataset\", \"Model\", \"Accuracy\"]]\n",
    "display(success_df)\n",
    "\n",
    "row1 = [\"Full\"] + success_df[success_df[\"Model\"] == \"Full\"][\"Accuracy\"].tolist()\n",
    "row2 = [\"Quant-8\"] + success_df[success_df[\"Model\"] == \"quant8\"][\"Accuracy\"].tolist()\n",
    "row3 = [\"Quant-4\"] + success_df[success_df[\"Model\"] == \"quant4\"][\"Accuracy\"].tolist()\n",
    "\n",
    "success_df = pd.DataFrame([row1, row2, row3], columns=[\"Model\", \"Spatial\", \"Object\", \"Goal\"])\n",
    "print(success_df.to_latex(index=False, float_format=\"%.3f\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "baseline_df = full_df[full_df[\"Method\"] == \"baseline\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "reprompt_df = full_df[full_df[\"Method\"] == \"reprompt\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "baseline_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(1,4, figsize=(13,3.25))\n",
    "\n",
    "for i, metric in enumerate(metrics_list):\n",
    "\n",
    "    coord = i\n",
    "    ax = axs[coord]\n",
    "\n",
    "    baseline_scores = baseline_df[metric].tolist()\n",
    "    reprompt_scores = reprompt_df[metric].tolist()\n",
    "    ax.scatter(reprompt_scores, baseline_scores, color=pal[i], s=70)\n",
    "    ax.plot([0,1],[0,1], \"--\", color=\"k\", alpha=0.5)\n",
    "\n",
    "    ax_min = min(min(baseline_scores),min(reprompt_scores))*0.95\n",
    "    ax_max = max(max(baseline_scores),max(reprompt_scores))*1.05\n",
    "\n",
    "    ax.set_xlim(ax_min, ax_max)\n",
    "    ax.set_ylim(ax_min, ax_max)\n",
    "\n",
    "    ax.set_title(metric, fontsize=18)\n",
    "    ax.set_xlabel(\"Reprompt\", fontsize=18)\n",
    "    \n",
    "    if i == 0:\n",
    "        ax.set_ylabel(\"Baseline\", fontsize=18)\n",
    "\n",
    "\n",
    "axs[0].set_xticks([0.05, 0.10,0.15])\n",
    "axs[0].set_yticks([0.05, 0.10,0.15])\n",
    "\n",
    "axs[1].set_xticks([0.05, 0.10,0.15,0.2])\n",
    "axs[1].set_yticks([0.05, 0.10,0.15,0.2])\n",
    "\n",
    "axs[2].set_xticks([0.1, 0.15, 0.2,0.25])\n",
    "axs[2].set_yticks([0.1, 0.15, 0.2,0.25])\n",
    "\n",
    "axs[3].set_xticks([0.40, 0.50, 0.60, 0.70])\n",
    "axs[3].set_yticks([0.40, 0.50, 0.60, 0.70])\n",
    "\n",
    "for i in range(4):\n",
    "    axs[i].tick_params(axis='x', labelsize=13)\n",
    "    axs[i].tick_params(axis='y', labelsize=13)\n",
    "\n",
    "axs[0].set_title(r\"$\\text{ECE}_1$\", fontsize=18)\n",
    "axs[1].set_title(r\"$\\text{ECE}_2$\", fontsize=18)\n",
    "axs[2].set_title(\"Brier score\", fontsize=18)\n",
    "    \n",
    "fig.tight_layout()\n",
    "plt.savefig(\"../plots/total_results.png\", dpi=600, bbox_inches=\"tight\")\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(6,3))\n",
    "\n",
    "for i, metric in enumerate(metrics_list):\n",
    "\n",
    "    bm = np.array(baseline_df[metric].tolist())\n",
    "    rm = np.array(reprompt_df[metric].tolist())\n",
    "\n",
    "    print(rm/bm)\n",
    "\n",
    "    plt.scatter([i/2]*len(bm), -(1-(rm/bm))*100, marker=\"x\", s=100, lw=4)\n",
    "    if i == 0:\n",
    "        plt.scatter([i/2], -(1-(np.mean(rm/bm)))*100, marker=\"+\", color=\"k\", label=\"Avg. % Change\", s=150, lw=3)\n",
    "    else:\n",
    "        plt.scatter([i/2], -(1-(np.mean(rm/bm)))*100, marker=\"+\", color=\"k\", s=150, lw=3)\n",
    "\n",
    "\n",
    "xticks = [r\"$\\text{ECE}_1$\",r\"$\\text{ECE}_2$\",\"Brier score\",\"NLL\"]\n",
    "plt.xticks(np.arange(4)/2, xticks, fontsize=18)\n",
    "plt.ylabel(\"% Change w/ Reprompt\", fontsize=15)\n",
    "plt.yticks(fontsize=14)\n",
    "plt.legend(fontsize=12, loc=\"upper left\")\n",
    "\n",
    "plt.xlim(-0.2,1.7)\n",
    "\n",
    "fig.tight_layout()\n",
    "\n",
    "plt.savefig(\"../plots/pct_reduction.png\", dpi=600, bbox_inches=\"tight\")\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tradeoff_df = full_df[full_df[\"Method\"] == \"baseline\"]\n",
    "display(tradeoff_df)\n",
    "\n",
    "metrics_list = [\"ECE-1\", \"ECE-2\", \"Brier\", \"NLL\"]\n",
    "\n",
    "fig, axs = plt.subplots(\n",
    "    1,4, \n",
    "    figsize=(13,3.25), \n",
    ")\n",
    "\n",
    "for i, metric in enumerate(metrics_list):\n",
    "\n",
    "    coord = i\n",
    "    ax = axs[coord]\n",
    "\n",
    "    for model in [\"Full\", \"quant8\", \"quant4\"]:\n",
    "\n",
    "        sub_df = tradeoff_df[tradeoff_df[\"Model\"] == model]\n",
    "        axs[coord].scatter(1-np.array(sub_df[\"Accuracy\"]), sub_df[metric], marker=\"+\", label=model, s=150, lw=5, color=pal[i])\n",
    "\n",
    "for i in range(4):\n",
    "    axs[i].set_xlabel(\"Task Error Rate\", fontsize=18)\n",
    "    axs[i].set_xticks([0.15, 0.2, 0.25], [0.15, 0.2, 0.25], fontsize=13)\n",
    "\n",
    "\n",
    "axs[0].set_ylabel(r\"$\\text{ECE}_1$\", fontsize=18)\n",
    "axs[1].set_ylabel(r\"$\\text{ECE}_2$\", fontsize=18)\n",
    "axs[2].set_ylabel(r\"Brier score\", fontsize=18)\n",
    "axs[3].set_ylabel(r\"NLL\", fontsize=18)\n",
    "\n",
    "\n",
    "axs[0].set_yticks([0.05, 0.10,0.15], [0.05, 0.10,0.15], fontsize=13)\n",
    "axs[1].set_yticks([0.05, 0.10,0.15,0.2], [0.05, 0.10,0.15,0.2], fontsize=13)\n",
    "axs[2].set_yticks([0.1, 0.15, 0.2,0.25], [0.1, 0.15, 0.2,0.25], fontsize=13)\n",
    "axs[3].set_yticks([0.40, 0.50, 0.60, 0.70], [0.40, 0.50, 0.60, 0.70], fontsize=13)\n",
    "\n",
    "fig.suptitle(\"Task Error vs. Calibration Error\", y=0.94, fontsize=18)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"../plots/tradeoffs_baseline.png\", dpi=600, bbox_inches=\"tight\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tradeoff_df = full_df[full_df[\"Method\"] == \"reprompt\"]\n",
    "display(tradeoff_df)\n",
    "\n",
    "metrics_list = [\"ECE-1\", \"ECE-2\", \"Brier\", \"NLL\"]\n",
    "\n",
    "fig, axs = plt.subplots(\n",
    "    1,4, \n",
    "    figsize=(13,3.25), \n",
    ")\n",
    "\n",
    "for i, metric in enumerate(metrics_list):\n",
    "\n",
    "    coord = i\n",
    "    ax = axs[coord]\n",
    "\n",
    "    for model in [\"Full\", \"quant8\", \"quant4\"]:\n",
    "\n",
    "        sub_df = tradeoff_df[tradeoff_df[\"Model\"] == model]\n",
    "        axs[coord].scatter(1-np.array(sub_df[\"Accuracy\"]), sub_df[metric], marker=\"+\", label=model, s=150, lw=5, color=pal[i])\n",
    "\n",
    "for i in range(4):\n",
    "    axs[i].set_xlabel(\"Task Error Rate\", fontsize=18)\n",
    "    axs[i].set_xticks([0.15, 0.2, 0.25], [0.15, 0.2, 0.25], fontsize=13)\n",
    "\n",
    "\n",
    "axs[0].set_ylabel(r\"$\\text{ECE}_1$\", fontsize=18)\n",
    "axs[1].set_ylabel(r\"$\\text{ECE}_2$\", fontsize=18)\n",
    "axs[2].set_ylabel(r\"Brier score\", fontsize=18)\n",
    "axs[3].set_ylabel(r\"NLL\", fontsize=18)\n",
    "\n",
    "\n",
    "axs[0].set_yticks([0.05, 0.10,0.15], [0.05, 0.10,0.15], fontsize=13)\n",
    "axs[1].set_yticks([0.05, 0.10,0.15,0.2], [0.05, 0.10,0.15,0.2], fontsize=13)\n",
    "axs[2].set_yticks([0.1, 0.15, 0.2,0.25], [0.1, 0.15, 0.2,0.25], fontsize=13)\n",
    "axs[3].set_yticks([0.40, 0.50, 0.60, 0.70], [0.40, 0.50, 0.60, 0.70], fontsize=13)\n",
    "\n",
    "fig.suptitle(\"Task Error vs. Calibration Error (Reprompt)\", y=0.95, fontsize=18)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"../plots/tradeoffs_reprompt.png\", dpi=600, bbox_inches=\"tight\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
