{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "566db452",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ! pip install colorcet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f089c26",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import os\n",
    "import sys\n",
    "from tqdm import tqdm\n",
    "from torch.utils.data import DataLoader, Dataset\n",
    "from torchvision.io import ImageReadMode, read_image\n",
    "from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize\n",
    "from torchvision.transforms.functional import InterpolationMode\n",
    "from datasets import load_dataset, Dataset\n",
    "from transformers import VisionEncoderDecoderModel\n",
    "from transformers import AutoTokenizer\n",
    "from transformers import AutoImageProcessor\n",
    "import pickle\n",
    "sys.path.append('.')\n",
    "sys.path.append('..')\n",
    "sys.path.append('./utils')\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "# plot\n",
    "import seaborn as sns\n",
    "sns.set_theme()\n",
    "sns.set_context(\"notebook\")\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "44cee832",
   "metadata": {},
   "source": [
    "# Load generated output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f55241a",
   "metadata": {},
   "outputs": [],
   "source": [
    "path = './output/llama-2-13b-chat-hf_triviaqa_10/0.pkl'\n",
    "with open(path, 'rb') as f:\n",
    "    sequences = pickle.load(f)\n",
    "print(len(sequences))\n",
    "print(sequences[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af207f0b",
   "metadata": {},
   "outputs": [],
   "source": [
    "tau = 0.3877551020408163    # revealed in saved filename\n",
    "path = f'./output/llama-2-13b-chat-hf_triviaqa_10/uq_result/scores_{tau}.csv'\n",
    "summ = pd.read_csv(path)\n",
    "path = f'./output/llama-2-13b-chat-hf_triviaqa_10/uq_result/labels_{tau}.csv'\n",
    "lab = np.array(pd.read_csv(path).iloc[:,0].to_list())\n",
    "\n",
    "ids = summ.iloc[:,0].to_list()\n",
    "summ_mean = 1 - np.array(summ.mean(1).to_list())\n",
    "rk = summ_mean.argsort()\n",
    "lab = lab[rk]\n",
    "id_min = rk[4000]\n",
    "id_max = rk[2530]\n",
    "id_seq = [sequences[ii]['id'] for ii in range(len(sequences))]\n",
    "\n",
    "id1 = id_seq.index(ids[id_min])\n",
    "print(sequences[id1]['question'])\n",
    "print(f\"Predicted alignment score: {summ_mean[id_min]}\")\n",
    "print(f\"True label: {lab[id_min]}\")\n",
    "print(\"\\nTrue answer:\\n\")\n",
    "print(''.join(sequences[id1]['answer'].split('\\n')).strip())\n",
    "print('-------------------------------------------\\nGenerated answer:\\n')\n",
    "print(''.join(sequences[id1]['most_likely_generation'].split('\\n')).strip())\n",
    "print(sequences[id1]['generations'])\n",
    "print('===========================================')\n",
    "id2 = id_seq.index(ids[id_max])\n",
    "print(sequences[id2]['question'])\n",
    "print(f\"Predicted alignment score: {summ_mean[id_max]}\")\n",
    "print(f\"True label: {lab[id_max]}\")\n",
    "print(\"\\nTrue answer:\\n\")\n",
    "print(''.join(sequences[id2]['answer'].split('\\n')).strip())\n",
    "print('-------------------------------------------\\nGenerated answer:\\n')\n",
    "print(''.join(sequences[id2]['most_likely_generation'].split('\\n')).strip())\n",
    "print(sequences[id2]['generations'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7b808a29",
   "metadata": {},
   "source": [
    "# Reproduce plots"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "32522ab0",
   "metadata": {},
   "source": [
    "## Results with varying reference sample size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "619eab3f",
   "metadata": {},
   "outputs": [],
   "source": [
    "SMALL_SIZE = 18\n",
    "MEDIUM_SIZE = 18\n",
    "BIGGER_SIZE = 16\n",
    "\n",
    "plt.rc('font', size=MEDIUM_SIZE)          # controls default text sizes\n",
    "plt.rc('axes', titlesize=MEDIUM_SIZE)     # fontsize of the axes title\n",
    "plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels\n",
    "plt.rc('xtick', labelsize=MEDIUM_SIZE)    # fontsize of the tick labels\n",
    "plt.rc('ytick', labelsize=MEDIUM_SIZE)    # fontsize of the tick labels\n",
    "plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize\n",
    "plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title\n",
    "\n",
    "repN = 500\n",
    "data = 'coqa'\n",
    "model_name = ['OPT-13B', 'LLaMA-2-13B-chat']\n",
    "\n",
    "for split_pr_tune in [0.2]:\n",
    "    for split_pr in [0.5]:\n",
    "        if split_pr + split_pr_tune<1:\n",
    "            for uq_name in ['rf', 'logistic', 'xgbrf']:\n",
    "    \n",
    "                fig, ax = plt.subplots(2,3,figsize=(12, 6))\n",
    "                fig.tight_layout()\n",
    "                plt.subplots_adjust(hspace=0.5)\n",
    "\n",
    "                xi = 0\n",
    "                for model in ['opt-13b', 'llama-2-13b-chat-hf']:\n",
    "                    par_path = f'./output/{model}_{data}_10'\n",
    "\n",
    "                    q_seq = np.round(np.linspace(0.05,0.95,19),2)\n",
    "                    train_size_seq = [100, 500, 2000]\n",
    "\n",
    "                    for i in tqdm(range(len(train_size_seq))):\n",
    "\n",
    "                        N = train_size_seq[i]\n",
    "\n",
    "                        path_fdr = os.path.join(par_path, f'uq_result/result/fdr_{model}_{data}_{uq_name}_{N}_{split_pr_tune}_{split_pr}_{repN}.pkl')\n",
    "                        path_power = os.path.join(par_path, f'uq_result/result/power_{model}_{data}_{uq_name}_{N}_{split_pr_tune}_{split_pr}_{repN}.pkl')\n",
    "                        path_fdr_std = os.path.join(par_path, f'uq_result/result/fdr_std_{model}_{data}_{uq_name}_{N}_{split_pr_tune}_{split_pr}_{repN}.pkl')\n",
    "                        path_power_std = os.path.join(par_path, f'uq_result/result/power_std_{model}_{data}_{uq_name}_{N}_{split_pr_tune}_{split_pr}_{repN}.pkl')\n",
    "\n",
    "                        with open(path_fdr, 'rb') as f:\n",
    "                            fdp_seq = pickle.load(f)\n",
    "                        with open(path_power, 'rb') as f:\n",
    "                            power_seq = pickle.load(f)\n",
    "                        with open(path_fdr_std, 'rb') as f:\n",
    "                            fdp_std = pickle.load(f)\n",
    "                        with open(path_power_std, 'rb') as f:\n",
    "                            power_std = pickle.load(f)\n",
    "                            \n",
    "                        if len(fdp_seq) > 20:\n",
    "                            idx_ = np.arange(4,99,5)\n",
    "                            fdp_seq = fdp_seq[idx_]\n",
    "                            power_seq = power_seq[idx_]\n",
    "                            fdp_std = fdp_std[idx_]\n",
    "                            power_std = power_std[idx_]\n",
    "\n",
    "\n",
    "                        ax[xi,i].plot(q_seq, fdp_seq, 'bo--', label='FDR')\n",
    "                        ax[xi,i].plot(q_seq, power_seq, 'rs--', label='Power')\n",
    "                        ax[xi,i].fill_between(q_seq, fdp_seq-fdp_std, fdp_seq+fdp_std, alpha=0.5, \n",
    "                            edgecolor='lightblue', facecolor='lightblue')\n",
    "                        ax[xi,i].fill_between(q_seq, power_seq-power_std, power_seq+power_std, alpha=0.5, \n",
    "                            edgecolor='lightpink', facecolor='lightpink')\n",
    "                    \n",
    "                        if xi==1:\n",
    "                            ax[xi,i].set_xlabel(\"Target level of FDR\")\n",
    "                        ax[xi,i].set_ylabel(\"FDR and Power\")\n",
    "\n",
    "                        lims = [\n",
    "                            np.min([ax[xi,i].get_xlim(), ax[xi,i].get_ylim()]),  # min of both axes\n",
    "                            np.max([ax[xi,i].get_xlim(), ax[xi,i].get_ylim()]),  # max of both axes\n",
    "                        ]\n",
    "                        ax[xi,i].plot(lims, lims, 'g--', alpha=0.75, zorder=0)\n",
    "                        ax[xi,i].set_title(r'%s: $|D|$=%s'%(model_name[xi],N,))\n",
    "                        ax[xi,i].set_xlim(lims)\n",
    "                        ax[xi,i].set_ylim(lims)\n",
    "                        ax[xi,i].set_xlim((0,1))\n",
    "                        ax[xi,i].set_ylim((-0.03,1.03))\n",
    "                        ax[xi,i].legend(loc='best', prop = { \"size\": 16 })\n",
    "\n",
    "                    xi += 1\n",
    "\n",
    "                plot_dir = f'./plots/fdp_plot_{data}_{uq_name}_{split_pr_tune}_{split_pr}_{repN}.pdf'\n",
    "                plt.savefig(plot_dir, dpi=400, bbox_inches='tight')\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3fa2ddac",
   "metadata": {},
   "source": [
    "## Results with varying $\\gamma_1$: sample size for parameter tuning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f6173085",
   "metadata": {},
   "outputs": [],
   "source": [
    "# vary split_pr_tune\n",
    "\n",
    "repN = 500\n",
    "data = 'coqa'\n",
    "model_name = ['OPT-13B', 'LLaMA-2-13B-chat']\n",
    "train_size_seq = [2000]\n",
    "pr_tune_seq = [0.2,0.4,0.6]\n",
    "\n",
    "for N in train_size_seq:\n",
    "    for split_pr in [0.3]:\n",
    "        if split_pr + split_pr_tune<1:\n",
    "            for uq_name in ['rf', 'logistic', 'xgbrf']:\n",
    "    \n",
    "                fig, ax = plt.subplots(2,3,figsize=(12, 6))\n",
    "                fig.tight_layout()\n",
    "                plt.subplots_adjust(hspace=0.5)\n",
    "\n",
    "                xi = 0\n",
    "                for model in ['opt-13b', 'llama-2-13b-chat-hf']:\n",
    "                    par_path = f'./output/{model}_{data}_10'\n",
    "\n",
    "                    q_seq = np.round(np.linspace(0.05,0.95,19),2)\n",
    "\n",
    "                    for i in tqdm(range(len(pr_tune_seq))):\n",
    "                        \n",
    "                        split_pr_tune = pr_tune_seq[i]\n",
    "\n",
    "                        path_fdr = os.path.join(par_path, f'uq_result/result/fdr_{model}_{data}_{uq_name}_{N}_{split_pr_tune}_{split_pr}_{repN}.pkl')\n",
    "                        path_power = os.path.join(par_path, f'uq_result/result/power_{model}_{data}_{uq_name}_{N}_{split_pr_tune}_{split_pr}_{repN}.pkl')\n",
    "                        path_fdr_std = os.path.join(par_path, f'uq_result/result/fdr_std_{model}_{data}_{uq_name}_{N}_{split_pr_tune}_{split_pr}_{repN}.pkl')\n",
    "                        path_power_std = os.path.join(par_path, f'uq_result/result/power_std_{model}_{data}_{uq_name}_{N}_{split_pr_tune}_{split_pr}_{repN}.pkl')\n",
    "\n",
    "                        with open(path_fdr, 'rb') as f:\n",
    "                            fdp_seq = pickle.load(f)\n",
    "                        with open(path_power, 'rb') as f:\n",
    "                            power_seq = pickle.load(f)\n",
    "                        with open(path_fdr_std, 'rb') as f:\n",
    "                            fdp_std = pickle.load(f)\n",
    "                        with open(path_power_std, 'rb') as f:\n",
    "                            power_std = pickle.load(f)\n",
    "                            \n",
    "                        if len(fdp_seq) > 20:\n",
    "                            idx_ = np.arange(4,99,5)\n",
    "                            fdp_seq = fdp_seq[idx_]\n",
    "                            power_seq = power_seq[idx_]\n",
    "                            fdp_std = fdp_std[idx_]\n",
    "                            power_std = power_std[idx_]\n",
    "\n",
    "\n",
    "                        ax[xi,i].plot(q_seq, fdp_seq, 'bo--', label='FDR')\n",
    "                        ax[xi,i].plot(q_seq, power_seq, 'rs--', label='Power')\n",
    "                        ax[xi,i].fill_between(q_seq, fdp_seq-fdp_std, fdp_seq+fdp_std, alpha=0.5, \n",
    "                            edgecolor='lightblue', facecolor='lightblue')\n",
    "                        ax[xi,i].fill_between(q_seq, power_seq-power_std, power_seq+power_std, alpha=0.5, \n",
    "                            edgecolor='lightpink', facecolor='lightpink')\n",
    "                    \n",
    "                        if xi==1:\n",
    "                            ax[xi,i].set_xlabel(\"Target level of FDR\")\n",
    "                        ax[xi,i].set_ylabel(\"FDR and Power\")\n",
    "\n",
    "                        lims = [\n",
    "                            np.min([ax[xi,i].get_xlim(), ax[xi,i].get_ylim()]),  # min of both axes\n",
    "                            np.max([ax[xi,i].get_xlim(), ax[xi,i].get_ylim()]),  # max of both axes\n",
    "                        ]\n",
    "                        ax[xi,i].plot(lims, lims, 'g--', alpha=0.75, zorder=0)\n",
    "                        ax[xi,i].set_title(r'%s: $\\gamma_1$=%s'%(model_name[xi],split_pr_tune,))\n",
    "                        ax[xi,i].set_xlim(lims)\n",
    "                        ax[xi,i].set_ylim(lims)\n",
    "                        ax[xi,i].set_xlim((0,1))\n",
    "                        ax[xi,i].set_ylim((-0.03,1.03))\n",
    "                        ax[xi,i].legend(loc='best', prop = { \"size\": 16 })\n",
    "\n",
    "                    xi += 1\n",
    "\n",
    "                plot_dir = f'./plots/fdp_plot_gamma_1_{data}_{uq_name}_{N}_{split_pr}_{repN}.pdf'\n",
    "                plt.savefig(plot_dir, dpi=400, bbox_inches='tight')\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4bc78a13",
   "metadata": {},
   "source": [
    "## Results with varying $\\gamma_2$: sample size for training the alignment score predictor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69ba8cf3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# vary split_pr\n",
    "\n",
    "repN = 500\n",
    "data = 'coqa'\n",
    "model_name = ['OPT-13B', 'LLaMA-2-13B-chat']\n",
    "train_size_seq = [100]\n",
    "pr_seq = [0.1,0.3,0.5,0.7]\n",
    "\n",
    "for N in train_size_seq:\n",
    "    for split_pr_tune in [0.2]:\n",
    "        for uq_name in ['rf', 'logistic', 'xgbrf']:\n",
    "\n",
    "            fig, ax = plt.subplots(2,len(pr_seq),figsize=(4*len(pr_seq), 6))\n",
    "            fig.tight_layout()\n",
    "            plt.subplots_adjust(hspace=0.5)\n",
    "\n",
    "            xi = 0\n",
    "            for model in ['opt-13b', 'llama-2-13b-chat-hf']:\n",
    "                par_path = f'./output/{model}_{data}_10'\n",
    "\n",
    "                q_seq = np.round(np.linspace(0.05,0.95,19),2)\n",
    "\n",
    "                for i in tqdm(range(len(pr_seq))):\n",
    "\n",
    "                    split_pr = pr_seq[i]\n",
    "\n",
    "                    path_fdr = os.path.join(par_path, f'uq_result/result/fdr_{model}_{data}_{uq_name}_{N}_{split_pr_tune}_{split_pr}_{repN}.pkl')\n",
    "                    path_power = os.path.join(par_path, f'uq_result/result/power_{model}_{data}_{uq_name}_{N}_{split_pr_tune}_{split_pr}_{repN}.pkl')\n",
    "                    path_fdr_std = os.path.join(par_path, f'uq_result/result/fdr_std_{model}_{data}_{uq_name}_{N}_{split_pr_tune}_{split_pr}_{repN}.pkl')\n",
    "                    path_power_std = os.path.join(par_path, f'uq_result/result/power_std_{model}_{data}_{uq_name}_{N}_{split_pr_tune}_{split_pr}_{repN}.pkl')\n",
    "\n",
    "                    with open(path_fdr, 'rb') as f:\n",
    "                        fdp_seq = pickle.load(f)\n",
    "                    with open(path_power, 'rb') as f:\n",
    "                        power_seq = pickle.load(f)\n",
    "                    with open(path_fdr_std, 'rb') as f:\n",
    "                        fdp_std = pickle.load(f)\n",
    "                    with open(path_power_std, 'rb') as f:\n",
    "                        power_std = pickle.load(f)\n",
    "                        \n",
    "                    if len(fdp_seq) > 20:\n",
    "                        idx_ = np.arange(4,99,5)\n",
    "                        fdp_seq = fdp_seq[idx_]\n",
    "                        power_seq = power_seq[idx_]\n",
    "                        fdp_std = fdp_std[idx_]\n",
    "                        power_std = power_std[idx_]\n",
    "\n",
    "\n",
    "                    ax[xi,i].plot(q_seq, fdp_seq, 'bo--', label='FDR')\n",
    "                    ax[xi,i].plot(q_seq, power_seq, 'rs--', label='Power')\n",
    "                    ax[xi,i].fill_between(q_seq, fdp_seq-fdp_std, fdp_seq+fdp_std, alpha=0.5, \n",
    "                        edgecolor='lightblue', facecolor='lightblue')\n",
    "                    ax[xi,i].fill_between(q_seq, power_seq-power_std, power_seq+power_std, alpha=0.5, \n",
    "                        edgecolor='lightpink', facecolor='lightpink')\n",
    "\n",
    "                    if xi==1:\n",
    "                        ax[xi,i].set_xlabel(\"Target level of FDR\")\n",
    "                    ax[xi,i].set_ylabel(\"FDR and Power\")\n",
    "\n",
    "                    lims = [\n",
    "                        np.min([ax[xi,i].get_xlim(), ax[xi,i].get_ylim()]),  # min of both axes\n",
    "                        np.max([ax[xi,i].get_xlim(), ax[xi,i].get_ylim()]),  # max of both axes\n",
    "                    ]\n",
    "                    ax[xi,i].plot(lims, lims, 'g--', alpha=0.75, zorder=0)\n",
    "                    ax[xi,i].set_title(r'%s: $\\gamma_2$=%s'%(model_name[xi],split_pr,))\n",
    "                    ax[xi,i].set_xlim(lims)\n",
    "                    ax[xi,i].set_ylim(lims)\n",
    "                    ax[xi,i].set_xlim((0,1))\n",
    "                    ax[xi,i].set_ylim((-0.03,1.03))\n",
    "                    ax[xi,i].legend(loc='best', prop = { \"size\": 16 })\n",
    "\n",
    "                xi += 1\n",
    "\n",
    "            plot_dir = f'./plots/fdp_plot_gamma_2_{data}_{uq_name}_{N}_{split_pr_tune}_{repN}.pdf'\n",
    "            plt.savefig(plot_dir, dpi=400, bbox_inches='tight')\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7f54bfad",
   "metadata": {},
   "source": [
    "## Effect of features in alignment score predictor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b3450395",
   "metadata": {},
   "outputs": [],
   "source": [
    "import colorcet as cc\n",
    "from matplotlib.cm import get_cmap\n",
    "\n",
    "# plot\n",
    "import seaborn as sns\n",
    "sns.set_theme()\n",
    "sns.set_context(\"notebook\")\n",
    "import matplotlib.pyplot as plt\n",
    "SMALL_SIZE = 14\n",
    "MEDIUM_SIZE = 20\n",
    "BIGGER_SIZE = 20\n",
    "\n",
    "plt.rc('font', size=MEDIUM_SIZE)          # controls default text sizes\n",
    "plt.rc('axes', titlesize=MEDIUM_SIZE)     # fontsize of the axes title\n",
    "plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels\n",
    "plt.rc('xtick', labelsize=MEDIUM_SIZE)    # fontsize of the tick labels\n",
    "plt.rc('ytick', labelsize=MEDIUM_SIZE)    # fontsize of the tick labels\n",
    "plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize\n",
    "plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title\n",
    "\n",
    "\n",
    "# ROC\n",
    "data = 'triviaqa'\n",
    "model = 'opt-13b'\n",
    "N = 2000\n",
    "uq_name = 'logistic'\n",
    "split_pr_tune = 0.2\n",
    "split_pr = 0.5\n",
    "repN = 500\n",
    "model_name = ['OPT-13B', 'LLaMA-2-13B-chat']\n",
    "uq_list = [\n",
    "        'generations|numsets', \n",
    "        'lexical_sim',\n",
    "        'generations|spectral_eigv_clip|disagreement_w',\n",
    "        'generations|eccentricity|disagreement_w',\n",
    "        'generations|degree|disagreement_w',\n",
    "        'generations|spectral_eigv_clip|agreement_w',\n",
    "        'generations|eccentricity|agreement_w',\n",
    "        'generations|degree|agreement_w',\n",
    "        'generations|spectral_eigv_clip|jaccard',\n",
    "        'generations|eccentricity|jaccard',\n",
    "        'generations|degree|jaccard',\n",
    "        'semanticEntropy|unnorm', \n",
    "        'self_prob',\n",
    "]\n",
    "def name_map(v):\n",
    "    if v == 'self_prob': return \"Self_Eval\"\n",
    "    if v == 'lexical_sim': return \"Lexical_Sim\"\n",
    "    v = v.replace(\"|numsets\", \"|Num_Sets\")\n",
    "    v = v.replace(\"|disagreement_w\", \"|(C)\")\n",
    "    v = v.replace(\"|agreement_w\", \"|(E)\")\n",
    "    v = v.replace(\"|jaccard\", \"|(J)\")\n",
    "    v = v.replace(\"spectral_eigv_clip|\", \"EigV\")\n",
    "    v = v.replace(\"eccentricity|\", \"Ecc\")\n",
    "    v = v.replace(\"degree|\", \"Deg\")\n",
    "    return {'semanticEntropy|unnorm': 'SE',\n",
    "            'blind': 'Basse Accuracy'}.get(v,v)\n",
    "    return v\n",
    "\n",
    "xi = 0\n",
    "\n",
    "q_seq = np.round(np.linspace(0.01,0.99,99),2)\n",
    "fig, ax = plt.subplots(1,2,figsize=(11, 5))\n",
    "fig.tight_layout()\n",
    "    \n",
    "for model in ['opt-13b', 'llama-2-13b-chat-hf']:\n",
    "    par_path = f'./output/{model}_{data}_10'\n",
    "    path_fdr = os.path.join(par_path, f'uq_result/result/uq_fdr_{model}_{data}_{uq_name}_{N}_{split_pr_tune}_{split_pr}_{repN}.pkl')\n",
    "    path_power = os.path.join(par_path, f'uq_result/result/uq_power_{model}_{data}_{uq_name}_{N}_{split_pr_tune}_{split_pr}_{repN}.pkl')\n",
    "    with open(path_fdr, 'rb') as f:\n",
    "        fdr_uq = pickle.load(f)\n",
    "    with open(path_power, 'rb') as f:\n",
    "        power_uq = pickle.load(f)\n",
    "    print(fdr_uq.shape)\n",
    "\n",
    "    idx_seq = [len(fdr_uq)-1, len(fdr_uq)-2] + [i for i in range(len(fdr_uq)-2)]\n",
    "\n",
    "    for i in range(len(idx_seq)):\n",
    "        i_ = idx_seq[i]\n",
    "        fdp_seq = fdr_uq[i_]\n",
    "        power_seq = power_uq[i_]\n",
    "        uq_nam = name_map(uq_list[i_])\n",
    "        uq_nam = uq_nam.split('generations|')[-1]\n",
    "        \n",
    "        if uq_nam == 'Self_Eval':\n",
    "            color = 'black'\n",
    "            mark = 's'\n",
    "            lty = '--'\n",
    "        elif uq_nam in ['Num_Sets', 'SE', 'Lexical_Sim']:\n",
    "            color = cc.fire[50*(i+1)]\n",
    "            mark = 'o'\n",
    "            lty = '--'\n",
    "        else:\n",
    "            color = cc.CET_C6[15*i]\n",
    "            mark = '.'\n",
    "            lty = '-'\n",
    "\n",
    "        val = \"Power\"\n",
    "        if val == 'FDR':\n",
    "            val_seq = fdp_seq\n",
    "        else:\n",
    "            val_seq = power_seq\n",
    "        \n",
    "        ax[xi].plot(q_seq, val_seq, linestyle=lty, marker=mark, markersize=6, linewidth=2.5, label=uq_nam, c=color)\n",
    "        ax[xi].set_xlabel(\"Target level of FDR\")\n",
    "        ax[xi].set_ylabel(val)\n",
    "\n",
    "        lims = [\n",
    "            np.min([ax[xi].get_xlim(), ax[xi].get_ylim()]),  # min of both axes\n",
    "            np.max([ax[xi].get_xlim(), ax[xi].get_ylim()]),  # max of both axes\n",
    "        ]\n",
    "        ax[xi].plot(lims, lims, 'g--', alpha=0.75, zorder=0)\n",
    "        ax[xi].set_xlim(lims)\n",
    "        ax[xi].set_ylim(lims)\n",
    "        ax[xi].set_ylim((0,1.02))\n",
    "        ax[xi].set_title(r'%s'%(model_name[xi],))\n",
    "    xi += 1\n",
    "ax[1].legend(loc='best',bbox_to_anchor=(1, 1.02))\n",
    "\n",
    "plt.savefig(f'./plots/ROC _{val}_plot_{data}_{uq_name}_{N}_{split_pr_tune}_{split_pr}_{repN}.pdf', dpi=400, bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e1fd1294",
   "metadata": {},
   "source": [
    "## Comparison with the baseline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d336585",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ROC\n",
    "SMALL_SIZE = 22\n",
    "MEDIUM_SIZE = 23\n",
    "BIGGER_SIZE = 23\n",
    "\n",
    "plt.rc('font', size=MEDIUM_SIZE)          # controls default text sizes\n",
    "plt.rc('axes', titlesize=MEDIUM_SIZE)     # fontsize of the axes title\n",
    "plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels\n",
    "plt.rc('xtick', labelsize=MEDIUM_SIZE)    # fontsize of the tick labels\n",
    "plt.rc('ytick', labelsize=MEDIUM_SIZE)    # fontsize of the tick labels\n",
    "plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize\n",
    "plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title\n",
    "\n",
    "\n",
    "\n",
    "data = 'triviaqa'\n",
    "# model = 'llama-2-13b-chat-hf'\n",
    "model = 'opt-13b'\n",
    "model_seq = ['opt-13b', 'llama-2-13b-chat-hf']\n",
    "N = 2000\n",
    "uq_name = 'logistic'\n",
    "split_pr_tune = 0.2\n",
    "split_pr = 0.5\n",
    "repN = 500\n",
    "model_name = ['OPT-13B', 'LLaMA-2-13B-chat']\n",
    "uq_list = [\n",
    "        'generations|numsets', \n",
    "        'lexical_sim',\n",
    "        'generations|spectral_eigv_clip|disagreement_w',\n",
    "        'generations|eccentricity|disagreement_w',\n",
    "        'generations|degree|disagreement_w',\n",
    "        'generations|spectral_eigv_clip|agreement_w',\n",
    "        'generations|eccentricity|agreement_w',\n",
    "        'generations|degree|agreement_w',\n",
    "        'generations|spectral_eigv_clip|jaccard',\n",
    "        'generations|eccentricity|jaccard',\n",
    "        'generations|degree|jaccard',\n",
    "        'semanticEntropy|unnorm', \n",
    "        'self_prob',\n",
    "]\n",
    "uq_list0 = [\n",
    "    'generations|numsets', \n",
    "    'semanticEntropy|unnorm', \n",
    "    'generations|eccentricity|jaccard',\n",
    "    'generations|degree|jaccard',\n",
    "]\n",
    "def name_map(v):\n",
    "    if v == 'self_prob': return \"Self_Eval\"\n",
    "    if v == 'lexical_sim': return \"Lexical_Sim\"\n",
    "    v = v.replace(\"|numsets\", \"|Num_Sets\")\n",
    "    v = v.replace(\"|disagreement_w\", \"|(C)\")\n",
    "    v = v.replace(\"|agreement_w\", \"|(E)\")\n",
    "    v = v.replace(\"|jaccard\", \"|(J)\")\n",
    "    v = v.replace(\"spectral_eigv_clip|\", \"EigV\")\n",
    "    v = v.replace(\"eccentricity|\", \"Ecc\")\n",
    "    v = v.replace(\"degree|\", \"Deg\")\n",
    "    return {'semanticEntropy|unnorm': 'SE',\n",
    "            'blind': 'Basse Accuracy'}.get(v,v)\n",
    "    return v\n",
    "\n",
    "\n",
    "fig, ax = plt.subplots(1,4,figsize=(20, 5))\n",
    "fig.tight_layout()\n",
    "q_seq = np.linspace(0.05,0.95,19)\n",
    "\n",
    "for xi in range(2):\n",
    "    \n",
    "    model = model_seq[xi]\n",
    "    par_path = f'./output/{model}_{data}_10'\n",
    "    path_fdr_ptrue = os.path.join(par_path, f'uq_result/result/baseline_ptrue_fdr_{model}_{data}_{uq_name}_{N}_{split_pr_tune}_{split_pr}_{repN}.pkl')\n",
    "    path_power_ptrue = os.path.join(par_path, f'uq_result/result/baseline_ptrue_power_{model}_{data}_{uq_name}_{N}_{split_pr_tune}_{split_pr}_{repN}.pkl')\n",
    "    path_fdr_std_ptrue = os.path.join(par_path, f'uq_result/result/baseline_ptrue_fdr_std_{model}_{data}_{uq_name}_{N}_{split_pr_tune}_{split_pr}_{repN}.pkl')\n",
    "    path_power_std_ptrue = os.path.join(par_path, f'uq_result/result/baseline_ptrue_power_std_{model}_{data}_{uq_name}_{N}_{split_pr_tune}_{split_pr}_{repN}.pkl')\n",
    "    with open(path_fdr_ptrue, 'rb') as f:\n",
    "        fdp_base_ptrue = pickle.load(f).ravel()\n",
    "    with open(path_power_ptrue, 'rb') as f:\n",
    "        power_base_ptrue = pickle.load(f).ravel()\n",
    "    with open(path_fdr_std_ptrue, 'rb') as f:\n",
    "        fdp_std_base_ptrue = pickle.load(f).ravel()\n",
    "    with open(path_power_std_ptrue, 'rb') as f:\n",
    "        power_std_base_ptrue = pickle.load(f).ravel()\n",
    "\n",
    "    path_fdr_ = os.path.join(par_path, f'uq_result/result/fdr_{model}_{data}_{uq_name}_{N}_{split_pr_tune}_{split_pr}_{repN}.pkl')\n",
    "    path_power_ = os.path.join(par_path, f'uq_result/result/power_{model}_{data}_{uq_name}_{N}_{split_pr_tune}_{split_pr}_{repN}.pkl')\n",
    "    path_fdr_std = os.path.join(par_path, f'uq_result/result/fdr_std_{model}_{data}_{uq_name}_{N}_{split_pr_tune}_{split_pr}_{repN}.pkl')\n",
    "    path_power_std = os.path.join(par_path, f'uq_result/result/power_std_{model}_{data}_{uq_name}_{N}_{split_pr_tune}_{split_pr}_{repN}.pkl')\n",
    "    with open(path_fdr_, 'rb') as f:\n",
    "        fdp_seq0 = pickle.load(f)\n",
    "    with open(path_power_, 'rb') as f:\n",
    "        power_seq0 = pickle.load(f)\n",
    "    with open(path_fdr_std, 'rb') as f:\n",
    "        fdp_std = pickle.load(f)\n",
    "    with open(path_power_std, 'rb') as f:\n",
    "        power_std = pickle.load(f)\n",
    "\n",
    "    ax[xi*2].plot(q_seq, fdp_seq0, linestyle='-', marker='o', markersize=6, linewidth=2.5, label='Conf_Align', c='blue')\n",
    "    ax[xi*2+1].plot(q_seq, power_seq0, linestyle='-', marker='o', markersize=6, linewidth=2.5, label='Conf_Align', c='blue')\n",
    "    ax[xi*2].fill_between(q_seq, fdp_seq0-fdp_std, fdp_seq0+fdp_std, alpha=0.5, \n",
    "        edgecolor='lightblue', facecolor='lightblue')\n",
    "    ax[xi*2+1].fill_between(q_seq, power_seq0-power_std, power_seq0+power_std, alpha=0.5, \n",
    "        edgecolor='lightblue', facecolor='lightblue')\n",
    "\n",
    "    ax[xi*2].plot(q_seq, fdp_base_ptrue, linestyle='-', marker='s', markersize=6, linewidth=2.5, label='Self_Eval', c='black')\n",
    "    ax[xi*2+1].plot(q_seq, power_base_ptrue, linestyle='-', marker='s', markersize=6, linewidth=2.5, label='Self_Eval', c='black')\n",
    "    ax[xi*2].fill_between(q_seq, fdp_base_ptrue-fdp_std_base_ptrue, fdp_base_ptrue+fdp_std_base_ptrue, alpha=0.5, \n",
    "        edgecolor='lightgray', facecolor='lightgray')\n",
    "    ax[xi*2+1].fill_between(q_seq, power_base_ptrue-power_std_base_ptrue, power_base_ptrue+power_std_base_ptrue, alpha=0.5, \n",
    "        edgecolor='lightgray', facecolor='lightgray')\n",
    "\n",
    "    ax[xi*2].set_xlabel(\"target level of FDR\")\n",
    "    ax[xi*2].set_ylabel(\"FDR\")\n",
    "    ax[xi*2+1].set_xlabel(\"target level of FDR\")\n",
    "    ax[xi*2+1].set_ylabel(\"Power\")\n",
    "\n",
    "    lims = [\n",
    "        np.min([ax[xi*2].get_xlim(), ax[xi*2].get_ylim()]),  # min of both axes\n",
    "        np.max([ax[xi*2].get_xlim(), ax[xi*2].get_ylim()]),  # max of both axes\n",
    "    ]\n",
    "    ax[xi*2].plot(lims, lims, 'g--', alpha=0.75, zorder=0)\n",
    "    ax[xi*2].set_xlim((0,1))\n",
    "    ax[xi*2].set_ylim((-0.02,1))\n",
    "    ax[xi*2].set_title(model_name[xi])\n",
    "    ax[xi*2+1].set_title(model_name[xi])\n",
    "    ax[3].legend(loc='lower right')\n",
    "\n",
    "    plt.savefig(f'./plots/baseline_plot_{data}_{uq_name}_{N}_{split_pr_tune}_{split_pr}_{repN}.pdf', dpi=400, bbox_inches='tight')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
