{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle as pkl\n",
    "from scipy.special import softmax\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "import random\n",
    "from tqdm import tqdm\n",
    "\n",
    "from vla_calibration.utils import *\n",
    "from vla_calibration.calibration import *\n",
    "\n",
    "plt.style.use('seaborn-v0_8')\n",
    "pal = plt.rcParams['axes.prop_cycle'].by_key()['color']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_prompt_probs = {}\n",
    "all_base_data = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_experiment(\n",
    "        task_name, \n",
    "        quant=None,\n",
    "        alternate_set=1, \n",
    "        n_prompts=20, \n",
    "        n_cal_bins=12,\n",
    "):\n",
    "    \n",
    "    data_save_dir = f\"../results/libero_{task_name}\"\n",
    "    if quant is not None:\n",
    "        data_save_dir += f\"/{quant}\"\n",
    "\n",
    "    top_n_steps=1\n",
    "\n",
    "    if task_name not in all_base_data:\n",
    "        all_base_data[task_name] = dict()\n",
    "        base_probs, _, correct = get_base_data(data_save_dir, top_n_steps)\n",
    "        all_base_data[task_name][\"base_probs\"] = base_probs\n",
    "        all_base_data[task_name][\"correct\"] = correct\n",
    "    else:\n",
    "        base_probs = all_base_data[task_name][\"base_probs\"]\n",
    "        correct = all_base_data[task_name][\"correct\"]\n",
    "\n",
    "    base_probs = np.expand_dims(base_probs, axis=2)\n",
    "\n",
    "    all_probs = []\n",
    "\n",
    "    if n_prompts < 20:\n",
    "        prompt_list = list(np.arange(20))\n",
    "        random.shuffle(prompt_list)\n",
    "        prompt_list = prompt_list[:n_prompts]\n",
    "    else:\n",
    "        prompt_list = range(n_prompts)\n",
    "\n",
    "    if task_name not in all_prompt_probs:\n",
    "        all_prompt_probs[task_name] = dict()\n",
    "\n",
    "    for i in prompt_list:\n",
    "\n",
    "        if i in all_prompt_probs[task_name]:\n",
    "            prompt_probs = all_prompt_probs[task_name][i]\n",
    "        else:\n",
    "\n",
    "            prompt_probs = []\n",
    "\n",
    "            if alternate_set == 1:\n",
    "                data_save_str = f\"{data_save_dir}/episode_data_prompt_{i}.pkl\"\n",
    "            elif alternate_set == 2:\n",
    "                data_save_str = f\"{data_save_dir}/episode_data_prompt_{i}_v2.pkl\"\n",
    "            elif alternate_set == 3:\n",
    "                data_save_str = f\"{data_save_dir}/episode_data_prompt_{i}_v3.pkl\"\n",
    "            else:\n",
    "                raise ValueError\n",
    "\n",
    "            with open(data_save_str, \"rb\") as f:  # \"rb\" = read binary mode\n",
    "                data = pkl.load(f)\n",
    "\n",
    "            for episode in data:\n",
    "\n",
    "                episode_probs = []\n",
    "\n",
    "                steps = episode[\"steps\"]\n",
    "\n",
    "                for step in steps[:top_n_steps]:\n",
    "\n",
    "                    logits = step[\"logits\"]\n",
    "                    probs = softmax(logits, -1)\n",
    "\n",
    "                    episode_probs.append(probs)\n",
    "\n",
    "                episode_probs = np.stack(episode_probs)\n",
    "                prompt_probs.append(episode_probs)\n",
    "\n",
    "            prompt_probs = np.stack(prompt_probs)\n",
    "            all_prompt_probs[task_name][i] = prompt_probs\n",
    "\n",
    "        all_probs.append(prompt_probs)\n",
    "\n",
    "    \n",
    "    all_probs = np.stack(all_probs)\n",
    "    ens_probs = np.transpose(all_probs, (1,2,0,3,4))\n",
    "\n",
    "    base_probs = base_probs[:,0]\n",
    "    ens_probs = ens_probs[:,0]\n",
    "\n",
    "    base_probs = np.max(base_probs, -1)\n",
    "    ens_probs = np.max(ens_probs, -1)\n",
    "\n",
    "    base_conf = np.mean(base_probs, -2)\n",
    "    ens_conf = np.mean(ens_probs, -2)\n",
    "\n",
    "\n",
    "    mean_base_conf = np.mean(base_conf, -1)\n",
    "    mean_ens_conf = np.mean(ens_conf, -1)\n",
    "\n",
    "    base_ece1 = round(get_ece(mean_base_conf, correct, n_cal_bins, p=1), 3)\n",
    "    ens_ece1 = round(get_ece(mean_ens_conf, correct, n_cal_bins, p=1), 3)\n",
    "\n",
    "    base_ece2 = round(get_ece(mean_base_conf, correct, n_cal_bins, p=2), 3)\n",
    "    ens_ece2 = round(get_ece(mean_ens_conf, correct, n_cal_bins, p=2), 3)\n",
    "\n",
    "    base_brier = round(np.mean((mean_base_conf - correct)**2), 3)\n",
    "    ens_brier = round(np.mean((mean_ens_conf - correct)**2), 3)\n",
    "\n",
    "    base_ce = cross_entropy(correct, mean_base_conf)\n",
    "    ens_ce = cross_entropy(correct, mean_ens_conf)\n",
    "\n",
    "    if quant is not None:\n",
    "        quant_tag = quant\n",
    "    else:\n",
    "        quant_tag = \"Full\"\n",
    "\n",
    "    base_row = [task_name, quant_tag, \"baseline\", base_ece1, base_ece2, base_brier, base_ce, np.mean(correct)]\n",
    "    ens_row = [task_name, quant_tag, \"reprompt\", ens_ece1, ens_ece2, ens_brier, ens_ce, np.mean(correct)]\n",
    "    rows = [base_row, ens_row]\n",
    "    \n",
    "    df = pd.DataFrame(rows, columns=[\"Dataset\",\"Model\",\"Method\",\"ECE-1\",\"ECE-2\",\"Brier\",\"NLL\",\"Accuracy\"])\n",
    "    return df\n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_bins = 12\n",
    "alternate_set = 1\n",
    "\n",
    "full_df = pd.DataFrame()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_prompt_list = [1,5,10,20]\n",
    "\n",
    "n_trials = 1000\n",
    "\n",
    "random.seed(0)\n",
    "np.random.seed(0)\n",
    "\n",
    "for n_prompts in n_prompt_list:\n",
    "\n",
    "    print(n_prompts)\n",
    "\n",
    "    if n_prompts == 20:\n",
    "        r = 1\n",
    "    else:\n",
    "        r = n_trials\n",
    "\n",
    "    for j in tqdm(range(r)):\n",
    "\n",
    "        df = run_experiment(\n",
    "            \"spatial\", \n",
    "            alternate_set=alternate_set, \n",
    "            n_cal_bins=n_bins,\n",
    "            n_prompts=n_prompts\n",
    "        )\n",
    "        df[\"n_prompts\"] = n_prompts\n",
    "        full_df = pd.concat([full_df, df])\n",
    "\n",
    "\n",
    "        df = run_experiment(\n",
    "            \"object\", \n",
    "            alternate_set=alternate_set, \n",
    "            n_cal_bins=n_bins,\n",
    "            n_prompts=n_prompts\n",
    "        )\n",
    "        df[\"n_prompts\"] = n_prompts\n",
    "        full_df = pd.concat([full_df, df])\n",
    "\n",
    "\n",
    "        df = run_experiment(\n",
    "            \"goal\", \n",
    "            alternate_set=alternate_set,  \n",
    "            n_cal_bins=n_bins,\n",
    "            n_prompts=n_prompts\n",
    "        )\n",
    "        df[\"n_prompts\"] = n_prompts\n",
    "        full_df = pd.concat([full_df, df])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "group_cols = [\"Dataset\", \"Model\", \"Method\", \"n_prompts\"]\n",
    "grouped_df = full_df.groupby(group_cols, as_index=False).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = grouped_df.sort_values([\"Dataset\", \"n_prompts\", \"Method\"]).drop_duplicates(['Dataset', 'Model', 'Method', 'ECE-1', 'ECE-2', 'Brier', 'NLL',\n",
    "       'Accuracy'])[['Dataset', 'Method', 'n_prompts', 'ECE-1', 'ECE-2', 'Brier', 'NLL']]\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "grouped_std_df = full_df.groupby(group_cols, as_index=False).std()\n",
    "std_df = grouped_std_df.sort_values([\"Dataset\", \"n_prompts\", \"Method\"]).drop_duplicates(['Dataset', 'Model', 'Method', 'ECE-1', 'ECE-2', 'Brier', 'NLL',\n",
    "       'Accuracy'])[['Dataset', 'Method', 'n_prompts', 'ECE-1', 'ECE-2', 'Brier', 'NLL']]\n",
    "std_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "datasets = [\"spatial\", \"goal\", \"object\"]\n",
    "for dataset in datasets:\n",
    "\n",
    "    dataset_df = df[(df[\"Dataset\"] == dataset) & (df[\"Method\"] == \"reprompt\")]\n",
    "    display(dataset_df)\n",
    "\n",
    "    dataset_std_df = std_df[(std_df[\"Dataset\"] == dataset) & (std_df[\"Method\"] == \"reprompt\")]\n",
    "    display(dataset_std_df)\n",
    "\n",
    "    fig, axs = plt.subplots(1,4,figsize=(13, 3))\n",
    "\n",
    "    axs[0].errorbar(n_prompt_list, dataset_df[\"ECE-1\"].tolist(), fmt=\"--\", yerr=np.array(dataset_std_df[\"ECE-1\"].tolist())/np.sqrt(n_trials), color=pal[0])\n",
    "    axs[1].errorbar(n_prompt_list, dataset_df[\"ECE-2\"].tolist(), fmt=\"--\", yerr=np.array(dataset_std_df[\"ECE-2\"].tolist())/np.sqrt(n_trials), color=pal[1])\n",
    "    axs[2].errorbar(n_prompt_list, dataset_df[\"Brier\"].tolist(), fmt=\"--\", yerr=np.array(dataset_std_df[\"Brier\"].tolist())/np.sqrt(n_trials), color=pal[2])\n",
    "    axs[3].errorbar(n_prompt_list, dataset_df[\"NLL\"].tolist(), fmt=\"--\", yerr=np.array(dataset_std_df[\"NLL\"].tolist())/np.sqrt(n_trials), color=pal[3])\n",
    "\n",
    "    axs[0].set_ylabel(r\"$\\text{ECE}_1$\", fontsize=18)\n",
    "    axs[1].set_ylabel(r\"$\\text{ECE}_2$\", fontsize=18)\n",
    "    axs[2].set_ylabel(\"Brier score\", fontsize=18)\n",
    "    axs[3].set_ylabel(\"NLL\", fontsize=18)\n",
    "\n",
    "    for i in range(4):\n",
    "        axs[i].set_xlabel(\"Ensemble Size\", fontsize=18)\n",
    "        axs[i].set_xticks(n_prompt_list, n_prompt_list, size=16)\n",
    "        axs[i].tick_params(axis='y', labelsize=13)\n",
    "    \n",
    "    fig.suptitle(f\"Prompt Ensemble Ablation (Dataset: {str.title(dataset)})\", y=0.925, fontsize=18)\n",
    "    fig.tight_layout()\n",
    "    plt.savefig(f\"../plots/n_prompts_ablation_{dataset}_with_errors.png\", dpi=600, bbox_inches=\"tight\")\n",
    "    plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "datasets = [\"spatial\", \"goal\", \"object\"]\n",
    "for dataset in datasets:\n",
    "\n",
    "    dataset_df = df[(df[\"Dataset\"] == dataset) & (df[\"Method\"] == \"reprompt\")]\n",
    "    display(dataset_df)\n",
    "\n",
    "    fig, axs = plt.subplots(1,4,figsize=(13, 3.25))\n",
    "\n",
    "    axs[0].plot(n_prompt_list, dataset_df[\"ECE-1\"].tolist(), \"--o\")\n",
    "    axs[1].plot(n_prompt_list, dataset_df[\"ECE-2\"].tolist(), \"--o\", color=pal[1])\n",
    "    axs[2].plot(n_prompt_list, dataset_df[\"Brier\"].tolist(), \"--o\", color=pal[2])\n",
    "    axs[3].plot(n_prompt_list, dataset_df[\"NLL\"].tolist(), \"--o\", color=pal[3])\n",
    "\n",
    "    axs[0].set_ylabel(r\"$\\text{ECE}_1$\", fontsize=18)\n",
    "    axs[1].set_ylabel(r\"$\\text{ECE}_2$\", fontsize=18)\n",
    "    axs[2].set_ylabel(\"Brier score\", fontsize=18)\n",
    "    axs[3].set_ylabel(\"NLL\", fontsize=18)\n",
    "\n",
    "    for i in range(4):\n",
    "        axs[i].set_xlabel(\"Ensemble Size\", fontsize=18)\n",
    "        axs[i].set_xticks(n_prompt_list, n_prompt_list, size=16)\n",
    "        axs[i].tick_params(axis='y', labelsize=12)\n",
    "    \n",
    "    fig.suptitle(f\"Prompt Ensemble Ablation (Dataset: {str.title(dataset)})\", y=0.925, fontsize=18)\n",
    "    fig.tight_layout()\n",
    "    plt.savefig(f\"../plots/n_prompts_ablation_{dataset}.png\", dpi=600, bbox_inches=\"tight\")\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
