{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "44340263",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ! pip install colorcet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f089c26",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import os\n",
    "import sys\n",
    "from tqdm import tqdm\n",
    "from torch.utils.data import DataLoader, Dataset\n",
    "from torchvision.io import ImageReadMode, read_image\n",
    "from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize\n",
    "from torchvision.transforms.functional import InterpolationMode\n",
    "from datasets import load_dataset, Dataset\n",
    "from transformers import VisionEncoderDecoderModel\n",
    "from transformers import AutoTokenizer\n",
    "from transformers import AutoImageProcessor\n",
    "import pickle\n",
    "\n",
    "sys.path.append('.')\n",
    "sys.path.append('..')\n",
    "sys.path.append('./utils')\n",
    "\n",
    "from _utils import *\n",
    "from chexbert import *\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "mimic_dir = '../conformal_align/codes/cxr/physionet.org/files/mimic-cxr-resized-224/'\n",
    "ckpt = \"./models/trained_p101112\"\n",
    "\n",
    "split_name = 'dev'\n",
    "\n",
    "num_gens = 10\n",
    "\n",
    "# plot\n",
    "import seaborn as sns\n",
    "sns.set_theme()\n",
    "sns.set_context(\"notebook\")\n",
    "import matplotlib.pyplot as plt\n",
    "SMALL_SIZE = 16\n",
    "MEDIUM_SIZE = 16\n",
    "BIGGER_SIZE = 16\n",
    "\n",
    "plt.rc('font', size=SMALL_SIZE)          # controls default text sizes\n",
    "plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title\n",
    "plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels\n",
    "plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels\n",
    "plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels\n",
    "plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize\n",
    "plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d55471a8",
   "metadata": {},
   "source": [
    "# Load fine-tuned VLM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec40ed2d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# load model\n",
    "print(\"Loading model...\")\n",
    "trained_model = VisionEncoderDecoderModel.from_pretrained(ckpt)\n",
    "\n",
    "print(\"Loading tokenizer...\")\n",
    "tokenizer = AutoTokenizer.from_pretrained(\n",
    "        \"gpt2\",\n",
    "        use_fast=True\n",
    "    )\n",
    "tokenizer.eos_token = tokenizer.decode(tokenizer.eos_token_id)\n",
    "tokenizer.bos_token = tokenizer.decode(tokenizer.bos_token_id)\n",
    "tokenizer.pad_token = tokenizer.eos_token\n",
    "\n",
    "print(\"Loading image processor...\")\n",
    "image_processor = AutoImageProcessor.from_pretrained(\n",
    "        \"google/vit-base-patch16-224-in21k\"\n",
    ")\n",
    "\n",
    "\n",
    "# Load dataset\n",
    "print(\"Loading dataset...\")\n",
    "base = \"./ap_pa_per_dicom_id\"\n",
    "dataset = load_dataset(\n",
    "    \"json\",\n",
    "    data_files={\n",
    "        \"train\": os.path.join(base, \"train.jsonl\"),\n",
    "        \"dev\": os.path.join(base, \"dev.jsonl\"),\n",
    "        \"validate\": os.path.join(base, \"validate.jsonl\"),\n",
    "        \"test\": os.path.join(base, \"test.jsonl\"),\n",
    "    }\n",
    ")\n",
    "\n",
    "\n",
    "data_sub1 = dataset['dev']\n",
    "data_sub2 = dataset['test']\n",
    "dt1 = [data_sub1[i] for i in range(len(data_sub1))]\n",
    "dt2 = [data_sub2[i] for i in range(len(data_sub2))]\n",
    "dt1.extend(dt2)\n",
    "data_sub = Dataset.from_list(dt1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b4701491",
   "metadata": {},
   "source": [
    "## an example of report generating"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4a2aec1",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_sub = data_sub[[1,10]]\n",
    "\n",
    "true_finding = [''.join(text.split(get_before_findings(text))).strip()  for text in test_sub['report']]\n",
    "test_prompts_sub = [get_before_findings(text) for text in test_sub['report']]\n",
    "tokenizer.padding_side = \"left\"\n",
    "test_input_ids_sub, test_attention_masks_sub = tokenizer(test_prompts_sub, return_tensors=\"pt\", padding=True).values()\n",
    "\n",
    "images_sub = [read_image(os.path.join(mimic_dir, image_path), mode=ImageReadMode.RGB) for image_path in test_sub['image_path']]\n",
    "image_size = image_processor.size['height']\n",
    "inference_image_transformations = Transform(image_size, image_processor.image_mean, image_processor.image_std)\n",
    "inference_image_transformations = torch.jit.script(inference_image_transformations)\n",
    "\n",
    "def inference_jit_image_processor(images):\n",
    "    return torch.stack([inference_image_transformations(image) for image in images], dim=0)\n",
    "\n",
    "test_pixel_values_sub = inference_jit_image_processor(images_sub)\n",
    "\n",
    "test_sub[\"pixel_values\"] = test_pixel_values_sub\n",
    "test_sub[\"decoder_input_ids\"] = test_input_ids_sub\n",
    "\n",
    "input_length = test_input_ids_sub.shape[1]\n",
    "\n",
    "input_datum = {\n",
    "    \"pixel_values\": test_pixel_values_sub,\n",
    "    \"decoder_input_ids\": test_input_ids_sub,\n",
    "    \"decoder_attention_mask\": test_attention_masks_sub,\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f4e3b077",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn.functional as F\n",
    "output = trained_model.generate(\n",
    "        **input_datum, \n",
    "        max_length=512, \n",
    "        do_sample=True, \n",
    "        num_return_sequences=1\n",
    "        )\n",
    "\n",
    "texts = tokenizer.batch_decode(output, skip_special_tokens=True)\n",
    "for t in texts:\n",
    "    print(\"--------------\")\n",
    "    print(t)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d6e50397",
   "metadata": {},
   "source": [
    "# Load generated report"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "195daaba",
   "metadata": {},
   "outputs": [],
   "source": [
    "path = './output/trained_cxr_10/01.pkl'\n",
    "with open(path, 'rb') as f:\n",
    "    sequences = pickle.load(f)\n",
    "print(len(sequences))\n",
    "print(sequences[0]['question'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "023b987e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "tau = 0.5918367346938775\n",
    "path = f'./output/trained_cxr_10/uq_result/scores_{tau}.csv'\n",
    "summ = pd.read_csv(path)\n",
    "path = f'./output/trained_cxr_10/uq_result/labels_{tau}.csv'\n",
    "lab = np.array(pd.read_csv(path).iloc[:,0].to_list())\n",
    "\n",
    "ids = summ.iloc[:,0].to_list()\n",
    "summ_mean = 1 - np.array(summ.mean(1).to_list())\n",
    "rk = summ_mean.argsort()\n",
    "lab = lab[rk]\n",
    "# 1500, 7000, 1556, 7330\n",
    "id_min = rk[1500]\n",
    "id_max = rk[7000]\n",
    "id_seq = [sequences[ii]['id'] for ii in range(len(sequences))]\n",
    "\n",
    "id1 = id_seq.index(ids[id_min])\n",
    "print(sequences[id1]['question'])\n",
    "print(sequences[id1]['chexbert_score_generations'])\n",
    "assert  data_sub[id1]['dicom_id'] == sequences[id1]['id']\n",
    "print(data_sub[id1]['image_path'])\n",
    "print(f\"Predicted alignment score: {summ_mean[id_min]}\")\n",
    "print(f\"True label: {lab[id_min]}\")\n",
    "print(\"\\nTrue report:\\n\")\n",
    "print(''.join(sequences[id1]['answer'].split('\\n')).strip())\n",
    "print('-------------------------------------------\\nGenerated report:\\n')\n",
    "print(''.join(sequences[id1]['most_likely_generation'][1:].split('\\n')).strip())\n",
    "print('===========================================')\n",
    "id2 = id_seq.index(ids[id_max])\n",
    "print(sequences[id2]['question'])\n",
    "print(sequences[id2]['chexbert_score_generations'])\n",
    "assert  data_sub[id2]['dicom_id'] == sequences[id2]['id']\n",
    "print(data_sub[id2]['image_path'])\n",
    "print(f\"Predicted alignment score: {summ_mean[id_max]}\")\n",
    "print(f\"True label: {lab[id_max]}\")\n",
    "print(\"\\nTrue report:\\n\")\n",
    "print(''.join(sequences[id2]['answer'].split('\\n')).strip())\n",
    "print('-------------------------------------------\\nGenerated report:\\n')\n",
    "print(''.join(sequences[id2]['most_likely_generation'][1:].split('\\n')).strip())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b1786fa9",
   "metadata": {},
   "source": [
    "# Reproduce plots"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8b930964",
   "metadata": {},
   "source": [
    "## Results with varying reference sample size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b62dfbf",
   "metadata": {},
   "outputs": [],
   "source": [
    "# for CXR: varying N\n",
    "\n",
    "repN = 500\n",
    "data = 'cxr'\n",
    "model = 'trained'\n",
    "\n",
    "for split_pr_tune in [0.2]:\n",
    "    for split_pr in [0.5]:\n",
    "        if split_pr + split_pr_tune<1:\n",
    "            for uq_name in ['rf', 'logistic', 'xgbrf']:\n",
    "\n",
    "                fig, ax = plt.subplots(1,3,figsize=(12, 3.2))\n",
    "                fig.tight_layout()\n",
    "                plt.subplots_adjust(hspace=0.3)\n",
    "\n",
    "                par_path = f'./output/{model}_{data}_10'\n",
    "\n",
    "                q_seq = np.round(np.linspace(0.05,0.95,19),2)\n",
    "                train_size_seq = [100, 500, 2000]\n",
    "\n",
    "                for i in tqdm(range(len(train_size_seq))):\n",
    "\n",
    "                    N = train_size_seq[i]\n",
    "\n",
    "                    path_fdr = os.path.join(par_path, f'uq_result/result/fdr_{model}_{data}_{uq_name}_{N}_{split_pr_tune}_{split_pr}_{repN}.pkl')\n",
    "                    path_power = os.path.join(par_path, f'uq_result/result/power_{model}_{data}_{uq_name}_{N}_{split_pr_tune}_{split_pr}_{repN}.pkl')\n",
    "                    path_fdr_std = os.path.join(par_path, f'uq_result/result/fdr_std_{model}_{data}_{uq_name}_{N}_{split_pr_tune}_{split_pr}_{repN}.pkl')\n",
    "                    path_power_std = os.path.join(par_path, f'uq_result/result/power_std_{model}_{data}_{uq_name}_{N}_{split_pr_tune}_{split_pr}_{repN}.pkl')\n",
    "\n",
    "                    with open(path_fdr, 'rb') as f:\n",
    "                        fdp_seq = pickle.load(f)\n",
    "                    with open(path_power, 'rb') as f:\n",
    "                        power_seq = pickle.load(f)\n",
    "                    with open(path_fdr_std, 'rb') as f:\n",
    "                        fdp_std = pickle.load(f)\n",
    "                    with open(path_power_std, 'rb') as f:\n",
    "                        power_std = pickle.load(f)\n",
    "                        \n",
    "                    if len(fdp_seq) > 20:\n",
    "                        idx_ = np.arange(4,99,5)\n",
    "                        fdp_seq = fdp_seq[idx_]\n",
    "                        power_seq = power_seq[idx_]\n",
    "                        fdp_std = fdp_std[idx_]\n",
    "                        power_std = power_std[idx_]\n",
    "\n",
    "\n",
    "                    ax[i].plot(q_seq, fdp_seq, 'bo--', label='FDR')\n",
    "                    ax[i].plot(q_seq, power_seq, 'rs--', label='Power')\n",
    "                    ax[i].fill_between(q_seq, fdp_seq-fdp_std, fdp_seq+fdp_std, alpha=0.5, \n",
    "                        edgecolor='lightblue', facecolor='lightblue')\n",
    "                    ax[i].fill_between(q_seq, power_seq-power_std, power_seq+power_std, alpha=0.5, \n",
    "                        edgecolor='lightpink', facecolor='lightpink')\n",
    "                \n",
    "                    ax[i].set_xlabel(\"Target level of FDR\")\n",
    "                    ax[i].set_ylabel(\"FDR and Power\")\n",
    "\n",
    "                    lims = [\n",
    "                        np.min([ax[i].get_xlim(), ax[i].get_ylim()]),  # min of both axes\n",
    "                        np.max([ax[i].get_xlim(), ax[i].get_ylim()]),  # max of both axes\n",
    "                    ]\n",
    "                    ax[i].plot(lims, lims, 'g--', alpha=0.75, zorder=0)\n",
    "                    ax[i].set_title(r'$|D|$=%s'%(N,))\n",
    "                    ax[i].set_xlim(lims)\n",
    "                    ax[i].set_ylim(lims)\n",
    "                    ax[i].set_xlim((0,1))\n",
    "                    ax[i].set_ylim((-0.03,1.03))\n",
    "                    ax[i].legend(loc='best', prop = { \"size\": 16 })\n",
    "\n",
    "\n",
    "                plot_dir = f'./plots/fdp_plot_{model}_{data}_{uq_name}_{split_pr_tune}_{split_pr}_{repN}.pdf'\n",
    "                plt.savefig(plot_dir, dpi=400, bbox_inches='tight')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c78a7f09",
   "metadata": {},
   "source": [
    "## Results with varying $\\gamma_1$: sample size for hyper-parameters tuning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b88c8588",
   "metadata": {},
   "outputs": [],
   "source": [
    "# for CXR: varying split_pr\n",
    "\n",
    "repN = 500\n",
    "data = 'cxr'\n",
    "model = 'trained'\n",
    "train_size_seq = [100]\n",
    "pr_seq = [0.1,0.3,0.5,0.7]\n",
    "\n",
    "for split_pr_tune in [0.2]:\n",
    "    for N in train_size_seq:\n",
    "        for uq_name in ['logistic']:\n",
    "\n",
    "            fig, ax = plt.subplots(1,len(pr_seq),figsize=(4*len(pr_seq), 3.2))\n",
    "            fig.tight_layout()\n",
    "            plt.subplots_adjust(hspace=0.3)\n",
    "\n",
    "            par_path = f'./output/{model}_{data}_10'\n",
    "\n",
    "            q_seq = np.round(np.linspace(0.05,0.95,19),2)\n",
    "\n",
    "            for i in tqdm(range(len(pr_seq))):\n",
    "                # check split_pr + split_pr_tune<1\n",
    "\n",
    "                split_pr = pr_seq[i]\n",
    "\n",
    "                path_fdr = os.path.join(par_path, f'uq_result/result/fdr_{model}_{data}_{uq_name}_{N}_{split_pr_tune}_{split_pr}_{repN}.pkl')\n",
    "                path_power = os.path.join(par_path, f'uq_result/result/power_{model}_{data}_{uq_name}_{N}_{split_pr_tune}_{split_pr}_{repN}.pkl')\n",
    "                path_fdr_std = os.path.join(par_path, f'uq_result/result/fdr_std_{model}_{data}_{uq_name}_{N}_{split_pr_tune}_{split_pr}_{repN}.pkl')\n",
    "                path_power_std = os.path.join(par_path, f'uq_result/result/power_std_{model}_{data}_{uq_name}_{N}_{split_pr_tune}_{split_pr}_{repN}.pkl')\n",
    "\n",
    "                with open(path_fdr, 'rb') as f:\n",
    "                    fdp_seq = pickle.load(f)\n",
    "                with open(path_power, 'rb') as f:\n",
    "                    power_seq = pickle.load(f)\n",
    "                with open(path_fdr_std, 'rb') as f:\n",
    "                    fdp_std = pickle.load(f)\n",
    "                with open(path_power_std, 'rb') as f:\n",
    "                    power_std = pickle.load(f)\n",
    "                    \n",
    "                if len(fdp_seq) > 20:\n",
    "                    idx_ = np.arange(4,99,5)\n",
    "                    fdp_seq = fdp_seq[idx_]\n",
    "                    power_seq = power_seq[idx_]\n",
    "                    fdp_std = fdp_std[idx_]\n",
    "                    power_std = power_std[idx_]\n",
    "\n",
    "\n",
    "                ax[i].plot(q_seq, fdp_seq, 'bo--', label='FDR')\n",
    "                ax[i].plot(q_seq, power_seq, 'rs--', label='Power')\n",
    "                ax[i].fill_between(q_seq, fdp_seq-fdp_std, fdp_seq+fdp_std, alpha=0.5, \n",
    "                    edgecolor='lightblue', facecolor='lightblue')\n",
    "                ax[i].fill_between(q_seq, power_seq-power_std, power_seq+power_std, alpha=0.5, \n",
    "                    edgecolor='lightpink', facecolor='lightpink')\n",
    "\n",
    "                ax[i].set_xlabel(\"Target level of FDR\")\n",
    "                ax[i].set_ylabel(\"FDR and Power\")\n",
    "\n",
    "                lims = [\n",
    "                    np.min([ax[i].get_xlim(), ax[i].get_ylim()]),  # min of both axes\n",
    "                    np.max([ax[i].get_xlim(), ax[i].get_ylim()]),  # max of both axes\n",
    "                ]\n",
    "                ax[i].plot(lims, lims, 'g--', alpha=0.75, zorder=0)\n",
    "                ax[i].set_title(r'$\\gamma_2$=%s'%(split_pr,))\n",
    "                ax[i].set_xlim(lims)\n",
    "                ax[i].set_ylim(lims)\n",
    "                ax[i].set_xlim((0,1))\n",
    "                ax[i].set_ylim((-0.03,1.03))\n",
    "                ax[i].legend(loc='best', prop = { \"size\": 16 })\n",
    "\n",
    "            plot_dir = f'./plots/fdp_plot_gamma_2_{model}_{data}_{uq_name}_{N}_{split_pr_tune}_{repN}.pdf'\n",
    "            plt.savefig(plot_dir, dpi=400, bbox_inches='tight')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e6f565db",
   "metadata": {},
   "source": [
    "## Results with varying $\\gamma_2$: sample size for training the alignment score predictor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6bb2e17d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# for CXR: varying split_pr\n",
    "\n",
    "repN = 500\n",
    "data = 'cxr'\n",
    "model = 'trained'\n",
    "train_size_seq = [2000]\n",
    "tune_pr_seq = [0.2,0.4,0.6]\n",
    "\n",
    "for split_pr in [0.3]:\n",
    "    for N in train_size_seq:\n",
    "        for uq_name in ['logistic']:\n",
    "\n",
    "            fig, ax = plt.subplots(1,len(tune_pr_seq),figsize=(4*len(tune_pr_seq), 3.2))\n",
    "            fig.tight_layout()\n",
    "            plt.subplots_adjust(hspace=0.3)\n",
    "\n",
    "            par_path = f'./output/{model}_{data}_10'\n",
    "\n",
    "            q_seq = np.round(np.linspace(0.05,0.95,19),2)\n",
    "\n",
    "            for i in tqdm(range(len(tune_pr_seq))):\n",
    "                # check split_pr + split_pr_tune<1\n",
    "\n",
    "                split_pr_tune = tune_pr_seq[i]\n",
    "\n",
    "                path_fdr = os.path.join(par_path, f'uq_result/result/fdr_{model}_{data}_{uq_name}_{N}_{split_pr_tune}_{split_pr}_{repN}.pkl')\n",
    "                path_power = os.path.join(par_path, f'uq_result/result/power_{model}_{data}_{uq_name}_{N}_{split_pr_tune}_{split_pr}_{repN}.pkl')\n",
    "                path_fdr_std = os.path.join(par_path, f'uq_result/result/fdr_std_{model}_{data}_{uq_name}_{N}_{split_pr_tune}_{split_pr}_{repN}.pkl')\n",
    "                path_power_std = os.path.join(par_path, f'uq_result/result/power_std_{model}_{data}_{uq_name}_{N}_{split_pr_tune}_{split_pr}_{repN}.pkl')\n",
    "\n",
    "                with open(path_fdr, 'rb') as f:\n",
    "                    fdp_seq = pickle.load(f)\n",
    "                with open(path_power, 'rb') as f:\n",
    "                    power_seq = pickle.load(f)\n",
    "                with open(path_fdr_std, 'rb') as f:\n",
    "                    fdp_std = pickle.load(f)\n",
    "                with open(path_power_std, 'rb') as f:\n",
    "                    power_std = pickle.load(f)\n",
    "                    \n",
    "                if len(fdp_seq) > 20:\n",
    "                    idx_ = np.arange(4,99,5)\n",
    "                    fdp_seq = fdp_seq[idx_]\n",
    "                    power_seq = power_seq[idx_]\n",
    "                    fdp_std = fdp_std[idx_]\n",
    "                    power_std = power_std[idx_]\n",
    "\n",
    "                ax[i].plot(q_seq, fdp_seq, 'bo--', label='FDR')\n",
    "                ax[i].plot(q_seq, power_seq, 'rs--', label='Power')\n",
    "                ax[i].fill_between(q_seq, fdp_seq-fdp_std, fdp_seq+fdp_std, alpha=0.5, \n",
    "                    edgecolor='lightblue', facecolor='lightblue')\n",
    "                ax[i].fill_between(q_seq, power_seq-power_std, power_seq+power_std, alpha=0.5, \n",
    "                    edgecolor='lightpink', facecolor='lightpink')\n",
    "                \n",
    "                ax[i].set_xlabel(\"Target level of FDR\")\n",
    "                ax[i].set_ylabel(\"FDR and Power\")\n",
    "\n",
    "                lims = [\n",
    "                    np.min([ax[i].get_xlim(), ax[i].get_ylim()]),  # min of both axes\n",
    "                    np.max([ax[i].get_xlim(), ax[i].get_ylim()]),  # max of both axes\n",
    "                ]\n",
    "                ax[i].plot(lims, lims, 'g--', alpha=0.75, zorder=0)\n",
    "                ax[i].set_title(r'$\\gamma_1$=%s'%(split_pr_tune,))\n",
    "                ax[i].set_xlim(lims)\n",
    "                ax[i].set_ylim(lims)\n",
    "                ax[i].set_xlim((0,1))\n",
    "                ax[i].set_ylim((-0.03,1.03))\n",
    "                ax[i].legend(loc='best', prop = { \"size\": 16 })\n",
    "\n",
    "\n",
    "            plot_dir = f'./plots/fdp_plot_gamma_1_{model}_{data}_{uq_name}_{N}_{split_pr}_{repN}.pdf'\n",
    "            plt.savefig(plot_dir, dpi=400, bbox_inches='tight')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6f204431",
   "metadata": {},
   "source": [
    "## Effect of features in alignment score prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a899a0e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import colorcet as cc\n",
    "from matplotlib.cm import get_cmap\n",
    "SMALL_SIZE = 14\n",
    "MEDIUM_SIZE = 16\n",
    "BIGGER_SIZE = 16\n",
    "\n",
    "plt.rc('font', size=SMALL_SIZE)          # controls default text sizes\n",
    "plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title\n",
    "plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels\n",
    "plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels\n",
    "plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels\n",
    "plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize\n",
    "plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title\n",
    "\n",
    "# ROC\n",
    "data = 'cxr'\n",
    "model = 'trained'\n",
    "N = 2000\n",
    "uq_name = 'logistic'\n",
    "split_pr_tune = 0.2\n",
    "split_pr = 0.5\n",
    "repN = 500\n",
    "uq_list = [\n",
    "        'generations|numsets', \n",
    "        'lexical_sim',\n",
    "        'generations|spectral_eigv_clip|disagreement_w',\n",
    "        'generations|eccentricity|disagreement_w',\n",
    "        'generations|degree|disagreement_w',\n",
    "        'generations|spectral_eigv_clip|agreement_w',\n",
    "        'generations|eccentricity|agreement_w',\n",
    "        'generations|degree|agreement_w',\n",
    "        'generations|spectral_eigv_clip|jaccard',\n",
    "        'generations|eccentricity|jaccard',\n",
    "        'generations|degree|jaccard',\n",
    "        'semanticEntropy|unnorm', \n",
    "]\n",
    "def name_map(v):\n",
    "    if v == 'self_prob': return \"Self_Eval\"\n",
    "    if v == 'lexical_sim': return \"Lexical_Sim\"\n",
    "    v = v.replace(\"|numsets\", \"|Num_Sets\")\n",
    "    v = v.replace(\"|disagreement_w\", \"|(C)\")\n",
    "    v = v.replace(\"|agreement_w\", \"|(E)\")\n",
    "    v = v.replace(\"|jaccard\", \"|(J)\")\n",
    "    v = v.replace(\"spectral_eigv_clip|\", \"EigV\")\n",
    "    v = v.replace(\"eccentricity|\", \"Ecc\")\n",
    "    v = v.replace(\"degree|\", \"Deg\")\n",
    "    return {'semanticEntropy|unnorm': 'SE',\n",
    "            'blind': 'Basse Accuracy'}.get(v,v)\n",
    "    return v\n",
    "\n",
    "q_seq = np.round(np.linspace(0.01,0.99,99),2)\n",
    "\n",
    "fig, ax = plt.subplots(1,1,figsize=(5, 4.5))\n",
    "    \n",
    "par_path = f'./output/{model}_{data}_10'\n",
    "path_fdr = os.path.join(par_path, f'uq_result/result/uq_fdr_{model}_{data}_{uq_name}_{N}_{split_pr_tune}_{split_pr}_{repN}.pkl')\n",
    "path_power = os.path.join(par_path, f'uq_result/result/uq_power_{model}_{data}_{uq_name}_{N}_{split_pr_tune}_{split_pr}_{repN}.pkl')\n",
    "with open(path_fdr, 'rb') as f:\n",
    "    fdr_uq = pickle.load(f)\n",
    "with open(path_power, 'rb') as f:\n",
    "    power_uq = pickle.load(f)\n",
    "\n",
    "idx_seq = [len(fdr_uq)-1] + [i for i in range(len(fdr_uq)-1)]\n",
    "for i in range(len(idx_seq)):\n",
    "    i_ = idx_seq[i]\n",
    "    fdp_seq = fdr_uq[i_]\n",
    "    power_seq = power_uq[i_]\n",
    "    uq_nam = name_map(uq_list[i_])\n",
    "    uq_nam = uq_nam.split('generations|')[-1]\n",
    "    \n",
    "    if uq_nam == 'Self_Eval':\n",
    "        color = 'black'\n",
    "        mark = 's'\n",
    "        lty = '-'\n",
    "    elif uq_nam in ['Num_Sets', 'SE', 'Lexical_Sim']:\n",
    "        color = cc.fire[50*(i+1)]\n",
    "        mark = 'o'\n",
    "        lty = '--'\n",
    "    else:\n",
    "        color = cc.CET_C6[15*i]\n",
    "        mark = '.'\n",
    "        lty = '-'\n",
    "\n",
    "    val = \"Power\"\n",
    "    if val == 'FDR':\n",
    "        val_seq = fdp_seq\n",
    "    else:\n",
    "        val_seq = power_seq\n",
    "    ax.plot(q_seq, val_seq, linestyle=lty, marker=mark, markersize=6, linewidth=2.5, label=uq_nam, c=color)\n",
    "    ax.set_xlabel(\"Target level of FDR\")\n",
    "    ax.set_ylabel(val)\n",
    "\n",
    "    lims = [\n",
    "        np.min([ax.get_xlim(), ax.get_ylim()]),  # min of both axes\n",
    "        np.max([ax.get_xlim(), ax.get_ylim()]),  # max of both axes\n",
    "    ]\n",
    "    ax.plot(lims, lims, 'g--', alpha=0.75, zorder=0)\n",
    "    ax.set_xlim(lims)\n",
    "    ax.set_ylim(lims)\n",
    "    ax.set_ylim((-0.02,1.02))\n",
    "\n",
    "ax.legend(loc='best',bbox_to_anchor=(1, 1.03))\n",
    "\n",
    "plt.savefig(f'./plots/ROC_{val}_plot_{data}_{uq_name}_{N}_{split_pr_tune}_{split_pr}_{repN}.pdf', dpi=400, bbox_inches='tight')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
