{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c7dedcc",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle as pkl\n",
    "from scipy.special import softmax\n",
    "import numpy as np\n",
    "from sklearn.calibration import calibration_curve\n",
    "from pathlib import Path\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from vla_calibration.utils import *\n",
    "from vla_calibration.calibration import *\n",
    "\n",
    "plt.style.use('seaborn-v0_8')\n",
    "pal = plt.rcParams['axes.prop_cycle'].by_key()['color']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c0f9adb7",
   "metadata": {},
   "outputs": [],
   "source": [
    "methods_map = {\n",
    "    \"last_step\": \"Current\",\n",
    "    \"sliding_5\": \"Window (5)\",\n",
    "    \"avg_all\": \"Avg. All\"\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "73e5299c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_timestep_data(data_save_dir, x_points, methods):\n",
    "\n",
    "    path = f\"{data_save_dir}/base_probs.pkl\"\n",
    "    path = Path(path).expanduser().resolve()\n",
    "    verbose=True\n",
    "    protocol = pkl.HIGHEST_PROTOCOL\n",
    "\n",
    "    if path.is_file():\n",
    "        if verbose:\n",
    "            print(f\"[load_or_create_pickle] Loading existing pickle: {path}\")\n",
    "        with path.open(\"rb\") as f:\n",
    "            all_probs, correct = pkl.load(f)\n",
    "        \n",
    "    else:\n",
    "\n",
    "\n",
    "        with open(f\"{data_save_dir}/episode_data_true_prompt.pkl\", \"rb\") as f: \n",
    "            data = pkl.load(f)\n",
    "\n",
    "        all_probs = {}\n",
    "        correct = []\n",
    "\n",
    "        for episode in data:\n",
    "\n",
    "            steps = episode[\"steps\"]\n",
    "\n",
    "            episode_probs = []\n",
    "\n",
    "            for step in steps:\n",
    "\n",
    "                logits = step[\"logits\"]\n",
    "\n",
    "                probs = softmax(logits, -1)\n",
    "\n",
    "                episode_probs.append(probs)\n",
    "\n",
    "            episode_probs = np.stack(episode_probs)\n",
    "            episode_conf = np.max(episode_probs, -1)\n",
    "\n",
    "            for percent in x_points:\n",
    "\n",
    "                if percent not in all_probs:\n",
    "                    all_probs[percent] = dict()\n",
    "\n",
    "                last_step = np.ceil(len(steps)*percent).astype(int)\n",
    "                if last_step >= len(steps):\n",
    "                    last_step = len(steps)-1\n",
    "\n",
    "                for method in methods:\n",
    "\n",
    "                    if method not in all_probs[percent]:\n",
    "                        all_probs[percent][method] = []\n",
    "\n",
    "                    if \"sliding\" in method:\n",
    "                        window_size = int(method.split(\"_\")[-1])\n",
    "                        min_step = np.max([0,last_step-window_size])\n",
    "                        conf = np.mean(np.mean(episode_conf[min_step:last_step+1,:], -1), -1)\n",
    "\n",
    "                    elif method == \"avg_all\":\n",
    "                        conf = np.mean(np.mean(episode_conf[:last_step+1,:], -1), -1)\n",
    "\n",
    "                    elif method == \"last_step\":\n",
    "                        conf = np.mean(episode_conf[last_step], -1)\n",
    "\n",
    "                    all_probs[percent][method].append(conf)\n",
    "\n",
    "            correct.append(int(episode[\"done\"]))\n",
    "\n",
    "        data = (all_probs, correct)\n",
    "        path.parent.mkdir(parents=True, exist_ok=True) \n",
    "        with path.open(\"wb\") as f:\n",
    "            pkl.dump(data, f, protocol=protocol)\n",
    "\n",
    "    return all_probs, correct"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "756e150a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def produce_plots(\n",
    "        all_probs,\n",
    "        correct,\n",
    "        task_name, \n",
    "        x_points,\n",
    "        methods,\n",
    "        quant,\n",
    "        save_string,\n",
    "        n_cal_bins=12,\n",
    "        save_fig=True,\n",
    "        trailing_k=5,\n",
    "        title_addition=\"\",     \n",
    "):\n",
    "    quant_save_string = \"\"\n",
    "    if quant is not None:\n",
    "        quant_save_string = f\"_{quant}\"\n",
    "\n",
    "    quant_string = \"\"\n",
    "    if quant is not None:\n",
    "        quant_string = f\" ({str.title(quant)})\"\n",
    "\n",
    "    results = {}\n",
    "    brier_results = {}\n",
    "\n",
    "    fig, axs = plt.subplots(\n",
    "        1,4, \n",
    "        figsize=(13, 3.5), \n",
    "    )\n",
    "\n",
    "    for method in methods:\n",
    "        if method not in results:\n",
    "            results[method] = []\n",
    "            brier_results[method] = []\n",
    "        for percent in x_points:\n",
    "            ece = get_ece(np.array(all_probs[percent][method]), correct, n_cal_bins)\n",
    "            results[method].append(ece)\n",
    "            brier_results[method].append(np.mean(((np.array(all_probs[percent][method]) - correct)**2)))\n",
    "\n",
    "        axs[0].plot(np.array(x_points)*100, trailing_average(np.array(results[method]), k=trailing_k), \"--\", label=methods_map[method], lw=2)\n",
    "        axs[1].plot(np.array(x_points)*100, trailing_average(np.array(brier_results[method]), k=trailing_k), \"--\", label=methods_map[method], lw=2)\n",
    "\n",
    "    baseline = results[methods[0]][0]\n",
    "    axs[0].plot(np.array(x_points)*100,[baseline]*len(x_points), \"--\", color=\"k\", alpha=0.5)\n",
    "\n",
    "    baseline = brier_results[methods[0]][0]\n",
    "    axs[1].plot(np.array(x_points)*100,[baseline]*len(x_points), \"--\", color=\"k\", alpha=0.5)\n",
    "    \n",
    "\n",
    "    ax23_ymin = 100.0\n",
    "    ax23_ymax = 0.0\n",
    "\n",
    "    reliability_data = {\n",
    "        \"0.0\": dict(),\n",
    "        \"0.5\": dict(),\n",
    "        \"0.75\": dict(),\n",
    "        \"0.99\": dict(),\n",
    "    }\n",
    "    rd_pcts = list(reliability_data.keys())\n",
    "\n",
    "\n",
    "    t = 0.9\n",
    "    print(\"-\"*20)\n",
    "    print(\"accuracy:\", np.mean(correct))\n",
    "    results = {}\n",
    "    for method in methods:\n",
    "\n",
    "        if method not in results:\n",
    "            results[method] = dict()\n",
    "            for score in [\"avg_incorrect\", \"avg_correct\", \"avg_all\", \"std_incorrect\", \"std_correct\", \"pct_incorrect\", \"pct_correct\", \"pct_all\"]:\n",
    "                results[method][score] = list()\n",
    "        for percent in x_points:\n",
    "            probs = np.array(all_probs[percent][method])\n",
    "            avg_incorrect, avg_correct = average_confidences(correct, probs)\n",
    "            std_incorrect, std_correct = std_confidences(correct, probs)\n",
    "            pct_incorrect, pct_correct = pct_high_confidence(correct, probs, t)\n",
    "            results[method][\"avg_incorrect\"].append(avg_incorrect)\n",
    "            results[method][\"avg_correct\"].append(avg_correct)\n",
    "            results[method][\"avg_all\"].append(np.mean(probs))\n",
    "            results[method][\"std_incorrect\"].append(std_incorrect)\n",
    "            results[method][\"std_correct\"].append(std_correct)\n",
    "            results[method][\"pct_incorrect\"].append(pct_incorrect)\n",
    "            results[method][\"pct_correct\"].append(pct_correct)\n",
    "            results[method][\"pct_all\"].append((probs > t).mean())\n",
    "\n",
    "            if str(percent) in reliability_data:\n",
    "                reliability_data[str(percent)][method] = (probs, correct)\n",
    "\n",
    "        avg_correct = trailing_average(np.array(results[method][\"avg_correct\"]), k=trailing_k)\n",
    "        avg_incorrect = trailing_average(np.array(results[method][\"avg_incorrect\"]), k=trailing_k)\n",
    "\n",
    "        axs[2].plot(np.array(x_points)*100, avg_correct, \"--\", label=methods_map[method], lw=2)\n",
    "        axs[3].plot(np.array(x_points)*100, avg_incorrect, \"--\", label=methods_map[method], lw=2)\n",
    "\n",
    "        ax23_ymin = min(min(avg_correct), ax23_ymin)\n",
    "        ax23_ymin = min(min(avg_incorrect), ax23_ymin)\n",
    "\n",
    "        ax23_ymax = max(max(avg_correct), ax23_ymax)\n",
    "        ax23_ymax = max(max(avg_incorrect), ax23_ymax)\n",
    "\n",
    "    axs[0].set_title(\"Calibration Error\", fontsize=18)\n",
    "    axs[1].set_title(\"Calibration Error\", fontsize=18)\n",
    "\n",
    "    axs[0].set_ylabel(r\"$\\text{ECE}_1$\", fontsize=18)\n",
    "    axs[1].set_ylabel(\"Brier Score\", fontsize=18)\n",
    "\n",
    "    axs[2].set_title(\"Successful Trials\", fontsize=18)\n",
    "    axs[3].set_title(\"Failed Trials\", fontsize=18)\n",
    "\n",
    "    axs[2].set_ylabel(\"Avg. Conf.\", fontsize=18)\n",
    "    axs[3].set_ylabel(\"Avg. Conf.\", fontsize=18)\n",
    "    \n",
    "\n",
    "    for i in range(4):\n",
    "        axs[i].set_xlabel(\"% Task Completion\", fontsize=16)\n",
    "        axs[i].tick_params(axis='y', labelsize=11)\n",
    "        axs[i].set_xticks([0, 25, 50, 75, 100], [0, 25, 50, 75, 100], fontsize=12)\n",
    "\n",
    "    for j in range(2):\n",
    "        axs[j+2].set_ylim(ax23_ymin*0.995, ax23_ymax*1.005)\n",
    "\n",
    "    axs[-1].legend(fontsize=12)\n",
    "\n",
    "    fig.suptitle(f\"{str.title(task_name)}{quant_string}{title_addition}\", y=0.94, fontsize=18)\n",
    "\n",
    "    fig.tight_layout()\n",
    "    if save_fig:\n",
    "        plt.savefig(f\"../plots/across_time_{task_name}{quant_save_string}_{save_string}.png\", dpi=600, bbox_inches=\"tight\")\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "    fig, axs = plt.subplots(2,len(rd_pcts), sharex=True, sharey=True, figsize=(9,4))\n",
    "\n",
    "    for pct_idx, pct in enumerate(rd_pcts):\n",
    "\n",
    "        for i, method in enumerate([\"last_step\", \"avg_all\"]):\n",
    "            data = reliability_data[pct][method]\n",
    "            prob_true, prob_pred = calibration_curve(np.array(data[1]), data[0], n_bins=n_cal_bins, strategy=\"quantile\")\n",
    "            axs[i,pct_idx].scatter(prob_pred, prob_true, color=pal[i+3])\n",
    "            axs[i,pct_idx].set_xlim(.5,1.02)\n",
    "            axs[i,pct_idx].set_ylim(0.0,1.0)\n",
    "            axs[i,pct_idx].plot([0,1], [0,1], \"--\", color=\"k\", alpha=0.5)\n",
    "\n",
    "        axs[0,pct_idx].set_title(f\"{int(float(pct)*100)}% Complete\", fontsize=14)\n",
    "        \n",
    "    axs[0,0].set_ylabel(\"Accuracy\", fontsize=14)\n",
    "    axs[1,0].set_ylabel(\"Accuracy\", fontsize=14)\n",
    "\n",
    "    fig.text(-0.12, 0.72, \"Current\", fontsize=14, fontweight=\"bold\")\n",
    "    fig.text(-0.12, 0.3, \"Avg. All\", fontsize=14, fontweight=\"bold\")\n",
    "\n",
    "    for i in range(len(rd_pcts)):\n",
    "        axs[1,i].set_xlabel(\"Confidence\", fontsize=14)\n",
    "\n",
    "    fig.tight_layout()\n",
    "\n",
    "    if save_fig:\n",
    "        plt.savefig(f\"../plots/across_time_{task_name}{quant_save_string}_{save_string}_reliability_diagrams.png\", dpi=600, bbox_inches=\"tight\")\n",
    "    plt.show()\n",
    "\n",
    "    print(\"-\"*20)\n",
    "\n",
    "\n",
    "def run_experiment(\n",
    "        task_name, \n",
    "        x_points=[0.1, 0.5, 0.99],\n",
    "        methods=[\"last_step\", \"sliding_5\", \"sliding_10\", \"avg_all\"],\n",
    "        quant=None,\n",
    "        n_cal_bins=12,\n",
    "        save_fig=True,\n",
    "        trailing_k=5\n",
    "):\n",
    "    \n",
    "    data_save_dir = f\"../results/libero_{task_name}\"\n",
    "    if quant is not None:\n",
    "        data_save_dir += f\"/{quant}\"\n",
    "\n",
    "    base_probs, correct = get_timestep_data(data_save_dir, x_points, methods)\n",
    "    \n",
    "    produce_plots(\n",
    "        base_probs, \n",
    "        correct,\n",
    "        task_name, \n",
    "        x_points,\n",
    "        methods,\n",
    "        quant,\n",
    "        save_string=\"_baseline\",\n",
    "        n_cal_bins=n_cal_bins,\n",
    "        save_fig=save_fig,\n",
    "        trailing_k=trailing_k,\n",
    "        title_addition=\"\",\n",
    "    )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75bef5f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_bins = 12\n",
    "x_points = list(np.arange(100)/100)\n",
    "methods=[\n",
    "    \"last_step\", \n",
    "    \"sliding_5\", \n",
    "    \"avg_all\"\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "568882aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "for suite in [\n",
    "    \"spatial\", \n",
    "    \"object\", \n",
    "    \"goal\"\n",
    "]:\n",
    "\n",
    "    run_experiment(\n",
    "        suite, \n",
    "        x_points,\n",
    "        methods,\n",
    "        n_cal_bins=n_bins,\n",
    "        trailing_k=3\n",
    "    )\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36c567e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "for suite in [\n",
    "    \"spatial\", \n",
    "    \"object\", \n",
    "    \"goal\"\n",
    "]:\n",
    "\n",
    "    run_experiment(\n",
    "        suite, \n",
    "        x_points,\n",
    "        methods,\n",
    "        quant=\"quant8\",\n",
    "        n_cal_bins=n_bins,\n",
    "        trailing_k=3\n",
    "    )"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
