{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle as pkl\n",
    "from scipy.special import softmax\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "\n",
    "from vla_calibration.utils import *\n",
    "from vla_calibration.calibration import *\n",
    "\n",
    "plt.style.use('seaborn-v0_8')\n",
    "pal = plt.rcParams['axes.prop_cycle'].by_key()['color']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_experiment(\n",
    "        task_name, \n",
    "        quant=None,\n",
    "        alternate_set=1, \n",
    "        n_prompts=20, \n",
    "        n_cal_bins=12,\n",
    "):\n",
    "    \n",
    "    data_save_dir = f\"../results/libero_{task_name}\"\n",
    "    if quant is not None:\n",
    "        data_save_dir += f\"/{quant}\"\n",
    "\n",
    "    top_n_steps=1\n",
    "\n",
    "    base_probs, _, correct = get_base_data(data_save_dir, top_n_steps)\n",
    "\n",
    "    base_probs = np.expand_dims(base_probs, axis=2)\n",
    "\n",
    "    all_probs = []\n",
    "\n",
    "    for i in range(n_prompts):\n",
    "\n",
    "        prompt_probs = []\n",
    "\n",
    "        if alternate_set == 1:\n",
    "            data_save_str = f\"{data_save_dir}/episode_data_prompt_{i}.pkl\"\n",
    "        elif alternate_set == 2:\n",
    "            data_save_str = f\"{data_save_dir}/episode_data_prompt_{i}_v2.pkl\"\n",
    "        elif alternate_set == 3:\n",
    "            data_save_str = f\"{data_save_dir}/episode_data_prompt_{i}_v3.pkl\"\n",
    "        else:\n",
    "            raise ValueError\n",
    "\n",
    "        with open(data_save_str, \"rb\") as f:  \n",
    "            data = pkl.load(f)\n",
    "\n",
    "        for episode in data:\n",
    "\n",
    "            episode_probs = []\n",
    "\n",
    "            steps = episode[\"steps\"]\n",
    "\n",
    "            for step in steps[:top_n_steps]:\n",
    "\n",
    "                logits = step[\"logits\"]\n",
    "                probs = softmax(logits, -1)\n",
    "\n",
    "                episode_probs.append(probs)\n",
    "\n",
    "            episode_probs = np.stack(episode_probs)\n",
    "            prompt_probs.append(episode_probs)\n",
    "\n",
    "        prompt_probs = np.stack(prompt_probs)\n",
    "\n",
    "        all_probs.append(prompt_probs)\n",
    "\n",
    "    \n",
    "    all_probs = np.stack(all_probs)\n",
    "    ens_probs = np.transpose(all_probs, (1,2,0,3,4))\n",
    "\n",
    "    base_probs = base_probs[:,0]\n",
    "    ens_probs = ens_probs[:,0]\n",
    "\n",
    "\n",
    "    base_probs = np.max(base_probs, -1)\n",
    "    ens_probs = np.max(ens_probs, -1)\n",
    "\n",
    "    base_conf = np.mean(base_probs, -2)\n",
    "    ens_conf = np.mean(ens_probs, -2)\n",
    "\n",
    "    mean_base_conf = np.mean(base_conf, -1)\n",
    "    mean_ens_conf = np.mean(ens_conf, -1)\n",
    "\n",
    "    base_ece1 = round(get_ece(mean_base_conf, correct, n_cal_bins, p=1), 3)\n",
    "    ens_ece1 = round(get_ece(mean_ens_conf, correct, n_cal_bins, p=1), 3)\n",
    "\n",
    "    base_ece2 = round(get_ece(mean_base_conf, correct, n_cal_bins, p=2), 3)\n",
    "    ens_ece2 = round(get_ece(mean_ens_conf, correct, n_cal_bins, p=2), 3)\n",
    "\n",
    "    base_brier = round(np.mean((mean_base_conf - correct)**2), 3)\n",
    "    ens_brier = round(np.mean((mean_ens_conf - correct)**2), 3)\n",
    "\n",
    "    base_ce = cross_entropy(correct, mean_base_conf)\n",
    "    ens_ce = cross_entropy(correct, mean_ens_conf)\n",
    "\n",
    "    if quant is not None:\n",
    "        quant_tag = quant\n",
    "    else:\n",
    "        quant_tag = \"Full\"\n",
    "\n",
    "    base_row = [task_name, quant_tag, \"baseline\", base_ece1, base_ece2, base_brier, base_ce, np.mean(correct)]\n",
    "    ens_row = [task_name, quant_tag, \"reprompt\", ens_ece1, ens_ece2, ens_brier, ens_ce, np.mean(correct)]\n",
    "    rows = [base_row, ens_row]\n",
    "\n",
    "    df = pd.DataFrame(rows, columns=[\"Dataset\",\"Model\",\"Method\",\"ECE-1\",\"ECE-2\",\"Brier\",\"NLL\",\"Accuracy\"])\n",
    "    return df\n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_bins = 12\n",
    "alternate_set = 1\n",
    "\n",
    "full_df = pd.DataFrame()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(3):\n",
    "    alternate_set = i+1\n",
    "\n",
    "    df = run_experiment(\n",
    "        \"spatial\", \n",
    "        alternate_set=alternate_set, \n",
    "        n_cal_bins=n_bins,\n",
    "        n_prompts=20\n",
    "    )\n",
    "    df[\"Prompt\"] = alternate_set\n",
    "    full_df = pd.concat([full_df, df])\n",
    "\n",
    "    display(full_df)\n",
    "\n",
    "    df = run_experiment(\n",
    "        \"object\", \n",
    "        alternate_set=alternate_set, \n",
    "        n_cal_bins=n_bins,\n",
    "        n_prompts=20\n",
    "    )\n",
    "    df[\"Prompt\"] = alternate_set\n",
    "    full_df = pd.concat([full_df, df])\n",
    "\n",
    "    display(full_df)\n",
    "\n",
    "    df = run_experiment(\n",
    "        \"goal\", \n",
    "        alternate_set=alternate_set,  \n",
    "        n_cal_bins=n_bins,\n",
    "        n_prompts=20\n",
    "    )\n",
    "    df[\"Prompt\"] = alternate_set\n",
    "    full_df = pd.concat([full_df, df])\n",
    "\n",
    "    display(full_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = full_df.sort_values([\"Dataset\", \"Prompt\", \"Method\"]).drop_duplicates(['Dataset', 'Model', 'Method', 'ECE-1', 'ECE-2', 'Brier', 'NLL',\n",
    "       'Accuracy'])[['Dataset', 'Method', 'Prompt', 'ECE-1', 'ECE-2', 'NLL', 'Brier']]\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(1, 2, figsize=(13, 3.25))\n",
    "\n",
    "\n",
    "pivot_df = df.pivot(index='Dataset', columns=['Method', 'Prompt'], values='ECE-1')\n",
    "pivot_df.columns = ['Baseline', 'Prompt 1', 'Prompt 2', 'Prompt 3']\n",
    "dataset_order = [\"spatial\", \"object\", \"goal\"]\n",
    "ordered_df = pivot_df.reindex(dataset_order)\n",
    "\n",
    "\n",
    "ordered_df.plot(ax=axs[0], kind='bar')\n",
    "axs[0].set_ylabel(r\"$\\text{ECE}_1$\", fontsize=18)\n",
    "axs[0].set_xlabel(\"\")\n",
    "axs[0].set_xticks(np.arange(len(dataset_order)), [str.title(d) for d in dataset_order], rotation=0, fontsize=18)\n",
    "axs[0].legend(loc='upper left', ncol=2, fontsize=15)\n",
    "\n",
    "\n",
    "pivot_df = df.pivot(index='Dataset', columns=['Method', 'Prompt'], values='NLL')\n",
    "\n",
    "\n",
    "pivot_df.columns = ['Baseline', 'Prompt 1', 'Prompt 2', 'Prompt 3']\n",
    "dataset_order = [\"spatial\", \"object\", \"goal\"]\n",
    "ordered_df = pivot_df.reindex(dataset_order)\n",
    "\n",
    "\n",
    "ordered_df.plot(ax=axs[1], kind='bar', legend=False)\n",
    "axs[1].set_ylabel(r\"NLL\", fontsize=18)\n",
    "axs[1].set_xlabel(\"\")\n",
    "axs[1].set_xticks(np.arange(len(dataset_order)), [str.title(d) for d in dataset_order], rotation=0, fontsize=16)\n",
    "\n",
    "for i in range(2):\n",
    "    axs[i].tick_params(axis='y', labelsize=12)\n",
    "\n",
    "fig.suptitle(\"Rephrasing Prompt Ablation\", fontsize=18, y=0.95)\n",
    "\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"../plots/rephrasings_ablation_both.png\", dpi=600, bbox_inches=\"tight\")\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
