{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "e2380b3b-dea9-44d5-bdc2-dd925d83d2c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import os\n",
    "import json\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import random\n",
    "import statistics\n",
    "from itertools import groupby\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.lines import Line2D\n",
    "import matplotlib\n",
    "from plotting_style import *\n",
    "from risk_control_utils import (get_label_order, rc_main, get_all_confidences, get_all_accuracies, get_ground_truth_by_type, \n",
    "                                get_relative_labels, apply_risk_control, load_all_data, get_losses_and_exits_confidence)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "32f6603d-6024-4a77-a521-60993cc2b171",
   "metadata": {},
   "source": [
    "# Parameters\n",
    "Specify all parameters for this notebook. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "cee9e9de-b40f-4ddc-a7eb-709dddc638ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "use_calibration = True # True or False\n",
    "fake_labels = True # True or False\n",
    "confidence_type = 'argmax' # argmax or top2_diff or entropy\n",
    "i_prop, c_prop = 90, 10 # the % of total observations of each type. do 5/95, 10/90, 25/75, 50/50\n",
    "debug_mode = False # True to display small subset of plots in notebook, False to make all plots and save them\n",
    "display_legends = False # whether to include legends on plots (all or none)\n",
    "max_eps = 0.5 # the maximum epsilon-value to plot. For the paper, this is between 0.1 and 0.5. "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f9ed141b-75e3-4c94-8934-b999a36e4425",
   "metadata": {},
   "source": [
    "# Define Variables\n",
    "These should not need to be modified to reproduce the results from the \"Safe In-Context Learning\" paper."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "0b9aecb1-56d2-4ae0-9437-8acad7d5fc0f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define a list of all the models and datasets to plot\n",
    "models = [\"facebook/layerskip-llama3-8B\", \"facebook/layerskip-llama2-7B\", \"meta-llama/Meta-Llama-3-8B\", \"meta-llama/Llama-2-7B-hf\"]\n",
    "tokenizers = [\"meta-llama/Meta-Llama-3-8B\", \"meta-llama/Llama-2-7B-hf\", \"meta-llama/Meta-Llama-3-8B\", \"meta-llama/Llama-2-7B-hf\"]\n",
    "n_early_exits = [32, 32, 32, 32]\n",
    "datasets = ['sst2', 'trec', 'financial_phrasebank', 'tweeteval_hate', 'tweeteval_feminist', 'tweeteval_atheism', 'unnatural', 'ag_news']\n",
    "n_demos=60 # 60"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "e7d0755e-1fba-466f-8c4b-bc408633c2bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_names_plot_titles = {\n",
    "    'sst2': 'SST2',\n",
    "    'trec': 'TREC',\n",
    "    'financial_phrasebank': 'FinancialPhrasebank',\n",
    "    'tweeteval_hate': 'TweetEval-Hate',\n",
    "    'tweeteval_feminist': 'TweetEval-Feminist', \n",
    "    'tweeteval_atheism': 'TweetEval-Atheism',\n",
    "    'unnatural': 'Unnatural', \n",
    "    'ag_news': 'AG News',\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "df0b3247-df2e-4f1b-a5fb-bf6346fe26eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define base data folder to read from (and associated params for finding the right data files)\n",
    "# different for different splits, e.g. 10_90_mix, 5_95_mix\n",
    "mixed_filename = str(i_prop) + '_' + str(c_prop) + '_mix'\n",
    "precomputed_risk_path = './rc-precomputed/' + ('fake_labels' if fake_labels else '') + ('/calibrated/' if use_calibration else '/uncalibrated/')\n",
    "precomputed_risk_path += confidence_type + '/'\n",
    "precomputed_mixed_path = precomputed_risk_path + mixed_filename + '/'\n",
    "results_folder = './results' + ('_fake_labels' if fake_labels else '') + ('/calibrated' if use_calibration else '/uncalibrated')\n",
    "\n",
    "with open('fake_labels.json') as f:\n",
    "    fake_label_map = json.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "e2ac457a-bf23-45ca-868b-360c53b2128a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define plotting params\n",
    "c_color, i_color, z_color = 'tab:blue', 'tab:orange', 'tab:green'\n",
    "model_colors = {\"facebook/layerskip-llama3-8B\": 'tab:brown', \"facebook/layerskip-llama2-7B\": 'tab:pink', \"meta-llama/Meta-Llama-3-8B\": 'tab:purple', \n",
    "               \"meta-llama/Llama-2-7B-hf\": 'tab:cyan'}\n",
    "first_exit=15 # for risk control, this sets the earliest layer at which we are allowed to exit\n",
    "plot_directory = './icl_plots' + ('_fake_labels' if fake_labels else '') + '/' + ('legend' if display_legends else 'no_legend') + '/'\n",
    "plot_directory += ('calibrated' if use_calibration else 'uncalibrated') + '/n_demos_' + str(n_demos) + '/' + confidence_type + '/'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "8d47c1ea-e3d4-4f08-8e57-1182c8e146f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define lambdas and epsilons\n",
    "# NOTE: If running with loss_01_conversion=max_0, cannot have epsilon < 0!\n",
    "stepsize = 0.01\n",
    "eps_grid = np.arange(0.0, max_eps + stepsize, stepsize)\n",
    "lambdas = np.arange(0.0, 1.0 + stepsize, stepsize)[::-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "c8e7573d-3922-4aa2-aa48-8b7989152abd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define loss and risk-control parameters\n",
    "relative_labels = 'zeroshot_full_model' # zeroshot_full_model or full_model; the predictions over which to compute a relative loss\n",
    "ground_truth_type = 'true_label' # the ground-truth for computing loss; true_label or zeroshot_full_model\n",
    "rcp_type = 'ltt'\n",
    "delta=0.1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "c3e85724-9fb5-440f-8bf0-32666e4349ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "if debug_mode:\n",
    "    # Select just a subset of the datasets and models\n",
    "    datasets = ['financial_phrasebank']\n",
    "    n_trials=2\n",
    "else:\n",
    "    # Prevent displaying figures\n",
    "    matplotlib.use('Agg')\n",
    "    n_trials=50"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0cbf4d28-f7b2-4f18-97d2-b7c783f1e493",
   "metadata": {},
   "source": [
    "# Pre-Compute Risk Control Results\n",
    "This makes plotting much more efficient. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "3fc752ee-f451-4a8a-bf23-983f55f200bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Pre-compute and save the risk control results matrices\n",
    "if not os.path.exists(precomputed_risk_path):\n",
    "    for dataset in datasets:\n",
    "        for model_name, n_early_exit, tokenizer in zip(models, n_early_exits, tokenizers):\n",
    "            label_order = get_label_order(dataset, tokenizer)\n",
    "            if fake_labels:\n",
    "                label_order = [fake_label_map[x] for x in label_order]\n",
    "            \n",
    "            base_dir = results_folder + '/n_demos_' + str(n_demos) + '/' + dataset + '/' + model_name + '/'\n",
    "            # First check that there exists all types of experiments\n",
    "            if (os.path.exists(base_dir + 'correct.json') and os.path.exists(base_dir + 'incorrect.json')\n",
    "                    and os.path.exists(base_dir + 'zeroshot.json')):\n",
    "                # We have all the data\n",
    "                data = {}\n",
    "                for expt_type in ['correct', 'incorrect', 'zeroshot']:\n",
    "                    with open(base_dir + expt_type + '.json', 'r') as file:\n",
    "                        data[expt_type] = json.load(file)\n",
    "                    \n",
    "                c_gt, i_gt, z_gt = get_ground_truth_by_type(ground_truth_type, data['correct'], data['incorrect'], data['zeroshot'], n_early_exit)\n",
    "                c_rel, i_rel, z_rel = get_relative_labels(relative_labels, data['correct'], data['incorrect'], data['zeroshot'], n_early_exit)\n",
    "    \n",
    "                for gt, rel, label in zip([c_gt, i_gt], [c_rel, i_rel], ['correct', 'incorrect']):\n",
    "                    conf = get_all_confidences(data[label], n_early_exit, label_order, confidence_type, first_exit)\n",
    "                    acc = get_all_accuracies(data[label], gt, n_early_exit, first_exit)\n",
    "                    n_cal = int(len(data[label]['0'])/2)\n",
    "                    # Run the max-0 method\n",
    "                    losses, test_risk, eff_gains, rcp_lams = apply_risk_control(conf, acc, gt, rel, lambdas, eps_grid, \n",
    "                                                                                delta, n_cal, n_trials, 'max_0')\n",
    "                    # Save results\n",
    "                    max0_path = precomputed_risk_path + 'max0/' + dataset + '/' + model_name + '/' + label + '/'\n",
    "                    if not os.path.exists(max0_path):\n",
    "                        os.makedirs(max0_path)\n",
    "                    np.save(max0_path + 'losses.npy', losses)\n",
    "                    np.save(max0_path + 'test_risk.npy', test_risk)\n",
    "                    np.save(max0_path + 'eff_gains.npy', eff_gains)\n",
    "                    np.save(max0_path + 'rcp_lams.npy', rcp_lams)\n",
    "                    # Run the scaling method\n",
    "                    losses, test_risk, eff_gains, rcp_lams = apply_risk_control(conf, acc, gt, rel, lambdas, eps_grid, \n",
    "                                                                                delta, n_cal, n_trials, 'scaling')\n",
    "                    # Save results\n",
    "                    scaling_path = precomputed_risk_path + 'scaling/' + dataset + '/' + model_name + '/' + label + '/'\n",
    "                    if not os.path.exists(scaling_path):\n",
    "                        os.makedirs(scaling_path)\n",
    "                    np.save(scaling_path + 'losses.npy', losses)\n",
    "                    np.save(scaling_path + 'test_risk.npy', test_risk)\n",
    "                    np.save(scaling_path + 'eff_gains.npy', eff_gains)\n",
    "                    np.save(scaling_path + 'rcp_lams.npy', rcp_lams)\n",
    "            else:\n",
    "                print('Missing data: ', dataset, model_name)\n",
    "        print('Finished', dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "9b3c1c60-e536-417a-bb35-702cab84e5c2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def filter(dat, idx):\n",
    "    return [dat[i] for i in idx if 0 <= i < len(dat)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "83e618c3-6c54-4318-9449-5bb989babe7b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Finished sst2\n",
      "Finished trec\n",
      "Finished financial_phrasebank\n",
      "Finished tweeteval_hate\n",
      "Finished tweeteval_feminist\n",
      "Finished tweeteval_atheism\n",
      "Finished unnatural\n",
      "Finished ag_news\n"
     ]
    }
   ],
   "source": [
    "# Pre-compute and save the risk control results matrices for combined correct + incorrect demos\n",
    "if not os.path.exists(precomputed_mixed_path):\n",
    "    for dataset in datasets:\n",
    "        for model_name, n_early_exit, tokenizer in zip(models, n_early_exits, tokenizers):\n",
    "            label_order = get_label_order(dataset, tokenizer)\n",
    "            if fake_labels:\n",
    "                label_order = [fake_label_map[x] for x in label_order]\n",
    "            \n",
    "            base_dir = results_folder + '/n_demos_' + str(n_demos) + '/' + dataset + '/' + model_name + '/'\n",
    "            # First check that there exists all types of experiments\n",
    "            if (os.path.exists(base_dir + 'correct.json') and os.path.exists(base_dir + 'incorrect.json')\n",
    "                    and os.path.exists(base_dir + 'zeroshot.json')):\n",
    "                # We have all the data\n",
    "                data = {}\n",
    "                for expt_type in ['correct', 'incorrect', 'zeroshot']:\n",
    "                    with open(base_dir + expt_type + '.json', 'r') as file:\n",
    "                        data[expt_type] = json.load(file)\n",
    "                    \n",
    "                c_gt, i_gt, z_gt = get_ground_truth_by_type(ground_truth_type, data['correct'], data['incorrect'], data['zeroshot'], n_early_exit)\n",
    "                c_rel, i_rel, z_rel = get_relative_labels(relative_labels, data['correct'], data['incorrect'], data['zeroshot'], n_early_exit)\n",
    "    \n",
    "                # Sample proportions of the data as needed\n",
    "                c_idx = random.sample(range(len(c_gt)), int(c_prop/100 * len(c_gt))) \n",
    "                i_idx = random.sample(range(len(i_gt)), int(i_prop/100 * len(i_gt)))\n",
    "                c_gt, i_gt, c_rel, i_rel = filter(c_gt, c_idx), filter(i_gt, i_idx), filter(c_rel, c_idx), filter(i_rel, i_idx)\n",
    "                cd, id = {}, {}\n",
    "                for col in data['correct']:\n",
    "                    cd[col] = filter(data['correct'][col], c_idx)\n",
    "                    id[col] = filter(data['incorrect'][col], i_idx)\n",
    "                data['correct'], data['incorrect'] = cd, id\n",
    "    \n",
    "                # Combine correct and incorrect\n",
    "                gt, rel = c_gt + i_gt, c_rel + i_rel\n",
    "                combined_data = {}\n",
    "                for col in data['correct']:\n",
    "                    combined_data[col] = data['correct'][col] + data['incorrect'][col]\n",
    "    \n",
    "                conf = get_all_confidences(combined_data, n_early_exit, label_order, confidence_type, first_exit)\n",
    "                acc = get_all_accuracies(combined_data, gt, n_early_exit, first_exit)\n",
    "                n_cal = int(len(combined_data['0'])/2)\n",
    "                \n",
    "                # Run the max-0 method\n",
    "                #print(conf.shape, acc.shape)\n",
    "                losses, test_risk, eff_gains, rcp_lams = apply_risk_control(conf, acc, gt, rel, lambdas, eps_grid, \n",
    "                                                                            delta, n_cal, n_trials, 'max_0')\n",
    "                # Save confidence + accuracy matrices and labels\n",
    "                conf_acc_path = precomputed_mixed_path + dataset + '/' + model_name + '/'\n",
    "                if not os.path.exists(conf_acc_path):\n",
    "                    os.makedirs(conf_acc_path)\n",
    "                np.save(conf_acc_path + 'conf.npy', conf)\n",
    "                np.save(conf_acc_path + 'acc.npy', acc)\n",
    "                with open(conf_acc_path + 'true_labels.txt', \"w\") as file:\n",
    "                    for item in gt:\n",
    "                        file.write(f\"{item}\\n\")\n",
    "\n",
    "                with open(conf_acc_path + 'relative_labels.txt', \"w\") as file:\n",
    "                    for item in rel:\n",
    "                        file.write(f\"{item}\\n\")\n",
    "                \n",
    "                # Save results\n",
    "                max0_path = precomputed_mixed_path + 'max0/' + dataset + '/' + model_name + '/'\n",
    "                if not os.path.exists(max0_path):\n",
    "                    os.makedirs(max0_path)\n",
    "                np.save(max0_path + 'losses.npy', losses)\n",
    "                np.save(max0_path + 'test_risk.npy', test_risk)\n",
    "                np.save(max0_path + 'eff_gains.npy', eff_gains)\n",
    "                np.save(max0_path + 'rcp_lams.npy', rcp_lams)\n",
    "                # Run the scaling method\n",
    "                losses, test_risk, eff_gains, rcp_lams = apply_risk_control(conf, acc, gt, rel, lambdas, eps_grid, \n",
    "                                                                            delta, n_cal, n_trials, 'scaling')\n",
    "                # Save results\n",
    "                scaling_path = precomputed_mixed_path + 'scaling/' + dataset + '/' + model_name + '/'\n",
    "                if not os.path.exists(scaling_path):\n",
    "                    os.makedirs(scaling_path)\n",
    "                np.save(scaling_path + 'losses.npy', losses)\n",
    "                np.save(scaling_path + 'test_risk.npy', test_risk)\n",
    "                np.save(scaling_path + 'eff_gains.npy', eff_gains)\n",
    "                np.save(scaling_path + 'rcp_lams.npy', rcp_lams)\n",
    "    \n",
    "            else:\n",
    "                print('Missing data: ', dataset, model_name)\n",
    "        print('Finished', dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "84d2b7d6-6071-4f45-a56d-6daef59037fa",
   "metadata": {},
   "source": [
    "# Plotting Helper Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "bac0d40d-0f2e-4515-96b4-fd3a73db97fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_rc_fixed_lambda(eps_grid, rcp_type, rcp_lams, losses, exits):\n",
    "    test_risk_e, test_risk_err, eff_gains_e, eff_gains_err = [], [], [], []\n",
    "    for e, eps in enumerate(eps_grid):\n",
    "        lam_id = rcp_lams[e]\n",
    "        if lam_id is None:\n",
    "            test_risk_e.append(default_loss_uncontrolled_risk)\n",
    "            eff_gains_e.append(default_eff_gain_uncontrolled_risk)\n",
    "            test_risk_err.append(0)\n",
    "            eff_gains_err.append(0)\n",
    "        else:\n",
    "            test_risk_e.append(losses[lam_id].mean())\n",
    "            eff_gains_e.append(exits[lam_id].mean())\n",
    "            test_risk_err.append(losses[lam_id].std(axis=0) / np.sqrt(losses[lam_id].shape[0]))\n",
    "            eff_gains_err.append(exits[lam_id].std(axis=0) / np.sqrt(exits[lam_id].shape[0]))\n",
    "\n",
    "    return np.array(test_risk_e), np.array(eff_gains_e), np.array(test_risk_err), np.array(eff_gains_err)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "26e35a15-b255-497f-88db-e75c1f221cc3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_risk_control(ax, eps_grid, losses, test_risk, label, color, linestyle):\n",
    "    risk_mean, risk_err = test_risk.mean(axis=0), test_risk.std(axis=0) / np.sqrt(test_risk.shape[0])\n",
    "    # Limit to the available epsilons\n",
    "    risk_mean, risk_err = risk_mean[:len(eps_grid)], risk_err[:len(eps_grid)]\n",
    "    ax.plot(eps_grid, risk_mean, label=label, color=color, linestyle=linestyle)\n",
    "    ax.fill_between(eps_grid, risk_mean - risk_err, risk_mean + risk_err, alpha=0.2, color=color)\n",
    "    # add a diagonal line and axis labels\n",
    "    ax.plot([min(eps_grid), max(eps_grid)], [min(eps_grid), max(eps_grid)], 'k--')\n",
    "    ax.set_ylabel('Risk')\n",
    "    ax.set_xlabel('epsilon')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "2b45c758-dee5-49b5-a4df-df07371303b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_separate_risk(ax, eps_grid, c_mean, c_err, i_mean, i_err, color):\n",
    "    ax.plot(eps_grid, c_mean, label='correct', color=color, linestyle='solid')\n",
    "    ax.fill_between(eps_grid, c_mean - c_err, c_mean + c_err, alpha=0.2, color=color)\n",
    "    ax.plot(eps_grid, i_mean, label='incorrect', color=color, linestyle='dotted')\n",
    "    ax.fill_between(eps_grid, i_mean - i_err, i_mean + i_err, alpha=0.2, color=color)\n",
    "    ax.set_ylabel('Risk')\n",
    "    ax.set_xlabel('epsilon')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "5a676887-1688-454c-893e-34bc2f90ba19",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_efficiency_gains(ax, eps_grid, eff_gains, label, color, linestyle):\n",
    "    exit_mean, exit_err = n_early_exit - eff_gains.mean(axis=0), eff_gains.std(axis=0) / np.sqrt(eff_gains.shape[0])\n",
    "    exit_mean, exit_err = exit_mean[:len(eps_grid)], exit_err[:len(eps_grid)]\n",
    "    ax.plot(eps_grid, exit_mean, label=label, color=color, linestyle=linestyle)\n",
    "    ax.fill_between(eps_grid, exit_mean - exit_err, exit_mean + exit_err, alpha=0.2, color=color)\n",
    "    ax.set_ylabel('Average Exit Layer')\n",
    "    ax.set_xlabel('epsilon')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "e5dfd76b-6ced-4f03-b0c4-6a2a29181384",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_lambda_vs_risk(ax, lambdas, losses, label, color, linestyle):\n",
    "    # Create plot of lambda vs risk\n",
    "    ax.plot(lambdas, losses.mean(axis=1), label=label, color=color, linestyle=linestyle)\n",
    "    # Add error bars\n",
    "    err = np.std(losses, axis=1) / np.sqrt(losses.shape[1])\n",
    "    ax.fill_between(lambdas, losses.mean(axis=1) - err, losses.mean(axis=1) + err, alpha=0.2, color=color)\n",
    "    ax.set_xlabel('Lambda')\n",
    "    ax.set_ylabel('Empirical Risk')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4c743cb3-efdd-4aba-85b9-011304c104bd",
   "metadata": {},
   "source": [
    "# Save Legends\n",
    "Save various legends for the plots below, so they do not cover results from the plot."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "7e7f19f4-2ad1-4e12-969d-34326e5a9780",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Incorrect/correct legend\n",
    "ic_lines = [Line2D([0], [0], color='black', lw=5, linestyle=sty, alpha=0.5) for sty in ['dashed', 'solid']]\n",
    "ic_labels = ['incorrect', 'correct']\n",
    "\n",
    "figlegend = plt.figure()\n",
    "figlegend.legend(ic_lines, ic_labels, loc='center', fontsize=20, ncol=2)\n",
    "figlegend.savefig(plot_directory.split(\"/\")[1] + \"/legends_only/incorrect_correct_legend.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "7577004b-b29d-4df2-8458-8c30cd8e99d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Model legend: only models\n",
    "model_lines = [Line2D([0], [0], color=model_colors[model], lw=5, linestyle='-',) for model in models]\n",
    "model_labels = [x.split('/')[1] for x in models]\n",
    "\n",
    "figlegend = plt.figure()\n",
    "figlegend.legend(model_lines, model_labels, loc='center', fontsize=20, ncol=len(models))\n",
    "figlegend.savefig(plot_directory.split(\"/\")[1] + \"/legends_only/models_legend.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "e163b046-f974-49e4-aba6-795527c2655c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Correct vs incorrect, with and without clipping (appendix figure)\n",
    "lines1 = [Line2D([0], [0], color=c_color, lw=5, linestyle='solid'), Line2D([0], [0], color=i_color, lw=5, linestyle='solid'), \n",
    "              Line2D([0], [0], color=c_color, lw=5, linestyle='dashed'), Line2D([0], [0], color=i_color, lw=5, linestyle='dashed')]\n",
    "labels1 = ['correct with scaling', 'incorrect with scaling', 'correct with clipping', 'incorrect with clipping']\n",
    "\n",
    "figlegend = plt.figure()\n",
    "figlegend.legend(lines1, labels1, loc='center', fontsize=20, ncol=4)\n",
    "figlegend.savefig(plot_directory.split(\"/\")[1] + \"/legends_only/ic_scaled_clipped_legend.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "8f247c3f-0013-4b5d-8ad6-3be0607914d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Correct vs incorrect, with and without demos (Fig 3)\n",
    "lines1 = [Line2D([0], [0], color=c_color, lw=5, linestyle='solid'), Line2D([0], [0], color=i_color, lw=5, linestyle='solid'), \n",
    "              Line2D([0], [0], color=c_color, lw=5, linestyle='dashed'), Line2D([0], [0], color=i_color, lw=5, linestyle='dashed')]\n",
    "labels1 = ['correct demos', 'incorrect demos', 'correct full model', 'incorrect full model']\n",
    "\n",
    "figlegend = plt.figure()\n",
    "figlegend.legend(lines1, labels1, loc='center', fontsize=20, ncol=4)\n",
    "figlegend.savefig(plot_directory.split(\"/\")[1] + \"/legends_only/c_vs_i_legend.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "9f615fb3-da8c-412e-a6f8-8c124558a760",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Clipped vs scaled risk\n",
    "cs_lines = [Line2D([0], [0], color='black', lw=5, linestyle=sty, alpha=0.5) for sty in ['dashed', 'solid']]\n",
    "cs_labels = ['clipped risk', 'scaled risk']\n",
    "\n",
    "figlegend = plt.figure()\n",
    "figlegend.legend(cs_lines, cs_labels, loc='center', fontsize=20, ncol=2)\n",
    "figlegend.savefig(plot_directory.split(\"/\")[1] + \"/legends_only/clipped_scaled_legend.pdf\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3eef923b-87f9-4f80-a660-f4b713e72738",
   "metadata": {},
   "source": [
    "# Bar Plots - Losses\n",
    "This is a small figure in sec. 3.2 that shows why scaling makes sense - by better preserving the underlying distribution. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "817be5bf-380e-411e-b7a7-713f060aef19",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Finished sst2\n",
      "Finished trec\n",
      "Finished financial_phrasebank\n",
      "Finished tweeteval_hate\n",
      "Finished tweeteval_feminist\n",
      "Finished tweeteval_atheism\n",
      "Finished unnatural\n",
      "Finished ag_news\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/var/folders/1g/wx_xh_xx44bdx2g8gcx64pbr0000gn/T/ipykernel_34604/1038406289.py:71: RuntimeWarning: More than 20 figures have been opened. Figures created through the pyplot interface (`matplotlib.pyplot.figure`) are retained until explicitly closed and may consume too much memory. (To control this warning, see the rcParam `figure.max_open_warning`). Consider using `matplotlib.pyplot.close()`.\n",
      "  fig, ax = plt.subplots(1,1,figsize=(5,5))\n"
     ]
    }
   ],
   "source": [
    "for dataset in datasets:\n",
    "    all_counts = []\n",
    "    for model_name, n_early_exit, tokenizer in zip(models, n_early_exits, tokenizers):\n",
    "        label_order = get_label_order(dataset, tokenizer)\n",
    "        if fake_labels:\n",
    "            label_order = [fake_label_map[x] for x in label_order]\n",
    "        \n",
    "        base_dir = results_folder + '/n_demos_' + str(n_demos) + '/' + dataset + '/' + model_name + '/'\n",
    "        # First check that there exists all types of experiments\n",
    "        if (os.path.exists(base_dir + 'correct.json') and os.path.exists(base_dir + 'incorrect.json')\n",
    "                and os.path.exists(base_dir + 'zeroshot.json')):\n",
    "            # We have all the data\n",
    "            data = {}\n",
    "            for expt_type in ['correct', 'incorrect', 'zeroshot']:\n",
    "                with open(base_dir + expt_type + '.json', 'r') as file:\n",
    "                    data[expt_type] = json.load(file)\n",
    "                \n",
    "            c_gt, i_gt, z_gt = get_ground_truth_by_type(ground_truth_type, data['correct'], data['incorrect'], data['zeroshot'], n_early_exit)\n",
    "            c_rel, i_rel, z_rel = get_relative_labels(relative_labels, data['correct'], data['incorrect'], data['zeroshot'], n_early_exit)\n",
    "\n",
    "            # Sample proportions of the data as needed\n",
    "            c_idx = random.sample(range(len(c_gt)), int(c_prop/100 * len(c_gt))) \n",
    "            i_idx = random.sample(range(len(i_gt)), int(i_prop/100 * len(i_gt)))\n",
    "            c_gt, i_gt, c_rel, i_rel = filter(c_gt, c_idx), filter(i_gt, i_idx), filter(c_rel, c_idx), filter(i_rel, i_idx)\n",
    "            cd, id = {}, {}\n",
    "            for col in data['correct']:\n",
    "                cd[col] = filter(data['correct'][col], c_idx)\n",
    "                id[col] = filter(data['incorrect'][col], i_idx)\n",
    "            data['correct'], data['incorrect'] = cd, id\n",
    "\n",
    "            # Combine correct and incorrect\n",
    "            gt, rel = c_gt + i_gt, c_rel + i_rel\n",
    "            combined_data = {}\n",
    "            for col in data['correct']:\n",
    "                combined_data[col] = data['correct'][col] + data['incorrect'][col]\n",
    "\n",
    "            conf = get_all_confidences(combined_data, n_early_exit, label_order, confidence_type, first_exit)\n",
    "            acc = get_all_accuracies(combined_data, gt, n_early_exit, first_exit)\n",
    "            n_cal = int(len(combined_data['0'])/2)\n",
    "            \n",
    "            losses, exits = get_losses_and_exits_confidence(conf, acc, lambdas, rel, gt, True)\n",
    "            unique_values, counts = np.unique(losses, return_counts=True)\n",
    "            if len(all_counts) == 0:\n",
    "                all_counts = counts\n",
    "            else:\n",
    "                all_counts = [a + b for a,b in zip(all_counts, counts)]\n",
    "\n",
    "    all_counts = all_counts / sum(all_counts) * 100\n",
    "    \n",
    "    # Create the bar chart\n",
    "    fig, ax = plt.subplots(1,1,figsize=(5,5))\n",
    "    ax.bar(unique_values, all_counts, width=0.6, color='skyblue', edgecolor='black')\n",
    "    ax.set_xticks([-1, 0, 1])\n",
    "    \n",
    "    # Add labels and title\n",
    "    ax.set_xlabel('Loss')\n",
    "    ax.set_ylabel('% of Data')\n",
    "    ax.set_title('ICL Loss Distribution')\n",
    "\n",
    "    if debug_mode:\n",
    "        # Display the image\n",
    "        plt.show()\n",
    "    else:\n",
    "        # Save out the image\n",
    "        path = plot_directory + mixed_filename + '/losses_barplot/' \n",
    "        if not os.path.exists(path):\n",
    "            os.makedirs(path)\n",
    "        plt.savefig(path + dataset + '_scaled.pdf')\n",
    "\n",
    "    # Create the bar chart for clipped losses\n",
    "    fig, ax = plt.subplots(1,1,figsize=(5,5))\n",
    "    ax.bar([0,1], [all_counts[0] + all_counts[1], all_counts[2]], width=0.6, color='skyblue', edgecolor='black')\n",
    "    ax.set_xticks([-1, 0, 1])\n",
    "    \n",
    "    # Add labels and title\n",
    "    ax.set_xlabel('Loss')\n",
    "    ax.set_ylabel('% of Data')\n",
    "    ax.set_title('Clipped ICL Loss Distribution')\n",
    "\n",
    "    if debug_mode:\n",
    "        # Display the image\n",
    "        plt.show()\n",
    "    else:\n",
    "        # Save out the image\n",
    "        path = plot_directory + mixed_filename + '/losses_barplot/' \n",
    "        if not os.path.exists(path):\n",
    "            os.makedirs(path)\n",
    "        plt.savefig(path + dataset + '_clipped.pdf')\n",
    "\n",
    "    print(\"Finished\", dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9073383f-ca66-49db-add7-ee1d7c260902",
   "metadata": {},
   "source": [
    "## Incorrect/Correct Sub-group Plots - Risk Control\n",
    "Paper: fig 6, 13, 14"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "709c9c8a-ce17-40e7-8a14-8fe1c4eda63e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "scaling ------------------------\n",
      "Finished sst2\n",
      "Finished trec\n",
      "Finished financial_phrasebank\n",
      "Finished tweeteval_hate\n",
      "Finished tweeteval_feminist\n",
      "Finished tweeteval_atheism\n",
      "Finished unnatural\n",
      "Finished ag_news\n",
      "max0 ------------------------\n",
      "Finished sst2\n",
      "Finished trec\n",
      "Finished financial_phrasebank\n",
      "Finished tweeteval_hate\n",
      "Finished tweeteval_feminist\n",
      "Finished tweeteval_atheism\n",
      "Finished unnatural\n",
      "Finished ag_news\n"
     ]
    }
   ],
   "source": [
    "# Split plots - scaling or max0\n",
    "for subgroup_plot_type in ['scaling', 'max0']:\n",
    "    print(subgroup_plot_type, '------------------------')\n",
    "    for dataset in datasets:\n",
    "        fig, ax = plt.subplots(1,1,figsize=(5,5))\n",
    "        for model_name, n_early_exit, tokenizer in zip(models, n_early_exits, tokenizers):\n",
    "            # First check that all experiment results are precomputed\n",
    "            if os.path.exists(precomputed_mixed_path + 'scaling/' + dataset + '/' + model_name + '/'):\n",
    "                label_order = get_label_order(dataset, tokenizer)\n",
    "                if fake_labels:\n",
    "                    label_order = [fake_label_map[x] for x in label_order]\n",
    "                \n",
    "                # load the data\n",
    "                base_dir = results_folder + '/n_demos_' + str(n_demos) + '/' + dataset + '/' + model_name + '/'\n",
    "                data = {}\n",
    "                for expt_type in ['correct', 'incorrect', 'zeroshot']:\n",
    "                    with open(base_dir + expt_type + '.json', 'r') as file:\n",
    "                        data[expt_type] = json.load(file)\n",
    "    \n",
    "                c_gt, i_gt, z_gt = get_ground_truth_by_type(ground_truth_type, data['correct'], data['incorrect'], data['zeroshot'], n_early_exit)\n",
    "                c_rel, i_rel, z_rel = get_relative_labels(relative_labels, data['correct'], data['incorrect'], data['zeroshot'], n_early_exit)\n",
    "    \n",
    "                # Combine correct and incorrect\n",
    "                gt, rel = c_gt + i_gt, c_rel + i_rel\n",
    "                combined_data = {}\n",
    "                for col in data['correct']:\n",
    "                    combined_data[col] = data['correct'][col] + data['incorrect'][col]\n",
    "                \n",
    "                # get the pre-computed lambda-hat's\n",
    "                scaling_path = precomputed_mixed_path + subgroup_plot_type + '/' + dataset + '/' + model_name + '/'\n",
    "                rcp_lams = np.load(scaling_path + 'rcp_lams.npy', allow_pickle=True)\n",
    "    \n",
    "                # compute + plot epsilon vs risk for correct demos\n",
    "                conf = get_all_confidences(data['correct'], n_early_exit, label_order, confidence_type, first_exit)\n",
    "                acc = get_all_accuracies(data['correct'], c_gt, n_early_exit, first_exit)\n",
    "                losses, exits = get_losses_and_exits_confidence(conf, acc, lambdas, c_rel, c_gt)\n",
    "                c_mean, _, c_err, _ = compute_rc_fixed_lambda(eps_grid, rcp_type, rcp_lams, losses, exits)\n",
    "    \n",
    "                # compute + plot epsilon vs risk for incorrect demos\n",
    "                conf = get_all_confidences(data['incorrect'], n_early_exit, label_order, confidence_type, first_exit)\n",
    "                acc = get_all_accuracies(data['incorrect'], c_gt, n_early_exit, first_exit)\n",
    "                losses, exits = get_losses_and_exits_confidence(conf, acc, lambdas, c_rel, c_gt)\n",
    "                i_mean, _, i_err, _ = compute_rc_fixed_lambda(eps_grid, rcp_type, rcp_lams, losses, exits)\n",
    "    \n",
    "                # plot correct/incorrect separately\n",
    "                plot_separate_risk(ax, eps_grid, c_mean, c_err, i_mean, i_err, model_colors[model_name])\n",
    "            else:\n",
    "                print('Missing data:', model_name, dataset)\n",
    "\n",
    "        if display_legends:\n",
    "            lines2 = [Line2D([0], [0], color='black', lw=5, linestyle=sty, alpha=0.5) for sty in ['dashed', 'solid']]\n",
    "            labels2 = ['incorrect', 'correct']\n",
    "            \n",
    "            lines1 = [Line2D([0], [0], color=model_colors[model], lw=5, linestyle='-',) for model in models]\n",
    "            lines1 += lines2\n",
    "            labels1 = [x.split('/')[1] for x in models] + labels2\n",
    "            \n",
    "            # legend1 = plt.legend(lines1, labels1, loc='upper left',)\n",
    "            # Save the legend as its own separate thing\n",
    "            figlegend = plt.figure()\n",
    "            figlegend.legend(lines1, labels1, loc='center', fontsize=20)\n",
    "            \n",
    "        ax.plot([min(eps_grid), max(eps_grid)], [min(eps_grid), max(eps_grid)], 'k--')\n",
    "        plt.tight_layout()\n",
    "    \n",
    "        if debug_mode:\n",
    "            # Display the image\n",
    "            plt.show()\n",
    "        else:\n",
    "            # Save out the image\n",
    "            path = plot_directory + mixed_filename + '/rc_split_' + subgroup_plot_type + '/' \n",
    "            if not os.path.exists(path):\n",
    "                os.makedirs(path)\n",
    "            plt.savefig(path + dataset + '.pdf')\n",
    "            #figlegend.savefig(path + \"legend.pdf\")\n",
    "    \n",
    "            # Close plots to save memory\n",
    "            matplotlib.pyplot.close()\n",
    "    \n",
    "        print('Finished', dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "04b5ac86-f68e-4fd8-9435-bd9e727a8573",
   "metadata": {},
   "source": [
    "## Risk Control - Epsilon vs Risk\n",
    "Paper: fig 5, 9-12"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "4fc93f01-00ac-456d-8524-2c4b8c1336f3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Finished sst2\n",
      "Finished trec\n",
      "Finished financial_phrasebank\n",
      "Finished tweeteval_hate\n",
      "Finished tweeteval_feminist\n",
      "Finished tweeteval_atheism\n",
      "Finished unnatural\n",
      "Finished ag_news\n"
     ]
    }
   ],
   "source": [
    "# Risk control plots - with only risk control on combined data\n",
    "for dataset in datasets:\n",
    "    fig, ax = plt.subplots(1,1,figsize=(5,5))\n",
    "    for model_name, n_early_exit, tokenizer in zip(models, n_early_exits, tokenizers):\n",
    "        # First check that all experiment results are precomputed\n",
    "        if os.path.exists(precomputed_mixed_path + 'scaling/' + dataset + '/' + model_name + '/'):\n",
    "            label_order = get_label_order(dataset, tokenizer)\n",
    "            if fake_labels:\n",
    "                label_order = [fake_label_map[x] for x in label_order]\n",
    "            \n",
    "            # load the data\n",
    "            base_dir = results_folder + '/n_demos_' + str(n_demos) + '/' + dataset + '/' + model_name + '/'\n",
    "            data = {}\n",
    "            for expt_type in ['correct', 'incorrect', 'zeroshot']:\n",
    "                with open(base_dir + expt_type + '.json', 'r') as file:\n",
    "                    data[expt_type] = json.load(file)\n",
    "\n",
    "            c_gt, i_gt, z_gt = get_ground_truth_by_type(ground_truth_type, data['correct'], data['incorrect'], data['zeroshot'], n_early_exit)\n",
    "            c_rel, i_rel, z_rel = get_relative_labels(relative_labels, data['correct'], data['incorrect'], data['zeroshot'], n_early_exit)\n",
    "\n",
    "            # Combine correct and incorrect\n",
    "            gt, rel = c_gt + i_gt, c_rel + i_rel\n",
    "            combined_data = {}\n",
    "            for col in data['correct']:\n",
    "                combined_data[col] = data['correct'][col] + data['incorrect'][col]\n",
    "            \n",
    "            # get the pre-computed lambda-hat's\n",
    "            scaling_path = precomputed_mixed_path + 'scaling/' + dataset + '/' + model_name + '/'\n",
    "            rcp_lams = np.load(scaling_path + 'rcp_lams.npy', allow_pickle=True)\n",
    "\n",
    "            # plot epsilon vs risk for ALL demos (50-50 split)\n",
    "            test_risk, losses = np.load(scaling_path + 'test_risk.npy'), np.load(scaling_path + 'losses.npy')\n",
    "            plot_risk_control(ax, eps_grid, losses, test_risk, model_name.split('/')[1], model_colors[model_name], 'solid')\n",
    "        else:\n",
    "            print('Missing data:', model_name, dataset)\n",
    "\n",
    "    if display_legends:\n",
    "        ax.legend()\n",
    "\n",
    "    plt.title(dataset_names_plot_titles[dataset])\n",
    "    plt.tight_layout()\n",
    "\n",
    "    if debug_mode:\n",
    "        # Display the image\n",
    "        plt.show()\n",
    "    else:\n",
    "        # Save out the image\n",
    "        path = plot_directory + mixed_filename + '/risk_control/'\n",
    "        if not os.path.exists(path):\n",
    "            os.makedirs(path)\n",
    "        plt.savefig(path + dataset + '.pdf')\n",
    "\n",
    "    # Close plots to save memory\n",
    "    matplotlib.pyplot.close()\n",
    "\n",
    "    print('Finished', dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b235cd0b-50a4-49ae-9898-d064e6a914e6",
   "metadata": {},
   "source": [
    "## Risk Control + Efficiency Gains - Comparison (Clipping vs Our Approach)\n",
    "Paper: fig 4, 18"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "64eb6c2d-788f-4648-9dbe-6aad345b3efe",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Finished sst2\n",
      "Finished trec\n",
      "Finished financial_phrasebank\n",
      "Finished tweeteval_hate\n",
      "Finished tweeteval_feminist\n",
      "Finished tweeteval_atheism\n",
      "Finished unnatural\n",
      "Finished ag_news\n"
     ]
    }
   ],
   "source": [
    "# Plot risk control results and exit layer, and with comparison to the other approach\n",
    "# Also study the average efficiency gains\n",
    "for dataset in datasets:\n",
    "    fig, ax = plt.subplots(1,2,figsize=(10,5))\n",
    "    for model_name, n_early_exit, tokenizer in zip(models, n_early_exits, tokenizers):\n",
    "        if os.path.exists(precomputed_mixed_path + 'scaling/' + dataset + '/' + model_name + '/'):\n",
    "            label_order = get_label_order(dataset, tokenizer)\n",
    "            if fake_labels:\n",
    "                label_order = [fake_label_map[x] for x in label_order]\n",
    "            \n",
    "            base_dir = results_folder + '/n_demos_' + str(n_demos) + '/' + dataset + '/' + model_name + '/'\n",
    "            # We have all the data\n",
    "            data = {}\n",
    "            for expt_type in ['correct', 'incorrect', 'zeroshot']:\n",
    "                with open(base_dir + expt_type + '.json', 'r') as file:\n",
    "                    data[expt_type] = json.load(file)\n",
    "\n",
    "            c_gt, i_gt, z_gt = get_ground_truth_by_type(ground_truth_type, data['correct'], data['incorrect'], data['zeroshot'], n_early_exit)\n",
    "            c_rel, i_rel, z_rel = get_relative_labels(relative_labels, data['correct'], data['incorrect'], data['zeroshot'], n_early_exit)\n",
    "\n",
    "            # Combine correct and incorrect\n",
    "            gt, rel = c_gt + i_gt, c_rel + i_rel\n",
    "            combined_data = {}\n",
    "            for col in data['correct']:\n",
    "                combined_data[col] = data['correct'][col] + data['incorrect'][col]\n",
    "            \n",
    "            # max-0\n",
    "            max0_path = precomputed_mixed_path + 'max0/' + dataset + '/' + model_name + '/'\n",
    "            #eff_gains, rcp_lams = np.load(max0_path + 'eff_gains.npy'), np.load(max0_path + 'rcp_lams.npy', allow_pickle=True)\n",
    "            losses, test_risk = np.load(max0_path + 'losses.npy'), np.load(max0_path + 'test_risk.npy')\n",
    "            eff_gains, rcp_lams = np.load(max0_path + 'eff_gains.npy'), np.load(max0_path + 'rcp_lams.npy', allow_pickle=True)\n",
    "            # using the rcp_lams compute risk\n",
    "            # conf = get_all_confidences(combined_data, n_early_exit, label_order, confidence_type, first_exit)\n",
    "            # acc = get_all_accuracies(combined_data, gt, n_early_exit, first_exit)\n",
    "            # losses, exits = get_losses_and_exits_confidence(conf, acc, lambdas, rel, gt)\n",
    "            # mean, _, err, _ = compute_rc_fixed_lambda(eps_grid, rcp_type, rcp_lams, losses, exits)\n",
    "            # mean, err = mean[:len(eps_grid)], err[:len(eps_grid)]\n",
    "            # ax[0].plot(eps_grid, mean, label='clipped risk', color=model_colors[model_name], linestyle='dashed')\n",
    "            # ax[0].fill_between(eps_grid, mean - err, mean + err, alpha=0.2, color=model_colors[model_name])\n",
    "            plot_risk_control(ax[0], eps_grid, losses, test_risk, 'clipped risk', model_colors[model_name], 'dashed')\n",
    "            plot_efficiency_gains(ax[1], eps_grid, eff_gains, 'clipped risk', model_colors[model_name], 'dashed')\n",
    "            # using the rcp_lams compute\n",
    "            scaling_path = precomputed_mixed_path + 'scaling/' + dataset + '/' + model_name + '/' \n",
    "            losses, test_risk = np.load(scaling_path + 'losses.npy'), np.load(scaling_path + 'test_risk.npy')\n",
    "            eff_gains, rcp_lams = np.load(scaling_path + 'eff_gains.npy'), np.load(scaling_path + 'rcp_lams.npy', allow_pickle=True)\n",
    "            plot_risk_control(ax[0], eps_grid, losses, test_risk, 'risk transformation', model_colors[model_name], 'solid')\n",
    "            plot_efficiency_gains(ax[1], eps_grid, eff_gains, 'risk transformation', model_colors[model_name], 'solid')\n",
    "        else:\n",
    "            print('Missing data: ', dataset, model_name)\n",
    "\n",
    "    if display_legends:\n",
    "        lines2 = [Line2D([0], [0], color='black', lw=5, linestyle=sty, alpha=0.5) for sty in ['dashed', 'solid']]\n",
    "        labels2 = ['clipped risk', 'risk transformation']\n",
    "        \n",
    "        lines1 = [Line2D([0], [0], color=model_colors[model], lw=5, linestyle='-',) for model in models]\n",
    "        labels1 = [x.split('/')[1] for x in models]\n",
    "        \n",
    "        legend1 = ax[0].legend(lines1, labels1, loc='upper left',)\n",
    "        legend2 = ax[1].legend(lines2, labels2, loc='upper right',)\n",
    "    \n",
    "    plt.tight_layout()\n",
    "    \n",
    "    if debug_mode:\n",
    "        # Display the image\n",
    "        plt.show()\n",
    "    else:\n",
    "        # Save out the image\n",
    "        path = plot_directory + mixed_filename + '/risk_control_with_comparison/'\n",
    "        if not os.path.exists(path):\n",
    "            os.makedirs(path)\n",
    "        plt.savefig(path + dataset + '.pdf')\n",
    "\n",
    "    # Close plots to save memory\n",
    "    matplotlib.pyplot.close()\n",
    "\n",
    "    print('Finished', dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "73268017-ee32-472e-a771-6003129c84af",
   "metadata": {},
   "source": [
    "# Risk Control Only - Comparison (Scaling vs Clipping)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "2e123a40-3400-4a8b-b0d3-53142b5e2bfe",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Finished sst2\n",
      "Finished trec\n",
      "Finished financial_phrasebank\n",
      "Finished tweeteval_hate\n",
      "Finished tweeteval_feminist\n",
      "Finished tweeteval_atheism\n",
      "Finished unnatural\n",
      "Finished ag_news\n"
     ]
    }
   ],
   "source": [
    "# Plot risk control results with comparison to the other approach\n",
    "for dataset in datasets:\n",
    "    fig, ax = plt.subplots(1,1,figsize=(5,5))\n",
    "    for model_name, n_early_exit, tokenizer in zip(models, n_early_exits, tokenizers):\n",
    "        if os.path.exists(precomputed_mixed_path + 'scaling/' + dataset + '/' + model_name + '/'):\n",
    "            label_order = get_label_order(dataset, tokenizer)\n",
    "            if fake_labels:\n",
    "                label_order = [fake_label_map[x] for x in label_order]\n",
    "            \n",
    "            base_dir = results_folder + '/n_demos_' + str(n_demos) + '/' + dataset + '/' + model_name + '/'\n",
    "            # We have all the data\n",
    "            data = {}\n",
    "            for expt_type in ['correct', 'incorrect', 'zeroshot']:\n",
    "                with open(base_dir + expt_type + '.json', 'r') as file:\n",
    "                    data[expt_type] = json.load(file)\n",
    "\n",
    "            c_gt, i_gt, z_gt = get_ground_truth_by_type(ground_truth_type, data['correct'], data['incorrect'], data['zeroshot'], n_early_exit)\n",
    "            c_rel, i_rel, z_rel = get_relative_labels(relative_labels, data['correct'], data['incorrect'], data['zeroshot'], n_early_exit)\n",
    "\n",
    "            # Combine correct and incorrect\n",
    "            gt, rel = c_gt + i_gt, c_rel + i_rel\n",
    "            combined_data = {}\n",
    "            for col in data['correct']:\n",
    "                combined_data[col] = data['correct'][col] + data['incorrect'][col]\n",
    "            \n",
    "            # max-0\n",
    "            max0_path = precomputed_mixed_path + 'max0/' + dataset + '/' + model_name + '/'\n",
    "            #eff_gains, rcp_lams = np.load(max0_path + 'eff_gains.npy'), np.load(max0_path + 'rcp_lams.npy', allow_pickle=True)\n",
    "            losses, test_risk = np.load(max0_path + 'losses.npy'), np.load(max0_path + 'test_risk.npy')\n",
    "            eff_gains, rcp_lams = np.load(max0_path + 'eff_gains.npy'), np.load(max0_path + 'rcp_lams.npy', allow_pickle=True)\n",
    "            plot_risk_control(ax, eps_grid, losses, test_risk, 'clipped risk', model_colors[model_name], 'dashed')\n",
    "            # using the rcp_lams compute\n",
    "            scaling_path = precomputed_mixed_path + 'scaling/' + dataset + '/' + model_name + '/' \n",
    "            losses, test_risk = np.load(scaling_path + 'losses.npy'), np.load(scaling_path + 'test_risk.npy')\n",
    "            eff_gains, rcp_lams = np.load(scaling_path + 'eff_gains.npy'), np.load(scaling_path + 'rcp_lams.npy', allow_pickle=True)\n",
    "            plot_risk_control(ax, eps_grid, losses, test_risk, 'risk transformation', model_colors[model_name], 'solid')\n",
    "        else:\n",
    "            print('Missing data: ', dataset, model_name)\n",
    "\n",
    "    if display_legends:\n",
    "        lines2 = [Line2D([0], [0], color='black', lw=5, linestyle=sty, alpha=0.5) for sty in ['dashed', 'solid']]\n",
    "        labels2 = ['clipped risk', 'risk transformation']\n",
    "        \n",
    "        lines1 = [Line2D([0], [0], color=model_colors[model], lw=5, linestyle='-',) for model in models]\n",
    "        labels1 = [x.split('/')[1] for x in models]\n",
    "        \n",
    "        legend1 = ax[0].legend(lines1, labels1, loc='upper left',)\n",
    "        legend2 = ax[1].legend(lines2, labels2, loc='upper right',)\n",
    "    \n",
    "    plt.tight_layout()\n",
    "    \n",
    "    if debug_mode:\n",
    "        # Display the image\n",
    "        plt.show()\n",
    "    else:\n",
    "        # Save out the image\n",
    "        path = plot_directory + mixed_filename + '/risk_control_only_comparison/'\n",
    "        if not os.path.exists(path):\n",
    "            os.makedirs(path)\n",
    "        plt.savefig(path + dataset + '.pdf')\n",
    "\n",
    "        # Close plots to save memory\n",
    "        matplotlib.pyplot.close()\n",
    "\n",
    "    print('Finished', dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "171f0783-edbe-48b4-920b-6517f26aa3d8",
   "metadata": {},
   "source": [
    "## Efficiency Gains - Comparison (Clipping vs Our Approach)\n",
    "Paper: fig 16"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "af72c5e9-2396-4437-9b2d-f728ab71ac9a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Finished sst2\n",
      "Finished trec\n",
      "Finished financial_phrasebank\n",
      "Finished tweeteval_hate\n",
      "Finished tweeteval_feminist\n",
      "Finished tweeteval_atheism\n",
      "Finished unnatural\n",
      "Finished ag_news\n"
     ]
    }
   ],
   "source": [
    "# Plot JUST efficiency gains\n",
    "for dataset in datasets:\n",
    "    fig, ax = plt.subplots(1,1,figsize=(5,5))\n",
    "    for model_name, n_early_exit, tokenizer in zip(models, n_early_exits, tokenizers):\n",
    "        if os.path.exists(precomputed_mixed_path + 'scaling/' + dataset + '/' + model_name + '/'):\n",
    "            scaling_path = precomputed_mixed_path + 'scaling/' + dataset + '/' + model_name + '/'\n",
    "            rcp_lams = np.load(scaling_path + 'rcp_lams.npy', allow_pickle=True)\n",
    "            \n",
    "            # max-0\n",
    "            max0_path = precomputed_mixed_path + 'max0/' + dataset + '/' + model_name + '/'\n",
    "            losses, test_risk = np.load(max0_path + 'losses.npy'), np.load(max0_path + 'test_risk.npy')\n",
    "            eff_gains, rcp_lams = np.load(max0_path + 'eff_gains.npy'), np.load(max0_path + 'rcp_lams.npy', allow_pickle=True)\n",
    "            plot_efficiency_gains(ax, eps_grid, eff_gains, 'clipped risk', model_colors[model_name], 'dashed')\n",
    "            # Scaling\n",
    "            scaling_path = precomputed_mixed_path + 'scaling/' + dataset + '/' + model_name + '/' \n",
    "            losses, test_risk = np.load(scaling_path + 'losses.npy'), np.load(scaling_path + 'test_risk.npy')\n",
    "            eff_gains, rcp_lams = np.load(scaling_path + 'eff_gains.npy'), np.load(scaling_path + 'rcp_lams.npy', allow_pickle=True)\n",
    "            plot_efficiency_gains(ax, eps_grid, eff_gains, 'risk transformation', model_colors[model_name], 'solid')\n",
    "        else:\n",
    "            print('Missing data: ', dataset, model_name)\n",
    "\n",
    "    if display_legends:\n",
    "        lines2 = [Line2D([0], [0], color='black', lw=5, linestyle=sty, alpha=0.5) for sty in ['dashed', 'solid']]\n",
    "        labels2 = ['clipped risk', 'risk transformation']\n",
    "        \n",
    "        lines1 = [Line2D([0], [0], color=model_colors[model], lw=5, linestyle='-',) for model in models] + lines2\n",
    "        labels1 = [x.split('/')[1] for x in models] + labels2\n",
    "        \n",
    "        legend1 = ax.legend(lines1, labels1, loc='upper left',)\n",
    "\n",
    "    plt.title(dataset_names_plot_titles[dataset])\n",
    "    plt.tight_layout()\n",
    "    \n",
    "    if debug_mode:\n",
    "        # Display the image\n",
    "        plt.show()\n",
    "    else:\n",
    "        # Save out the image\n",
    "        path = plot_directory + mixed_filename + '/eff_gains/'\n",
    "        if not os.path.exists(path):\n",
    "            os.makedirs(path)\n",
    "        plt.savefig(path + dataset + '.pdf')\n",
    "\n",
    "        # Close plots to save memory\n",
    "        matplotlib.pyplot.close()\n",
    "\n",
    "    print('Finished', dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "91cc5bc3-9baa-4a14-992f-390799a056ee",
   "metadata": {},
   "source": [
    "# Highlighted Lambda vs Accuracy\n",
    "Paper: fig 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "35432451-f444-453e-98b9-5127e9a099e4",
   "metadata": {},
   "outputs": [
    {
     "ename": "IndexError",
     "evalue": "too many indices for array: array is 0-dimensional, but 1 were indexed",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mIndexError\u001b[0m                                Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[44], line 25\u001b[0m\n\u001b[1;32m     23\u001b[0m conf \u001b[38;5;241m=\u001b[39m get_all_confidences(data[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mcorrect\u001b[39m\u001b[38;5;124m'\u001b[39m], n_early_exit, label_order, confidence_type, first_exit)\n\u001b[1;32m     24\u001b[0m acc \u001b[38;5;241m=\u001b[39m get_all_accuracies(data[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mcorrect\u001b[39m\u001b[38;5;124m'\u001b[39m], c_gt, n_early_exit, first_exit)\n\u001b[0;32m---> 25\u001b[0m losses, exits \u001b[38;5;241m=\u001b[39m get_losses_and_exits_confidence(conf, acc, lambdas, \u001b[38;5;28;01mNone\u001b[39;00m, c_gt)\n\u001b[1;32m     26\u001b[0m c_acc \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\u001b[38;5;241m-\u001b[39mlosses\u001b[38;5;241m.\u001b[39mmean(axis\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m     27\u001b[0m ax\u001b[38;5;241m.\u001b[39mplot(lambdas, c_acc, label\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mcorrect demos\u001b[39m\u001b[38;5;124m'\u001b[39m, color\u001b[38;5;241m=\u001b[39mc_color)\n",
      "File \u001b[0;32m~/Documents/llm-risk-control/risk_control_utils.py:89\u001b[0m, in \u001b[0;36mget_losses_and_exits_confidence\u001b[0;34m(conf, acc, lambdas, relative_labels, true_labels, default_to_zero_shot)\u001b[0m\n\u001b[1;32m     86\u001b[0m lambda_acc \u001b[38;5;241m=\u001b[39m acc[np\u001b[38;5;241m.\u001b[39marange(\u001b[38;5;28mlen\u001b[39m(exits)), exits \u001b[38;5;241m-\u001b[39m \u001b[38;5;241m1\u001b[39m] \u001b[38;5;66;03m# accuracy at each of the chosen exit points\u001b[39;00m\n\u001b[1;32m     88\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m default_to_zero_shot:\n\u001b[0;32m---> 89\u001b[0m     lambda_acc[rows_with_no_threshold] \u001b[38;5;241m=\u001b[39m (relative_labels[rows_with_no_threshold] \u001b[38;5;241m==\u001b[39m true_labels[rows_with_no_threshold])\n\u001b[1;32m     91\u001b[0m all_exits\u001b[38;5;241m.\u001b[39mappend(exits)\n\u001b[1;32m     92\u001b[0m losses \u001b[38;5;241m=\u001b[39m (\u001b[38;5;241m1\u001b[39m \u001b[38;5;241m-\u001b[39m lambda_acc)\n",
      "\u001b[0;31mIndexError\u001b[0m: too many indices for array: array is 0-dimensional, but 1 were indexed"
     ]
    }
   ],
   "source": [
    "# Plot lambda vs relative accuracy\n",
    "for dataset in datasets:\n",
    "    for model_name, n_early_exit, tokenizer in zip(models, n_early_exits, tokenizers):\n",
    "        fig, ax = plt.subplots(1,1,figsize=(5,5))\n",
    "        label_order = get_label_order(dataset, tokenizer)\n",
    "        if fake_labels:\n",
    "            label_order = [fake_label_map[x] for x in label_order]\n",
    "        \n",
    "        base_dir = results_folder + '/n_demos_' + str(n_demos) + '/' + dataset + '/' + model_name + '/'\n",
    "        # First check that there exists all types of experiments\n",
    "        if (os.path.exists(base_dir + 'correct.json') and os.path.exists(base_dir + 'incorrect.json')\n",
    "                and os.path.exists(base_dir + 'zeroshot.json')):\n",
    "            # We have all the data\n",
    "            data = {}\n",
    "            for expt_type in ['correct', 'incorrect', 'zeroshot']:\n",
    "                with open(base_dir + expt_type + '.json', 'r') as file:\n",
    "                    data[expt_type] = json.load(file)\n",
    "\n",
    "            c_gt, i_gt, z_gt = get_ground_truth_by_type(ground_truth_type, data['correct'], data['incorrect'], data['zeroshot'], n_early_exit)\n",
    "            c_rel, i_rel, z_rel = get_relative_labels(relative_labels, data['correct'], data['incorrect'], data['zeroshot'], n_early_exit)\n",
    "\n",
    "            # Plot lambda vs correct accuracy\n",
    "            conf = get_all_confidences(data['correct'], n_early_exit, label_order, confidence_type, first_exit)\n",
    "            acc = get_all_accuracies(data['correct'], c_gt, n_early_exit, first_exit)\n",
    "            losses, exits = get_losses_and_exits_confidence(conf, acc, lambdas, None, c_gt)\n",
    "            c_acc = 1-losses.mean(axis=1)\n",
    "            ax.plot(lambdas, c_acc, label='correct demos', color=c_color)\n",
    "\n",
    "            # Plot lambda vs incorrect accuracy\n",
    "            conf = get_all_confidences(data['incorrect'], n_early_exit, label_order, confidence_type, first_exit)\n",
    "            acc = get_all_accuracies(data['incorrect'], c_gt, n_early_exit, first_exit)\n",
    "            losses, exits = get_losses_and_exits_confidence(conf, acc, lambdas, None, i_gt)\n",
    "            i_acc = 1-losses.mean(axis=1)\n",
    "            ax.plot(lambdas, i_acc, label='incorrect demos', color=i_color)\n",
    "\n",
    "            # Plot full-model performance as horizontal lines\n",
    "            correct_acc = [1 if p == t else 0 for p,t in zip(data['correct'][str(n_early_exit-1)], c_gt)]\n",
    "            incorrect_acc = [1 if p == t else 0 for p,t in zip(data['incorrect'][str(n_early_exit-1)], i_gt)]\n",
    "            zeroshot_acc = [1 if p == t else 0 for p,t in zip(data['zeroshot'][str(n_early_exit-1)], z_gt)]\n",
    "            correct_rel = [c-z for c,z in zip(correct_acc, zeroshot_acc)]\n",
    "            incorrect_rel = [i-z for i,z in zip(incorrect_acc, zeroshot_acc)]\n",
    "            avg_correct_rel, avg_incorrect_rel = c_acc[0], i_acc[0]\n",
    "            # Plot the full-model accuracies\n",
    "            ax.axhline(y=avg_correct_rel, color=c_color, linestyle='--', linewidth=2, label='correct full model')\n",
    "            ax.axhline(y=avg_incorrect_rel, color=i_color, linestyle='--', linewidth=2, label='incorrect full model')\n",
    "            valid_lams = []\n",
    "            for lam_idx in range(len(lambdas)):\n",
    "                if (avg_correct_rel - c_acc[len(lambdas)-lam_idx-1]) < 0.05 and i_acc[len(lambdas)-lam_idx-1] > avg_incorrect_rel:\n",
    "                    valid_lams.append(len(lambdas)-lam_idx-1)\n",
    "\n",
    "            intervals_idx = [(g[0][1], g[-1][1]) for _, group in groupby(enumerate(valid_lams), lambda x: x[1] - x[0]) for g in [list(group)]]\n",
    "            for start,end in intervals_idx:\n",
    "                end = end+1 if start == end and end < len(lambdas)-1 else end\n",
    "                plt.axvspan(lambdas[start], lambdas[end], facecolor='yellow', alpha=0.3)\n",
    "            \n",
    "            ax.set_xlabel('Lambda')\n",
    "            ax.set_ylabel('Accuracy Relative to Zero-Shot')\n",
    "            if display_legends:\n",
    "                ax.legend(loc='upper left')\n",
    "            \n",
    "            plt.tight_layout()\n",
    "            \n",
    "            if debug_mode:\n",
    "                # Display the image\n",
    "                plt.show()\n",
    "            else:\n",
    "                # Save out the image\n",
    "                path = plot_directory + 'lambda_vs_accuracy/'\n",
    "                if not os.path.exists(path):\n",
    "                    os.makedirs(path)\n",
    "                plt.savefig(path + dataset + '_' + model_name.split('/')[1] + '.pdf')\n",
    "    \n",
    "                # Close plots to save memory\n",
    "                matplotlib.pyplot.close()\n",
    "        else:\n",
    "            print('Missing data: ', dataset, model_name)\n",
    "    \n",
    "    print('Finished', dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "906f935e-bb35-4dee-8443-2890daa48551",
   "metadata": {},
   "source": [
    "# Highlighted Lambda vs Accuracy with Error Bars\n",
    "Paper: fig 8"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "982d3a49-5bc0-44de-9e05-b3fc82f0b602",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot lambda vs relative accuracy\n",
    "for dataset in datasets:\n",
    "    for model_name, n_early_exit, tokenizer in zip(models, n_early_exits, tokenizers):\n",
    "        fig, ax = plt.subplots(1,1,figsize=(5,5))\n",
    "        label_order = get_label_order(dataset, tokenizer)\n",
    "        if fake_labels:\n",
    "            label_order = [fake_label_map[x] for x in label_order]\n",
    "        \n",
    "        base_dir = results_folder + '/n_demos_' + str(n_demos) + '/' + dataset + '/' + model_name + '/'\n",
    "        # First check that there exists all types of experiments\n",
    "        if (os.path.exists(base_dir + 'correct.json') and os.path.exists(base_dir + 'incorrect.json')\n",
    "                and os.path.exists(base_dir + 'zeroshot.json')):\n",
    "            # We have all the data\n",
    "            data = {}\n",
    "            for expt_type in ['correct', 'incorrect', 'zeroshot']:\n",
    "                with open(base_dir + expt_type + '.json', 'r') as file:\n",
    "                    data[expt_type] = json.load(file)\n",
    "\n",
    "            c_gt, i_gt, z_gt = get_ground_truth_by_type(ground_truth_type, data['correct'], data['incorrect'], data['zeroshot'], n_early_exit)\n",
    "            c_rel, i_rel, z_rel = get_relative_labels(relative_labels, data['correct'], data['incorrect'], data['zeroshot'], n_early_exit)\n",
    "\n",
    "            # Plot lambda vs correct accuracy\n",
    "            conf = get_all_confidences(data['correct'], n_early_exit, label_order, confidence_type, first_exit)\n",
    "            acc = get_all_accuracies(data['correct'], c_gt, n_early_exit, first_exit)\n",
    "            losses, exits = get_losses_and_exits_confidence(conf, acc, lambdas, None, c_gt)\n",
    "            c_acc, c_err = 1-losses.mean(axis=1), losses.std(axis=1)/np.sqrt(losses.shape[1])\n",
    "            ax.plot(lambdas, c_acc, label='correct demos', color=c_color)\n",
    "            ax.fill_between(lambdas, c_acc-c_err, c_acc+c_err, color=c_color, alpha=0.3)\n",
    "\n",
    "            # Plot lambda vs incorrect accuracy\n",
    "            conf = get_all_confidences(data['incorrect'], n_early_exit, label_order, confidence_type, first_exit)\n",
    "            acc = get_all_accuracies(data['incorrect'], c_gt, n_early_exit, first_exit)\n",
    "            losses, exits = get_losses_and_exits_confidence(conf, acc, lambdas, None, i_gt)\n",
    "            i_acc, i_err = 1-losses.mean(axis=1), losses.std(axis=1)/np.sqrt(losses.shape[1])\n",
    "            ax.plot(lambdas, i_acc, label='incorrect demos', color=i_color)\n",
    "            ax.fill_between(lambdas, i_acc-i_err, i_acc+i_err, color=i_color, alpha=0.3)\n",
    "\n",
    "            # Plot full-model performance as horizontal lines\n",
    "            correct_acc = [1 if p == t else 0 for p,t in zip(data['correct'][str(n_early_exit-1)], c_gt)]\n",
    "            incorrect_acc = [1 if p == t else 0 for p,t in zip(data['incorrect'][str(n_early_exit-1)], i_gt)]\n",
    "            zeroshot_acc = [1 if p == t else 0 for p,t in zip(data['zeroshot'][str(n_early_exit-1)], z_gt)]\n",
    "            correct_rel = [c-z for c,z in zip(correct_acc, zeroshot_acc)]\n",
    "            incorrect_rel = [i-z for i,z in zip(incorrect_acc, zeroshot_acc)]\n",
    "            avg_correct_rel, avg_incorrect_rel = c_acc[0], i_acc[0]\n",
    "            # Plot the full-model accuracies\n",
    "            ax.axhline(y=avg_correct_rel, color=c_color, linestyle='--', linewidth=2, label='correct full model')\n",
    "            ax.axhline(y=avg_incorrect_rel, color=i_color, linestyle='--', linewidth=2, label='incorrect full model')\n",
    "            \n",
    "            valid_lams = []\n",
    "            for lam_idx in range(len(lambdas)):\n",
    "                idx = len(lambdas)-lam_idx-1\n",
    "                if (avg_correct_rel - c_acc[idx]) - c_err[idx] < 0.05 and i_acc[idx] - i_err[idx] > avg_incorrect_rel:\n",
    "                    valid_lams.append(len(lambdas)-lam_idx-1)\n",
    "\n",
    "            intervals_idx = [(g[0][1], g[-1][1]) for _, group in groupby(enumerate(valid_lams), lambda x: x[1] - x[0]) for g in [list(group)]]\n",
    "            for start,end in intervals_idx:\n",
    "                end = end+1 if start == end and end < len(lambdas)-1 else end\n",
    "                plt.axvspan(lambdas[start], lambdas[end], facecolor='yellow', alpha=0.3)\n",
    "            \n",
    "            ax.set_xlabel('Lambda')\n",
    "            ax.set_ylabel('Accuracy Relative to Zero-Shot')\n",
    "            if display_legends:\n",
    "                ax.legend(loc='upper left')\n",
    "                \n",
    "            plt.tight_layout()\n",
    "            \n",
    "            if debug_mode:\n",
    "                # Display the image\n",
    "                plt.show()\n",
    "            else:\n",
    "                # Save out the image\n",
    "                path = plot_directory + 'lambda_vs_accuracy/with_error_bars/' \n",
    "                if not os.path.exists(path):\n",
    "                    os.makedirs(path)\n",
    "                plt.savefig(path + dataset + '_' + model_name.split('/')[1] + '.pdf')\n",
    "    \n",
    "                # Close plots to save memory\n",
    "                matplotlib.pyplot.close()\n",
    "        else:\n",
    "            print('Missing data: ', dataset, model_name)\n",
    "    \n",
    "    print('Finished', dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b1c65a5b-f469-4376-9e3e-95ea09081672",
   "metadata": {},
   "source": [
    "# Non-Monotonicity of Risk\n",
    "Paper: fig 7"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03f49020-a26a-4fa3-b962-fdba4a6f1ccd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Show non-monotonicity of risk\n",
    "for dataset in datasets:\n",
    "    for model_name, n_early_exit, tokenizer in zip(models, n_early_exits, tokenizers):\n",
    "        fig, ax = plt.subplots(1,1,figsize=(5,5))\n",
    "        label_order = get_label_order(dataset, tokenizer)\n",
    "        if fake_labels:\n",
    "            label_order = [fake_label_map[x] for x in label_order]\n",
    "        \n",
    "        base_dir = results_folder + '/n_demos_' + str(n_demos) + '/' + dataset + '/' + model_name + '/'\n",
    "        # First check that there exists all types of experiments\n",
    "        if (os.path.exists(base_dir + 'correct.json') and os.path.exists(base_dir + 'incorrect.json')\n",
    "                and os.path.exists(base_dir + 'zeroshot.json')):\n",
    "            # We have all the data\n",
    "            data = {}\n",
    "            for expt_type in ['correct', 'incorrect', 'zeroshot']:\n",
    "                with open(base_dir + expt_type + '.json', 'r') as file:\n",
    "                    data[expt_type] = json.load(file)\n",
    "\n",
    "            c_gt, i_gt, z_gt = get_ground_truth_by_type(ground_truth_type, data['correct'], data['incorrect'], data['zeroshot'], n_early_exit)\n",
    "            c_rel, i_rel, z_rel = get_relative_labels(relative_labels, data['correct'], data['incorrect'], data['zeroshot'], n_early_exit)\n",
    "\n",
    "            # Combine correct and incorrect\n",
    "            gt, rel = c_gt + i_gt, c_rel + i_rel\n",
    "            combined_data = {}\n",
    "            for col in data['correct']:\n",
    "                combined_data[col] = data['correct'][col] + data['incorrect'][col]\n",
    "\n",
    "            # Plot lambda vs overall accuracy for this model\n",
    "            conf = get_all_confidences(combined_data, n_early_exit, label_order, confidence_type, first_exit)\n",
    "            acc = get_all_accuracies(combined_data, gt, n_early_exit, first_exit)\n",
    "            losses, exits = get_losses_and_exits_confidence(conf, acc, lambdas, rel, gt)\n",
    "            ax.plot(lambdas, 1-losses.mean(axis=1))\n",
    "            ax.set_xlabel('Lambda')\n",
    "            ax.set_ylabel('Risk')\n",
    "            plt.tight_layout()\n",
    "\n",
    "            if debug_mode:\n",
    "                # Display the image\n",
    "                plt.show()\n",
    "            else:\n",
    "                # Save out the image\n",
    "                path = plot_directory + mixed_filename + '/lambda_vs_accuracy/'\n",
    "                if not os.path.exists(path):\n",
    "                    os.makedirs(path)\n",
    "                plt.savefig(path + dataset + '_' + model_name.split('/')[1] + '.pdf')\n",
    "        \n",
    "                # Close plots to save memory\n",
    "                matplotlib.pyplot.close()\n",
    "        else:\n",
    "            print('Missing data: ', dataset, model_name)\n",
    "\n",
    "    print('Finished', dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "609e88ab-75d1-45b9-ad7b-2721e5a21705",
   "metadata": {},
   "source": [
    "# Separate Correct and Incorrect Demos Results - Risk Control and Efficiency Gains\n",
    "Not included in paper. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f3fcada-1b40-4c97-8816-ec0e92d3ba60",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot risk control results and exit layer\n",
    "# ONLY using the scaling approach, and not including zero-shot! One plot per dataset. \n",
    "for dataset in datasets:\n",
    "    fig, ax = plt.subplots(1,2,figsize=(10,5))\n",
    "    for model_name, n_early_exit, tokenizer in zip(models, n_early_exits, tokenizers):\n",
    "        # First check that all experiment results are precomputed\n",
    "        if os.path.exists(precomputed_risk_path + 'scaling/' + dataset + '/' + model_name + '/correct'):\n",
    "            # Run the scaling method ONLY - correct demos\n",
    "            scaling_path = precomputed_risk_path + 'scaling/' + dataset + '/' + model_name + '/correct/'\n",
    "            losses, test_risk = np.load(scaling_path + 'losses.npy'), np.load(scaling_path + 'test_risk.npy')\n",
    "            eff_gains, rcp_lams = np.load(scaling_path + 'eff_gains.npy'), np.load(scaling_path + 'rcp_lams.npy', allow_pickle=True)\n",
    "            plot_risk_control(ax[0], eps_grid, losses, test_risk, 'correct', model_colors[model_name], 'solid')\n",
    "            plot_efficiency_gains(ax[1], eps_grid, eff_gains, 'correct', model_colors[model_name], 'solid')\n",
    "        \n",
    "            # Run the scaling method ONLY - incorrect demos\n",
    "            scaling_path = precomputed_risk_path + 'scaling/' + dataset + '/' + model_name + '/incorrect/'\n",
    "            losses, test_risk = np.load(scaling_path + 'losses.npy'), np.load(scaling_path + 'test_risk.npy')\n",
    "            eff_gains, rcp_lams = np.load(scaling_path + 'eff_gains.npy'), np.load(scaling_path + 'rcp_lams.npy', allow_pickle=True)\n",
    "            plot_risk_control(ax[0], eps_grid, losses, test_risk, 'incorrect', model_colors[model_name], 'dashed')\n",
    "            plot_efficiency_gains(ax[1], eps_grid, eff_gains, 'incorrect', model_colors[model_name], 'dashed')\n",
    "        else:\n",
    "            print('Missing data: ', dataset, model_name)\n",
    "\n",
    "    if display_legends:\n",
    "        lines2 = [Line2D([0], [0], color='black', lw=5, linestyle=sty, alpha=0.5) for sty in ['dashed', 'solid']]\n",
    "        labels2 = ['incorrect', 'correct']\n",
    "        \n",
    "        lines1 = [Line2D([0], [0], color=model_colors[model], lw=5, linestyle='-',) for model in models]\n",
    "        labels1 = [x.split('/')[1] for x in models]\n",
    "        \n",
    "        legend1 = ax[0].legend(lines1, labels1, loc='upper right',)\n",
    "        legend2 = ax[1].legend(lines2, labels2, loc='upper right',)\n",
    "    \n",
    "    plt.tight_layout()\n",
    "    \n",
    "    if debug_mode:\n",
    "        # Display the image\n",
    "        plt.show()\n",
    "    else:\n",
    "        # Save out the image\n",
    "        path = plot_directory + 'risk_control/'\n",
    "        if not os.path.exists(path):\n",
    "            os.makedirs(path)\n",
    "        plt.savefig(path + dataset + '.pdf')\n",
    "\n",
    "        # Close plots to save memory\n",
    "        matplotlib.pyplot.close()\n",
    "    \n",
    "    print('Finished', dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "716f4a30-0a5a-49e4-a0f2-cdf5412119d6",
   "metadata": {},
   "source": [
    "# Compute Accuracy Difference between Clipping and Full Model\n",
    "Used to compute accuracy delta column in Table 1. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e4d0317-f646-4c9b-abc9-a054896df13f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the accuracy diff between my early exit and the full dataset for each epsilon\n",
    "avg_acc_diff = [0 for i in range(len(eps_grid))]\n",
    "total_n = 0\n",
    "for dataset in datasets:\n",
    "    print(dataset)\n",
    "    for model_name, n_early_exit, tokenizer in zip(models, n_early_exits, tokenizers):\n",
    "        # First check that all experiment results are precomputed\n",
    "        if os.path.exists(precomputed_mixed_path + 'scaling/' + dataset + '/' + model_name + '/'):\n",
    "            label_order = get_label_order(dataset, tokenizer)\n",
    "            if fake_labels:\n",
    "                label_order = [fake_label_map[x] for x in label_order]\n",
    "            \n",
    "            # load the data\n",
    "            base_dir = results_folder + '/n_demos_' + str(n_demos) + '/' + dataset + '/' + model_name + '/'\n",
    "            data = {}\n",
    "            for expt_type in ['correct', 'incorrect', 'zeroshot']:\n",
    "                with open(base_dir + expt_type + '.json', 'r') as file:\n",
    "                    data[expt_type] = json.load(file)\n",
    "\n",
    "            c_gt, i_gt, z_gt = get_ground_truth_by_type(ground_truth_type, data['correct'], data['incorrect'], data['zeroshot'], n_early_exit)\n",
    "            c_rel, i_rel, z_rel = get_relative_labels(relative_labels, data['correct'], data['incorrect'], data['zeroshot'], n_early_exit)\n",
    "\n",
    "            # Combine correct and incorrect\n",
    "            gt, rel = c_gt + i_gt, c_rel + i_rel\n",
    "            combined_data = {}\n",
    "            for col in data['correct']:\n",
    "                combined_data[col] = data['correct'][col] + data['incorrect'][col]\n",
    "\n",
    "            # get the full-model performance on this model + dataset (relative to zero-shot)\n",
    "            full_model_risk = [1 if p == t else 0 for p,t in zip(combined_data[str(n_early_exit-1)], gt)]\n",
    "            zeroshot_risk = [1 if p == t else 0 for p,t in zip(data['zeroshot'][str(n_early_exit-1)], z_gt)]\n",
    "            zeroshot_risk = sum(zeroshot_risk) / len(zeroshot_risk)\n",
    "            full_model_risk = sum(full_model_risk) / len(full_model_risk)\n",
    "            full_model_risk = full_model_risk - zeroshot_risk\n",
    "            \n",
    "            # get the pre-computed losses per epsilon\n",
    "            scaling_path = precomputed_mixed_path + 'scaling/' + dataset + '/' + model_name + '/'\n",
    "            test_risk = np.load(scaling_path + 'test_risk.npy').mean(axis=0)\n",
    "            avg_acc_diff = [avg_acc_diff[i] + test_risk[i] - full_model_risk for i in range(len(avg_acc_diff))]\n",
    "            total_n += 1\n",
    "\n",
    "avg_acc_diff = [i/total_n for i in avg_acc_diff]\n",
    "# for i in range(len(eps_grid)):\n",
    "#     print(eps_grid[i], 'Avg acc diff:', avg_acc_diff[i])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8e050b67-4ce8-4acb-b51d-66da1b09671f",
   "metadata": {},
   "source": [
    "# Compute Efficiency Gains\n",
    "Used to compute the two efficiency gains columns in Table 1."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c1d4271-69c6-4f1a-8c44-b687ce9cae80",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Evaluate efficiency gain difference between scaling vs max0\n",
    "eps_to_evaluate = 0.1 # must be less than max_eps!\n",
    "eps_index = np.where(eps_grid == eps_to_evaluate)[0][0]\n",
    "avg_incorrect_diff, avg_correct_diff, total_n = 0, 0, 0\n",
    "avg_correct_scaling_gains = 0\n",
    "for dataset in datasets:\n",
    "    for model_name, n_early_exit, tokenizer in zip(models, n_early_exits, tokenizers):\n",
    "        if os.path.exists(precomputed_risk_path + 'scaling/' + dataset + '/' + model_name + '/correct/'):\n",
    "            total_n += 1\n",
    "            # Correct demos\n",
    "            eff_gains = np.load(precomputed_risk_path + 'max0/' + dataset + '/' + model_name + '/correct/' + 'eff_gains.npy')\n",
    "            max0_gains = eff_gains.mean(axis=0)[eps_index]\n",
    "            eff_gains = np.load(precomputed_risk_path + 'scaling/' + dataset + '/' + model_name + '/correct/' + 'eff_gains.npy')\n",
    "            scaling_gains = eff_gains.mean(axis=0)[eps_index]\n",
    "            losses = np.load(precomputed_risk_path + 'scaling/' + dataset + '/' + model_name + '/correct/' + 'losses.npy')\n",
    "            avg_correct_scaling_gains += scaling_gains\n",
    "            avg_correct_diff += (scaling_gains - max0_gains)\n",
    "            # Incorrect demos\n",
    "            eff_gains = np.load(precomputed_risk_path + 'max0/' + dataset + '/' + model_name + '/incorrect/' + 'eff_gains.npy')\n",
    "            max0_gains = eff_gains.mean(axis=0)[eps_index]\n",
    "            eff_gains = np.load(precomputed_risk_path + 'scaling/' + dataset + '/' + model_name + '/incorrect/' + 'eff_gains.npy')\n",
    "            scaling_gains = eff_gains.mean(axis=0)[eps_index]\n",
    "            avg_incorrect_diff += (scaling_gains - max0_gains)\n",
    "        else:\n",
    "            print('Missing data: ', dataset, model_name)\n",
    "\n",
    "if total_n > 0:\n",
    "    avg_correct_diff, avg_incorrect_diff = avg_correct_diff / total_n, avg_incorrect_diff / total_n\n",
    "    print('Avg diff - correct examples:', avg_correct_diff, 'Percent:', avg_correct_diff/17*100)\n",
    "    print('Avg diff - incorrect examples:', avg_incorrect_diff, 'Percent:', avg_incorrect_diff/17*100)\n",
    "    avg_correct_scaling_gains = avg_correct_scaling_gains / total_n\n",
    "    print('Avg diff vs full model (correct):', avg_correct_scaling_gains, 'Percent:', avg_correct_scaling_gains/17*100)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "85923741-e26f-4a8b-a560-85e5fe5ddc77",
   "metadata": {},
   "source": [
    "# Efficiency Gains - Correct vs Incorrect Demos & Clipped vs Transformed Risk\n",
    "Not included in paper."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61f41686-597e-4919-8285-210fad941b26",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot risk control results and exit layer, and with comparison to the other approach\n",
    "# Also study the average efficiency gains\n",
    "for dataset in datasets:\n",
    "    for model_name, n_early_exit, tokenizer in zip(models, n_early_exits, tokenizers):\n",
    "        if os.path.exists(precomputed_risk_path + 'scaling/' + dataset + '/' + model_name + '/correct/'):\n",
    "            fig, ax = plt.subplots(1,1,figsize=(5,5))            \n",
    "            for label, color in zip(['correct', 'incorrect'], [c_color, i_color]):\n",
    "                # max-0\n",
    "                max0_path = precomputed_risk_path + 'max0/' + dataset + '/' + model_name + '/' + label + '/'\n",
    "                losses, test_risk = np.load(max0_path + 'losses.npy'), np.load(max0_path + 'test_risk.npy')\n",
    "                eff_gains, rcp_lams = np.load(max0_path + 'eff_gains.npy'), np.load(max0_path + 'rcp_lams.npy', allow_pickle=True)\n",
    "                plot_efficiency_gains(ax, eps_grid, eff_gains, label + ' clipped', color, 'dashed')\n",
    "                # Scaling\n",
    "                scaling_path = precomputed_risk_path + 'scaling/' + dataset + '/' + model_name + '/' + label + '/'\n",
    "                losses, test_risk = np.load(scaling_path + 'losses.npy'), np.load(scaling_path + 'test_risk.npy')\n",
    "                eff_gains, rcp_lams = np.load(scaling_path + 'eff_gains.npy'), np.load(scaling_path + 'rcp_lams.npy', allow_pickle=True)\n",
    "                plot_efficiency_gains(ax, eps_grid, eff_gains, label + ' with transformation', color, 'solid')\n",
    "\n",
    "            if display_legends:\n",
    "                ax.legend()\n",
    "            \n",
    "            plt.tight_layout()\n",
    "            \n",
    "            if debug_mode:\n",
    "                # Display the image\n",
    "                plt.show()\n",
    "            else:\n",
    "                # Save out the image\n",
    "                path = plot_directory + 'eff_gains_with_comparison/'\n",
    "                if not os.path.exists(path):\n",
    "                    os.makedirs(path)\n",
    "                plt.savefig(path + dataset + '_' + model_name.split('/')[1] + '.pdf')\n",
    "    \n",
    "                # Close plots to save memory\n",
    "                matplotlib.pyplot.close()\n",
    "        else:\n",
    "            print('Missing data: ', dataset, model_name)\n",
    "\n",
    "    print('Finished', dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "38993989-5dc1-4f22-8fa6-c29056976f75",
   "metadata": {},
   "source": [
    "# Accuracy vs Layer\n",
    "Paper: fig 1, 15, 17, 19"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a926147-459d-416b-9738-a72ea3481a37",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Accuracy vs Layer plots (showing overthinking)\n",
    "for dataset in datasets:\n",
    "    for model_name, n_early_exit, tokenizer in zip(models, n_early_exits, tokenizers):\n",
    "        fig, ax = plt.subplots(1,1,figsize=(5,5))\n",
    "        label_order = get_label_order(dataset, tokenizer)\n",
    "        base_dir = results_folder + '/n_demos_' + str(n_demos) + '/' + dataset + '/' + model_name + '/'\n",
    "        \n",
    "        # First check that there exists all types of experiments\n",
    "        if (os.path.exists(base_dir + 'correct.json') and os.path.exists(base_dir + 'incorrect.json')\n",
    "                and os.path.exists(base_dir + 'zeroshot.json')):\n",
    "            # We have all the data\n",
    "            data = {}\n",
    "            for expt_type in ['correct', 'incorrect', 'zeroshot']:\n",
    "                with open(base_dir + expt_type + '.json', 'r') as file:\n",
    "                    data[expt_type] = json.load(file)\n",
    "            c_gt, i_gt, z_gt = get_ground_truth_by_type(ground_truth_type, data['correct'], data['incorrect'], data['zeroshot'], n_early_exit)\n",
    "            \n",
    "            for gt, label, color in zip([c_gt, i_gt, z_gt], ['correct', 'incorrect', 'zeroshot'], [c_color, i_color, z_color]):\n",
    "                acc = get_all_accuracies(data[label], gt, n_early_exit, 0)\n",
    "                ax.plot([i for i in range(n_early_exit)], acc.mean(axis=0), label=label, color=color)\n",
    "\n",
    "            if display_legends:\n",
    "                ax.legend()\n",
    "            \n",
    "            ax.set_xlabel('Layer')\n",
    "            calib = \"Calibrated\" if use_calibration else \"Uncalibrated\"\n",
    "            ax.set_ylabel(calib + ' Accuracy')\n",
    "            plt.tight_layout()\n",
    "            if debug_mode:\n",
    "                # Display the image\n",
    "                plt.show()\n",
    "            else:\n",
    "                # Save out the image\n",
    "                path = plot_directory + 'loss_vs_layer/' \n",
    "                if not os.path.exists(path):\n",
    "                    os.makedirs(path)\n",
    "                plt.savefig(path + dataset + '_' + model_name.split('/')[1] + '.pdf')\n",
    "        \n",
    "                # Close plots to save memory\n",
    "                matplotlib.pyplot.close()\n",
    "        else:\n",
    "            print('Missing data: ', dataset, model_name)\n",
    "    \n",
    "    print('Finished', dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c7fb5b79-257b-457c-9dec-7a2e9f583b42",
   "metadata": {},
   "source": [
    "# Lambda vs Risk - Correct vs Incorrect Demos and Scaling vs Clipping\n",
    "Paper: fig 18, top row."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4f09f6d-9578-41e4-8b09-fe12f5801584",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot lambda vs risk for both risk types, correct and incorrect demos\n",
    "for dataset in datasets:\n",
    "    for model_name, n_early_exit, tokenizer in zip(models, n_early_exits, tokenizers):\n",
    "        fig, ax = plt.subplots(1,1,figsize=(5,5))\n",
    "        label_order = get_label_order(dataset, tokenizer)\n",
    "        if fake_labels:\n",
    "            label_order = [fake_label_map[x] for x in label_order]\n",
    "        \n",
    "        base_dir = results_folder + '/n_demos_' + str(n_demos) + '/' + dataset + '/' + model_name + '/'\n",
    "        # First check that there exists all types of experiments\n",
    "        if (os.path.exists(base_dir + 'correct.json') and os.path.exists(base_dir + 'incorrect.json')\n",
    "                and os.path.exists(base_dir + 'zeroshot.json')):\n",
    "            # We have all the data\n",
    "            data = {}\n",
    "            for expt_type in ['correct', 'incorrect', 'zeroshot']:\n",
    "                with open(base_dir + expt_type + '.json', 'r') as file:\n",
    "                    data[expt_type] = json.load(file)\n",
    "\n",
    "            c_gt, i_gt, z_gt = get_ground_truth_by_type(ground_truth_type, data['correct'], data['incorrect'], data['zeroshot'], n_early_exit)\n",
    "            c_rel, i_rel, z_rel = get_relative_labels(relative_labels, data['correct'], data['incorrect'], data['zeroshot'], n_early_exit)\n",
    "\n",
    "            # Plot lambda vs correct risk, scaled\n",
    "            conf = get_all_confidences(data['correct'], n_early_exit, label_order, confidence_type, first_exit)\n",
    "            acc = get_all_accuracies(data['correct'], c_gt, n_early_exit, first_exit)\n",
    "            losses, exits = get_losses_and_exits_confidence(conf, acc, lambdas, c_rel, c_gt)\n",
    "            c_risk = losses.mean(axis=1)\n",
    "            ax.plot(lambdas, c_risk, label='correct with scaling', color=c_color)\n",
    "\n",
    "            # Plot lambda vs correct risk, clipped\n",
    "            losses = losses.clip(min=0)\n",
    "            c_risk = losses.mean(axis=1)\n",
    "            ax.plot(lambdas, c_risk, label='correct with clipping', color=c_color, linestyle='dashed')\n",
    "\n",
    "            # Plot lambda vs incorrect risk, scaled\n",
    "            conf = get_all_confidences(data['incorrect'], n_early_exit, label_order, confidence_type, first_exit)\n",
    "            acc = get_all_accuracies(data['incorrect'], c_gt, n_early_exit, first_exit)\n",
    "            losses, exits = get_losses_and_exits_confidence(conf, acc, lambdas, i_rel, i_gt)\n",
    "            i_risk = losses.mean(axis=1)\n",
    "            ax.plot(lambdas, i_risk, label='incorrect with scaling', color=i_color)\n",
    "\n",
    "            # Plot lambda vs incorrect risk, scaled\n",
    "            losses = losses.clip(min=0)\n",
    "            i_risk = losses.mean(axis=1)\n",
    "            ax.plot(lambdas, i_risk, label='incorrect with clipping', color=i_color, linestyle='dashed')\n",
    "            \n",
    "            ax.set_xlabel('Lambda')\n",
    "            ax.set_ylabel('Empirical Risk')\n",
    "            if display_legends:\n",
    "                ax.legend()\n",
    "                \n",
    "            plt.tight_layout()\n",
    "            \n",
    "            if debug_mode:\n",
    "                # Display the image\n",
    "                plt.show()\n",
    "            else:\n",
    "                # Save out the image\n",
    "                path = plot_directory + mixed_filename + '/lambda_vs_risk/'\n",
    "                if not os.path.exists(path):\n",
    "                    os.makedirs(path)\n",
    "                plt.savefig(path + dataset + '_' + model_name.split('/')[1] + '.pdf')\n",
    "    \n",
    "                # Close plots to save memory\n",
    "                matplotlib.pyplot.close()\n",
    "        else:\n",
    "            print('Missing data: ', dataset, model_name)\n",
    "    \n",
    "    print('Finished', dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "79b95793-fcfa-4903-bfb7-820325107732",
   "metadata": {},
   "source": [
    "# Confidence vs Layer\n",
    "Paper: fig 19"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55185c93-5d0a-4148-abbc-90be6b6ecc22",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot confidence of each layer's prediction through layers\n",
    "for dataset in datasets:\n",
    "    for model_idx in range(len(models)):\n",
    "        fig, ax = plt.subplots(1,1,figsize=(5,5))\n",
    "        model_name, n_early_exit, tokenizer = models[model_idx], n_early_exits[model_idx], tokenizers[model_idx]\n",
    "        layers = [i for i in range(n_early_exit)]\n",
    "        label_order = get_label_order(dataset, tokenizer)\n",
    "        if fake_labels:\n",
    "            label_order = [fake_label_map[x] for x in label_order]\n",
    "        \n",
    "        base_dir = results_folder + '/n_demos_' + str(n_demos) + '/' + dataset + '/' + model_name + '/'\n",
    "        # First check that there exists all types of experiments\n",
    "        if (os.path.exists(base_dir + 'correct.json') and os.path.exists(base_dir + 'incorrect.json')\n",
    "                and os.path.exists(base_dir + 'zeroshot.json')):\n",
    "            # We have all the data\n",
    "            data = {}\n",
    "            for expt_type in ['correct', 'incorrect', 'zeroshot']:\n",
    "                with open(base_dir + expt_type + '.json', 'r') as file:\n",
    "                    data[expt_type] = json.load(file)\n",
    "\n",
    "            c_conf = get_all_confidences(data['correct'], n_early_exit, label_order, confidence_type)\n",
    "            i_conf = get_all_confidences(data['incorrect'], n_early_exit, label_order, confidence_type)\n",
    "            z_conf = get_all_confidences(data['zeroshot'], n_early_exit, label_order, confidence_type)\n",
    "\n",
    "            ax.plot(layers, c_conf.mean(axis=0), label='correct')\n",
    "            ax.plot(layers, i_conf.mean(axis=0), label='incorrect')\n",
    "            ax.plot(layers, z_conf.mean(axis=0), label='zeroshot')\n",
    "            ax.set_xlabel('Layer')\n",
    "            ax.set_ylabel('Confidence')\n",
    "            if display_legends:\n",
    "                ax.legend()\n",
    "\n",
    "            plt.tight_layout()\n",
    "\n",
    "        if debug_mode:\n",
    "            # Display the image\n",
    "            plt.show()\n",
    "        else:\n",
    "            # Save out the image\n",
    "            path = plot_directory + 'confidences/' + model_name + '/'\n",
    "            if not os.path.exists(path):\n",
    "                os.makedirs(path)\n",
    "            plt.savefig(path + dataset + '.pdf')\n",
    "\n",
    "    # Close plots to save memory\n",
    "    matplotlib.pyplot.close()\n",
    "    print('Finished', dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5e237eb4-eef7-4feb-8729-fdef8fbf56e3",
   "metadata": {},
   "source": [
    "# Comparing Accuracy vs Full Model (instead of zero-shot) - Table\n",
    "Not in paper; used only for rebuttals. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8127949c-5e06-461d-9d66-f96e5889552b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Print out data from the separate risk-control data\n",
    "eps = 0.2\n",
    "eps_index = np.where(eps_grid == eps)[0][0]\n",
    "print('Using epsilon value', eps)\n",
    "\n",
    "# Convert to Markdown format\n",
    "def to_markdown_table(rows):\n",
    "    header = \"| \" + \" | \".join(rows[0]) + \" |\"\n",
    "    separator = \"| \" + \" | \".join([\"---\"] * len(rows[0])) + \" |\"\n",
    "    content = \"\\n\".join(\"| \" + \" | \".join(row) + \" |\" for row in rows[1:])\n",
    "    return \"\\n\".join([header, separator, content])\n",
    "\n",
    "inc_rows = [[\"Dataset\", \"Zeroshot Performance\", \"Full Model Accuracy\", \"Early Exit Accuracy\", \"Accuracy Gain vs Full Model\",\n",
    "                \"Efficiency Gain vs Full Model\"]]\n",
    "correct_rows = [[\"Dataset\", \"Zeroshot Performance\", \"Full Model Accuracy\", \"Early Exit Accuracy\", \"Accuracy Gain vs Full Model\",\n",
    "                \"Efficiency Gain vs Full Model\"]]\n",
    "for dataset in datasets:\n",
    "    for model_name, n_early_exit, tokenizer in zip(models, n_early_exits, tokenizers):\n",
    "        if 'layerskip-llama3-8B' in model_name:\n",
    "            print('Running', dataset, model_name)\n",
    "            # First check that all experiment results are precomputed\n",
    "            if os.path.exists(precomputed_risk_path + 'scaling/' + dataset + '/' + model_name + '/'):\n",
    "                label_order = get_label_order(dataset, tokenizer)\n",
    "                if fake_labels:\n",
    "                    label_order = [fake_label_map[x] for x in label_order]\n",
    "                \n",
    "                # load the data\n",
    "                base_dir = results_folder + '/n_demos_' + str(n_demos) + '/' + dataset + '/' + model_name + '/'\n",
    "                data = {}\n",
    "                for expt_type in ['correct', 'incorrect', 'zeroshot']:\n",
    "                    with open(base_dir + expt_type + '.json', 'r') as file:\n",
    "                        data[expt_type] = json.load(file)\n",
    "    \n",
    "                c_gt, i_gt, z_gt = get_ground_truth_by_type(ground_truth_type, data['correct'], data['incorrect'], data['zeroshot'], n_early_exit)\n",
    "                c_rel, i_rel, z_rel = get_relative_labels(\"zeroshot_full_model\", data['correct'], data['incorrect'], data['zeroshot'], n_early_exit)\n",
    "                \n",
    "                # For each dataset, print the following (TODO averaged over all models):\n",
    "                # Zeroshot, incorrect, correct demos full-model performance\n",
    "                c_acc = np.mean(get_all_accuracies(data['correct'], c_gt, n_early_exit, first_exit)[:,-1])\n",
    "                i_acc = np.mean(get_all_accuracies(data['incorrect'], i_gt, n_early_exit, first_exit)[:,-1])\n",
    "                z_acc = np.mean(get_all_accuracies(data['zeroshot'], z_gt, n_early_exit, first_exit)[:,-1])\n",
    "                #print('Zeroshot acc:', z_acc)\n",
    "                #print('Correct full model:', c_acc)\n",
    "                #print('Incorrect full model:', i_acc)\n",
    "\n",
    "                # Set default loss/eff gain when risk can't be controlled\n",
    "                def_loss_uncontrolled_risk = 1-z_acc \n",
    "                def_eff_gain_uncontrolled_risk = 0\n",
    "\n",
    "                # Compute early exit acc, delta w full model, and eff gains for correct demos\n",
    "                conf = get_all_confidences(data['correct'], n_early_exit, label_order, confidence_type, first_exit)\n",
    "                acc = get_all_accuracies(data['correct'], c_gt, n_early_exit, first_exit)\n",
    "                n_cal = int(len(data['correct']['0'])/2)\n",
    "                losses, test_risk, eff_gains, rcp_lams = apply_risk_control(conf, acc, c_gt, c_rel, lambdas, eps_grid,\n",
    "                                                                            delta, n_cal, n_trials, 'scaling')\n",
    "                losses, exits = get_losses_and_exits_confidence(conf, acc, lambdas, None, c_gt)\n",
    "                # Get the loss corresponding to the right lambda\n",
    "                if rcp_lams[eps_index] is not None:\n",
    "                    lam_idx = rcp_lams[eps_index]\n",
    "                    c_acc_ee = 1-losses[lam_idx,:].mean(axis=0)\n",
    "                else:\n",
    "                    c_acc_ee = z_acc\n",
    "                    c_eff_gain = 0\n",
    "                #print('Early exit correct acc:', c_acc_ee)\n",
    "                #print('Delta w full model:', c_acc-c_acc_ee)\n",
    "                c_eff_gain = eff_gains.mean(axis=0)[eps_index]\n",
    "                #print('Efficiency gain correct:', c_eff_gain)\n",
    "    \n",
    "                # Compute early exit acc, delta w full model, and eff gains for incorrect demos\n",
    "                conf = get_all_confidences(data['incorrect'], n_early_exit, label_order, confidence_type, first_exit)\n",
    "                acc = get_all_accuracies(data['incorrect'], i_gt, n_early_exit, first_exit)\n",
    "                n_cal = int(len(data['incorrect']['0'])/2)\n",
    "                losses, test_risk, eff_gains, rcp_lams = apply_risk_control(conf, acc, i_gt, i_rel, lambdas, eps_grid, \n",
    "                                                                            delta, n_cal, n_trials, 'scaling')\n",
    "                losses, exits = get_losses_and_exits_confidence(conf, acc, lambdas, None, i_gt)\n",
    "                # Get the loss corresponding to the right lambda\n",
    "                if rcp_lams[eps_index] is not None:\n",
    "                    lam_idx = rcp_lams[eps_index]\n",
    "                    i_acc_ee = 1-losses[lam_idx,:].mean(axis=0)\n",
    "                else:\n",
    "                    i_acc_ee = z_acc\n",
    "                    i_eff_gain = 0\n",
    "                #print('Early exit incorrect acc:', i_acc_ee)\n",
    "                #print('Delta w full model:', i_acc-i_acc_ee)\n",
    "                i_eff_gain = eff_gains.mean(axis=0)[eps_index]\n",
    "                #print('Efficiency gain incorrect:', i_eff_gain)\n",
    "                c_results = [dataset, round(z_acc,3), round(c_acc,3), round(c_acc_ee,3), round(c_acc-c_acc_ee,3), round(c_eff_gain,3)]\n",
    "                correct_rows.append([str(x) for x in c_results])\n",
    "                i_results = [dataset, round(z_acc,3), round(i_acc,3), round(i_acc_ee,3), round(i_acc-i_acc_ee,3), round(i_eff_gain,3)]\n",
    "                inc_rows.append([str(x) for x in i_results])\n",
    "                # print('Correct:', round(z_acc,3), round(c_acc,3), round(c_acc_ee,3), round(c_acc-c_acc_ee,3), round(c_eff_gain,3))\n",
    "                # print('Incorrect:', round(z_acc,3), round(i_acc,3), round(i_acc_ee,3), round(i_acc-i_acc_ee,3), round(i_eff_gain,3))\n",
    "\n",
    "print(to_markdown_table(correct_rows), '\\n\\n')\n",
    "print(to_markdown_table(inc_rows))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "60f021a9-da53-41ae-9980-45dd68dab0cc",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llm-risk-control",
   "language": "python",
   "name": "llm-risk-control"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
