{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff10d9c1-be8d-4ef7-844f-fe4731f982b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.core.debugger import set_trace"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc85b82b-56e0-4b27-8074-1a3e06df81e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "pwd = !pwd\n",
    "pwd = pwd[0]\n",
    "\n",
    "# Absolute path to default Hydra config for normalization script\n",
    "config_path = '/'.join(pwd.split('/')[:-1]) + '/examples/configs/normalization/fit/default.yaml'\n",
    "\n",
    "# Absolute path to saved train and eval managers\n",
    "EVAL_MAN_PATHS = {\n",
    "  'CoQA': pwd + '/polygraph_tacl_stablelm12b_coqa.man',\n",
    "  'GSM8K': pwd + '/polygraph_tacl_stablelm12b_gsm8k.man',\n",
    "  'MMLU': pwd + '/polygraph_tacl_stablelm12b_mmlu.man',\n",
    "  'TriviaQA': pwd + '/polygraph_tacl_stablelm12b_triviaqa.man',\n",
    "  'WMT14': pwd + '/polygraph_tacl_stablelm12b_wmt14.man',\n",
    "  'WMT19': pwd + '/polygraph_tacl_stablelm12b_wmt19.man',\n",
    "  'XSum': pwd + '/polygraph_tacl_stablelm12b_xsum.man',\n",
    "}\n",
    "\n",
    "TRAIN_MAN_PATHS = {\n",
    "  'CoQA': pwd + '/polygraph_tacl_stablelm12b_coqa_train.man',\n",
    "  'GSM8K': pwd + '/polygraph_tacl_stablelm12b_gsm8k_train.man',\n",
    "  'MMLU': pwd + '/polygraph_tacl_stablelm12b_mmlu_train.man',\n",
    "  'TriviaQA': pwd + '/polygraph_tacl_stablelm12b_triviaqa_train.man',\n",
    "  'WMT14': pwd + '/polygraph_tacl_stablelm12b_wmt14_train.man',\n",
    "  'WMT19': pwd + '/polygraph_tacl_stablelm12b_wmt19_train.man',\n",
    "  'XSum': pwd + '/polygraph_tacl_stablelm12b_xsum_train.man',\n",
    "}\n",
    "\n",
    "DATASET_NAMES = list(TRAIN_MAN_PATHS.keys())\n",
    "\n",
    "UE_METHOD_NAMES = [\n",
    "    'MaximumSequenceProbability',\n",
    "    'Perplexity',\n",
    "    'MeanTokenEntropy',\n",
    "    'MonteCarloSequenceEntropy',\n",
    "    'MonteCarloNormalizedSequenceEntropy',\n",
    "    'MeanPointwiseMutualInformation',\n",
    "    'RenyiNeg',\n",
    "    'FisherRao',\n",
    "    'TokenSAR',\n",
    "    'CCP',\n",
    "    'SemanticEntropy',\n",
    "    'SentenceSAR',\n",
    "    'SAR',\n",
    "    'PTrue',\n",
    "    'NumSemSets',\n",
    "    'EigValLaplacian_NLI_score_entail',\n",
    "    'EigValLaplacian_NLI_score_contra',\n",
    "    'EigValLaplacian_Jaccard_score',\n",
    "    'DegMat_NLI_score_entail',\n",
    "    'DegMat_NLI_score_contra',\n",
    "    'DegMat_Jaccard_score',\n",
    "    'Eccentricity_NLI_score_entail',\n",
    "    'Eccentricity_NLI_score_contra',\n",
    "    'Eccentricity_Jaccard_score',\n",
    "    'LexicalSimilarity_rouge1',\n",
    "    'LexicalSimilarity_rouge2',\n",
    "    'LexicalSimilarity_rougeL',\n",
    "]\n",
    "\n",
    "UE_METHOD_NAMES_ABBR = {\n",
    "    'MaximumSequenceProbability': 'MSP',\n",
    "    'Perplexity': 'PPL',\n",
    "    'MeanTokenEntropy': 'MTE',\n",
    "    'MonteCarloSequenceEntropy': 'MCSE',\n",
    "    'MonteCarloNormalizedSequenceEntropy': 'MCNSE',\n",
    "    'MeanPointwiseMutualInformation': 'MPMI',\n",
    "    'RenyiNeg': 'RenyiNeg',\n",
    "    'FisherRao': 'FisherRao',\n",
    "    'TokenSAR': 'TokenSAR',\n",
    "    'CCP': 'CCP',\n",
    "    'SemanticEntropy': 'SE',\n",
    "    'SentenceSAR': 'SentenceSAR',\n",
    "    'SAR': 'SAR',\n",
    "    'PTrue': 'PTrue',\n",
    "    'NumSemSets': 'NumSemSets',\n",
    "    'EigValLaplacian_NLI_score_entail': 'EVL_entail',\n",
    "    'EigValLaplacian_NLI_score_contra': 'EVL_contra',\n",
    "    'EigValLaplacian_Jaccard_score': 'EVL_Jaccard',\n",
    "    'DegMat_NLI_score_entail': 'DegMat_entail',\n",
    "    'DegMat_NLI_score_contra': 'DegMat_contra',\n",
    "    'DegMat_Jaccard_score': 'DegMat_Jaccard',\n",
    "    'Eccentricity_NLI_score_entail': 'Eccentricity_entail',\n",
    "    'Eccentricity_NLI_score_contra': 'Eccentricity_contra',\n",
    "    'Eccentricity_Jaccard_score': 'Eccentricity_Jaccard',\n",
    "    'LexicalSimilarity_rouge1': 'LS_rouge1',\n",
    "    'LexicalSimilarity_rouge2': 'LS_rouge2',\n",
    "    'LexicalSimilarity_rougeL': 'LS_rougeL',\n",
    "}\n",
    "\n",
    "GEN_METRIC_NAMES = ['AlignScore']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "260d8fa3-c70d-451b-8de5-6f68e3870293",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Download all managers to current directory\n",
    "#!wget -r --cut-dirs=2 -nH --no-parent -A '*man' http://209.38.249.180:8000/polygraph_data/mans/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80a60d0d-2255-4719-b08f-f8b3a72efa28",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_man_paths_list(man_paths):\n",
    "    \"\"\" Formats a list of paths so it can be passed as a parameter override to hydra script call\"\"\"\n",
    "    paths = ['\"' + path + '\"' for path in man_paths]\n",
    "    paths = '\\'[' + ','.join(paths) + ']\\''\n",
    "\n",
    "    return paths"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d90b943-ff6a-480a-bf5d-d3bd22e258d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "# Run polygraph_normalize to fit all normalizers using all train datasets\n",
    "train_man_paths = get_man_paths_list(list(TRAIN_MAN_PATHS.values()))\n",
    "os.system(f'HYDRA_CONFIG={config_path} polygraph_normalize save_path=\"./\" man_paths={train_man_paths}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31d2d6d7-39f9-4bf4-842e-24f8f464d6a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "# Load saved fitted normalizers\n",
    "with open('fitted_normalizers.json', 'rb') as f:\n",
    "    fitted_normalizers = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc3d056e-e44a-4b17-a2b1-0dcadfdf94e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from collections import defaultdict\n",
    "from lm_polygraph.normalizers.minmax import MinMaxNormalizer\n",
    "from lm_polygraph.normalizers.quantile import QuantileNormalizer\n",
    "from lm_polygraph.normalizers.binned_pcc import BinnedPCCNormalizer\n",
    "from lm_polygraph.normalizers.isotonic_pcc import IsotonicPCCNormalizer\n",
    "\n",
    "NORMALIZERS = {\n",
    "    'min_max': MinMaxNormalizer,\n",
    "    'quantile': QuantileNormalizer,\n",
    "    'binned_pcc': BinnedPCCNormalizer,\n",
    "    'isotonic_pcc': IsotonicPCCNormalizer\n",
    "}\n",
    "\n",
    "def get_confidences(normalizers, ues):\n",
    "    \"\"\" For each combination of method, gen metric and normalizer type\n",
    "    load normalizer from encoded string and use it to transform UE array \"\"\"\n",
    "    \n",
    "    confidences = {'min_max': defaultdict(dict),\n",
    "                   'quantile': defaultdict(dict),\n",
    "                   'binned_pcc': defaultdict(dict),\n",
    "                   'isotonic_pcc': defaultdict(dict)}\n",
    "    \n",
    "    for key in confidences.keys():\n",
    "        for method_name in UE_METHOD_NAMES:\n",
    "            for metric_name in GEN_METRIC_NAMES:\n",
    "                normalizer = NORMALIZERS[key].loads(normalizers[(metric_name, method_name, key)])\n",
    "                normalized = normalizer.transform(ues[method_name])\n",
    "                confidences[key][metric_name][method_name] = normalized\n",
    "\n",
    "    return confidences\n",
    "\n",
    "def calculate_mses(confidences, gen_metrics, ues):\n",
    "    \"\"\" Given a list of confidences and gen metrics calculates MSE between them\n",
    "    for each combination of method, metric and normalizer type\"\"\"\n",
    "    \n",
    "    mses = {'min_max': defaultdict(dict),\n",
    "            'quantile': defaultdict(dict),\n",
    "            'binned_pcc': defaultdict(dict),\n",
    "            'isotonic_pcc': defaultdict(dict)}\n",
    "    \n",
    "    for key in confidences.keys():\n",
    "        for metric_name in GEN_METRIC_NAMES:\n",
    "            metric_mses = []\n",
    "            for method_name in UE_METHOD_NAMES:\n",
    "                gt_metric = gen_metrics[metric_name]\n",
    "                try:\n",
    "                    mse = ((confidences[key][metric_name][method_name] - gt_metric) ** 2).mean()\n",
    "                except:\n",
    "                    set_trace()\n",
    "                    pass\n",
    "                metric_mses.append(mse)\n",
    "            mses[key][metric_name] = metric_mses\n",
    "\n",
    "    return mses"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "952a1cf6-e03f-439b-b246-428ffefcbb68",
   "metadata": {},
   "source": [
    "### All datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7ae19fa-3d56-4ad9-92dc-575836c143d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from lm_polygraph.utils.normalize import get_mans_ues_metrics\n",
    "\n",
    "# Load and concatenate all UE values and metrics for all test datasets\n",
    "ues, gen_metrics = get_mans_ues_metrics(EVAL_MAN_PATHS.values(), UE_METHOD_NAMES, GEN_METRIC_NAMES)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20a4b2ea-c66b-4feb-8d87-f08f9d50f662",
   "metadata": {},
   "outputs": [],
   "source": [
    "confidences = get_confidences(fitted_normalizers, ues)\n",
    "mses = calculate_mses(confidences, gen_metrics, ues)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6fcc5210-6255-481e-b41c-839cf989cfe9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_mses(ax, mses, title):\n",
    "    ax.bar(x-0.3, mses['min_max']['AlignScore'], width=0.2, color='g', align='center', label='Linear')\n",
    "    ax.bar(x-0.1, mses['quantile']['AlignScore'], width=0.2, color='b', align='center', label='Quantile')\n",
    "    ax.bar(x+0.1, mses['binned_pcc']['AlignScore'], width=0.2, color='tab:olive', align='center', label='Binned')\n",
    "    ax.bar(x+0.3, mses['isotonic_pcc']['AlignScore'], width=0.2, color='r', align='center', label='Isotonic')\n",
    "\n",
    "    ax.set_xticks(range(len(UE_METHOD_NAMES)), list(UE_METHOD_NAMES_ABBR.values()), rotation=90, fontsize=14)\n",
    "    \n",
    "    ax.set_title(title, fontsize=20)\n",
    "    ax.set_ylabel('MSE', fontsize=18)\n",
    "    ax.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8a0e47d-c28e-43a4-9b32-6d8498fc91a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "x = np.array(list(range(len(UE_METHOD_NAMES))))\n",
    "\n",
    "f, ax = plt.subplots(1, 1, figsize=(9, 7))\n",
    "\n",
    "plot_mses(ax, mses, 'MSE between AlignScore and confidence')\n",
    "\n",
    "# handles, labels = ax.get_legend_handles_labels()\n",
    "# f.legend(handles, labels, bbox_to_anchor=(1.15, 0.96), fontsize=12)\n",
    "\n",
    "plt.tight_layout()\n",
    "# Change this to plt.show() to display inline\n",
    "plt.savefig(f'normalization_mse_total.pdf', bbox_inches='tight')\n",
    "# plt.show()\n",
    "plt.clf()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1c0ab75e-1c83-4fdc-a5db-186161d02f58",
   "metadata": {},
   "source": [
    "### OOD Datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3cc6137-dfec-4754-8d0d-0305c0ebef5e",
   "metadata": {},
   "outputs": [],
   "source": [
    "ood_confidences = {}\n",
    "ood_mses = {}\n",
    "\n",
    "for dataset_name in DATASET_NAMES:\n",
    "    # Fit normalizers excluding current OOD dataset from train set\n",
    "    train_man_paths_wo_dataset = [value for key, value in TRAIN_MAN_PATHS.items() if key != dataset_name]\n",
    "    train_man_paths_wo_dataset = get_man_paths_list(train_man_paths_wo_dataset)\n",
    "    os.system(f'HYDRA_CONFIG={config_path} polygraph_normalize save_path=\"./ood_{dataset_name}\" man_paths={train_man_paths_wo_dataset}')\n",
    "\n",
    "    # Get UE and metric values for OOD dataset\n",
    "    ood_ues, ood_gen_metrics = get_mans_ues_metrics([EVAL_MAN_PATHS[dataset_name]], UE_METHOD_NAMES, GEN_METRIC_NAMES)\n",
    "\n",
    "    with open(f'./ood_{dataset_name}/fitted_normalizers.json', 'rb') as f:\n",
    "        fitted_normalizers = pickle.load(f)\n",
    "\n",
    "    ood_confidences[dataset_name] = get_confidences(fitted_normalizers, ood_ues)\n",
    "    ood_mses[dataset_name] = calculate_mses(ood_confidences[dataset_name], ood_gen_metrics, ood_ues)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c7851db-8d4d-4d7a-985a-8f02697c3120",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.gridspec import GridSpec\n",
    "\n",
    "# If number of datasets is not even, we display an even part in pairs, and then last one separately\n",
    "is_even = (len(DATASET_NAMES) % 2 == 0)\n",
    "even_datasets_subset = DATASET_NAMES if is_even else DATASET_NAMES[:-1]\n",
    "dataset_pairs = [even_datasets_subset[i:i + 2] for i in range(0, len(even_datasets_subset), 2)]\n",
    "\n",
    "x = np.array(list(range(len(UE_METHOD_NAMES))))\n",
    "\n",
    "for dataset_pair in dataset_pairs:\n",
    "    fig = plt.figure(figsize=(18, 5))\n",
    "    gs = GridSpec(1, 2, figure=fig)\n",
    "\n",
    "    for i, dataset_name in enumerate(dataset_pair):\n",
    "        ax = fig.add_subplot(gs[i//2,i%2])\n",
    "        plot_mses(ax, ood_mses[dataset_name], f'MSE between true AlignScore and confidence: {dataset_name}')\n",
    "        handles, labels = ax.get_legend_handles_labels()\n",
    "    \n",
    "    plt.tight_layout()\n",
    "    # Change this to plt.show() to display inline\n",
    "    plt.savefig(f'normalization_mse_ood_{\"_\".join(dataset_pair).lower()}.pdf')\n",
    "    # plt.show()\n",
    "    plt.clf()\n",
    "\n",
    "if not is_even:\n",
    "    dataset_name = DATASET_NAMES[-1]\n",
    "    \n",
    "    fig = plt.figure(figsize=(9, 5))\n",
    "    gs = GridSpec(1, 1, figure=fig)\n",
    "\n",
    "    ax = fig.add_subplot(gs[0,0])\n",
    "\n",
    "    plot_mses(ax, ood_mses[dataset_name], f'MSE between true AlignScore and confidence: {dataset_name}')\n",
    "    \n",
    "    plt.tight_layout()\n",
    "    # Change this to plt.show() to display inline\n",
    "    plt.savefig(f'normalization_mse_ood_{dataset_name.lower()}.pdf')\n",
    "    # plt.show()\n",
    "    plt.clf()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0bf5e4be-eec1-45da-b1ad-ec0662707c34",
   "metadata": {},
   "source": [
    "### PRR change relative to raw uncertainty"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25439bd3-1e06-4770-ba89-3eab823e34ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "from lm_polygraph.utils.normalize import filter_nans\n",
    "from lm_polygraph.ue_metrics.pred_rej_area import PredictionRejectionArea\n",
    "from lm_polygraph.ue_metrics.ue_metric import (\n",
    "    get_random_scores,\n",
    "    normalize_metric,\n",
    ")\n",
    "import pandas as pd\n",
    "\n",
    "ue_metric = PredictionRejectionArea()\n",
    "\n",
    "cols = ['MinMax', 'Quantile', 'Binned PCC', 'Isotonic PCC']\n",
    "\n",
    "# For each of the dataset we take all confidences calculated in OOD setting\n",
    "# and compare PRR of this to raw unnormalized UE\n",
    "for dataset_name, path in EVAL_MAN_PATHS.items():\n",
    "    res = {}\n",
    "    all_ues, all_gen_metrics = get_mans_ues_metrics([path], UE_METHOD_NAMES, GEN_METRIC_NAMES)\n",
    "    train_man_paths_wo_dataset = [value for key, value in TRAIN_MAN_PATHS.items() if key != dataset_name]\n",
    "    train_ues, train_gen_metrics = get_mans_ues_metrics(train_man_paths_wo_dataset, UE_METHOD_NAMES, GEN_METRIC_NAMES)\n",
    "\n",
    "    for metric_name in GEN_METRIC_NAMES:\n",
    "        gen_metrics = all_gen_metrics[metric_name]\n",
    "        for ue_method_name in UE_METHOD_NAMES:\n",
    "            ues = all_ues[ue_method_name]\n",
    "            \n",
    "            filtered_metric, filtered_ues = filter_nans(gen_metrics, ues)\n",
    "            \n",
    "            # -np.array() because we need to use UE, not confidence to calculate PRR\n",
    "            minmax_ues = -np.array(ood_confidences[dataset_name]['min_max'][metric_name][ue_method_name])\n",
    "            quantile_ues = -np.array(ood_confidences[dataset_name]['quantile'][metric_name][ue_method_name])\n",
    "            binned_pcc_ues = -np.array(ood_confidences[dataset_name]['binned_pcc'][metric_name][ue_method_name])\n",
    "            isotonic_pcc_ues = -np.array(ood_confidences[dataset_name]['isotonic_pcc'][metric_name][ue_method_name])\n",
    "            \n",
    "            oracle_score = ue_metric(-filtered_metric, filtered_metric)\n",
    "            random_score = get_random_scores(ue_metric, filtered_metric)\n",
    "\n",
    "            raw_ue_metric_val = ue_metric(filtered_ues, filtered_metric)\n",
    "            raw_score = normalize_metric(raw_ue_metric_val, oracle_score, random_score)\n",
    "\n",
    "            minmax_ue_metric_val = ue_metric(minmax_ues, filtered_metric)\n",
    "            minmax_score = normalize_metric(minmax_ue_metric_val, oracle_score, random_score)\n",
    "            minmax_diff = raw_score - minmax_score\n",
    "\n",
    "            quantile_ue_metric_val = ue_metric(quantile_ues, filtered_metric)\n",
    "            quantile_score = normalize_metric(quantile_ue_metric_val, oracle_score, random_score)\n",
    "            quantile_diff = raw_score - quantile_score\n",
    "            \n",
    "            binned_pcc_ue_metric_val = ue_metric(binned_pcc_ues, filtered_metric)\n",
    "            binned_pcc_score = normalize_metric(binned_pcc_ue_metric_val, oracle_score, random_score)\n",
    "            binned_pcc_diff = raw_score - binned_pcc_score\n",
    "\n",
    "            isotonic_pcc_ue_metric_val = ue_metric(isotonic_pcc_ues, filtered_metric)\n",
    "            isotonic_pcc_score = normalize_metric(isotonic_pcc_ue_metric_val, oracle_score, random_score)\n",
    "            isotonic_pcc_diff = raw_score - isotonic_pcc_score\n",
    "\n",
    "            res[ue_method_name] = [minmax_diff, quantile_diff, binned_pcc_diff, isotonic_pcc_diff]\n",
    "\n",
    "    # Show table for each datasets that contains difference between raw UE PRR and PRR based on normalized confidence\n",
    "    # Lower is better, negative is best (means normalized confidence improves upon raw PRR\n",
    "    df = pd.DataFrame.from_dict(res, orient='index', columns=cols)\n",
    "    display(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f04893f-0516-4855-a753-d4b8f5756539",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_ues, all_gen_metrics = get_mans_ues_metrics(list(EVAL_MAN_PATHS.values()), UE_METHOD_NAMES, GEN_METRIC_NAMES)\n",
    "\n",
    "# Same for all datasets concatenated\n",
    "for metric_name in GEN_METRIC_NAMES:\n",
    "    gen_metrics = all_gen_metrics[metric_name]\n",
    "    for ue_method_name in UE_METHOD_NAMES:\n",
    "        ues = all_ues[ue_method_name]\n",
    "\n",
    "        filtered_metric, filtered_ues = filter_nans(gen_metrics, ues)\n",
    "\n",
    "        minmax_ues = -np.array(confidences['min_max'][metric_name][ue_method_name])\n",
    "        quantile_ues = -np.array(confidences['quantile'][metric_name][ue_method_name])\n",
    "        binned_pcc_ues = -np.array(confidences['binned_pcc'][metric_name][ue_method_name])\n",
    "        isotonic_pcc_ues = -np.array(confidences['isotonic_pcc'][metric_name][ue_method_name])\n",
    "\n",
    "        oracle_score = ue_metric(-filtered_metric, filtered_metric)\n",
    "        random_score = get_random_scores(ue_metric, filtered_metric)\n",
    "\n",
    "        raw_ue_metric_val = ue_metric(filtered_ues, filtered_metric)\n",
    "        raw_score = normalize_metric(raw_ue_metric_val, oracle_score, random_score)\n",
    "\n",
    "        minmax_ue_metric_val = ue_metric(minmax_ues, filtered_metric)\n",
    "        minmax_score = normalize_metric(minmax_ue_metric_val, oracle_score, random_score)\n",
    "        minmax_diff = raw_score - minmax_score\n",
    "\n",
    "        quantile_ue_metric_val = ue_metric(quantile_ues, filtered_metric)\n",
    "        quantile_score = normalize_metric(quantile_ue_metric_val, oracle_score, random_score)\n",
    "        quantile_diff = raw_score - quantile_score\n",
    "\n",
    "        binned_pcc_ue_metric_val = ue_metric(binned_pcc_ues, filtered_metric)\n",
    "        binned_pcc_score = normalize_metric(binned_pcc_ue_metric_val, oracle_score, random_score)\n",
    "        binned_pcc_diff = raw_score - binned_pcc_score\n",
    "\n",
    "        isotonic_pcc_ue_metric_val = ue_metric(isotonic_pcc_ues, filtered_metric)\n",
    "        isotonic_pcc_score = normalize_metric(isotonic_pcc_ue_metric_val, oracle_score, random_score)\n",
    "        isotonic_pcc_diff = raw_score - isotonic_pcc_score\n",
    "\n",
    "        res[ue_method_name] = [minmax_diff, quantile_diff, binned_pcc_diff, isotonic_pcc_diff]\n",
    "\n",
    "df = pd.DataFrame.from_dict(res, orient='index', columns=cols)\n",
    "display(df)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e7a70cac-b626-4dde-a268-33cf3682a5d0",
   "metadata": {},
   "source": [
    "### Table coloring and formatting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d0e5b48-82ce-4160-97ed-f3c2b8c66881",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib\n",
    "from matplotlib import colors\n",
    "\n",
    "cmap = matplotlib.cm.get_cmap('Greens')\n",
    "my_cmap = cmap(np.arange(cmap.N))\n",
    "my_cmap[:,-1] = 0.5\n",
    "my_cmap = colors.ListedColormap(my_cmap)\n",
    "\n",
    "def b_g(values, cmap, low=0, high=0):\n",
    "    # values = s.apply(lambda x: float(x.split(\"±\")[0]) if len(x.split(\"±\"))>1 else x)\n",
    "    \n",
    "    rng = values.max().max() - values.min().min()\n",
    "    norm = colors.Normalize(values.min().min() - (rng * low), values.max().max() + (rng * high))\n",
    "    normed = norm(values.values)\n",
    "    back_colors = [[colors.rgb2hex(val) for val in x] for x in plt.cm.get_cmap(cmap)(normed)]\n",
    "    text_colors = [[\"white\" if val>0.3 else \"black\" for val in x] for x in normed]\n",
    "    \n",
    "    return np.array([[f'color: {text_color}; background-color: {color}' for text_color, color in zip(row_text_colors, row_colors)] for row_text_colors, row_colors in zip(text_colors, back_colors)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "abed508d-988b-4542-943f-a35a361b55f4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def rgba2rgb(rgba, background=(1,1,1)):\n",
    "    ch = rgba.shape[0]\n",
    "    if ch == 3:\n",
    "        return rgba\n",
    "\n",
    "    assert ch == 4, 'RGBA image has 4 channels.'\n",
    "\n",
    "    r, g, b, a = rgba[0], rgba[1], rgba[2], rgba[3]\n",
    "    a = np.asarray(a, dtype='float32')\n",
    "    R, G, B = background\n",
    "\n",
    "    r_new = r * a + (1.0 - a) * R\n",
    "    g_new = g * a + (1.0 - a) * G\n",
    "    b_new = b * a + (1.0 - a) * B\n",
    "\n",
    "    return [r_new, g_new, b_new]\n",
    "\n",
    "\n",
    "def to_color(text, vals):\n",
    "    vals = rgba2rgb(np.array(vals))\n",
    "    return '\\\\cellcolor[rgb]{'+f'{vals[0]},'+f'{vals[1]},'+f'{vals[2]}'+'} '+f'{text}'\n",
    "\n",
    "def bold_best(df, columns):\n",
    "    total_min = df.values.min().min()\n",
    "    total_max = df.values.max().max()\n",
    "    for col in columns:\n",
    "        values_init_raw = [x if x!='-' else np.nan for x in df[col]]\n",
    "        values_init = np.array([x for x in df[col] if x!='-'])\n",
    "        if values_init.min() != values_init.max():\n",
    "            values_init_raw = np.array([(x - total_min) / (total_max - total_min) if not np.isnan(x) else x for x in values_init_raw])\n",
    "            \n",
    "        def get_new_x(x):\n",
    "            if isinstance(x, str):\n",
    "                return x\n",
    "            return '-'\n",
    "        \n",
    "        values = [to_color(\"{:.3f}\".format(raw), my_cmap(float(x))) if (isinstance(x, float) and (not np.isnan(x))) else get_new_x(x) for raw, x in zip(df[col], values_init_raw)]\n",
    "        df[col] = values\n",
    "    return df\n",
    "\n",
    "df = pd.DataFrame.from_dict(res, orient='index', columns=cols)\n",
    "df.style.apply(b_g, cmap=cmap, axis=None)\n",
    "df_colored = bold_best(df, df.columns)\n",
    "with open('total_prr_table.tex', 'w') as f:\n",
    "    with pd.option_context(\"max_colwidth\", 1000):\n",
    "        table = df_colored.to_latex()\n",
    "        table = table.replace('-0.000', '0.000')\n",
    "        table = table.replace('\\\\textbackslash ', '\\\\')\n",
    "        table = table.replace('\\\\{', '{')\n",
    "        table = table.replace('\\\\}', '}')\n",
    "        f.write(table)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1b6d2474-4322-432f-8f33-0a88e75d952e",
   "metadata": {},
   "source": [
    "### Normalized confidence vs raw uncertainty plots"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f4764979-5e3f-40f7-a2c5-9a00d4d45875",
   "metadata": {},
   "outputs": [],
   "source": [
    "for method in UE_METHOD_NAMES:\n",
    "    metric = 'AlignScore'\n",
    "    cur_ues = all_ues[method]\n",
    "    order = np.argsort(cur_ues)\n",
    "    sor_ues = cur_ues[order]\n",
    "    sor_metrics = all_gen_metrics[metric]\n",
    "    plt.plot(sor_ues, sor_metrics)\n",
    "    plt.plot(sor_ues, confidences['min_max'][metric][method][order], label='MinMax')\n",
    "    plt.plot(sor_ues, confidences['quantile'][metric][method][order], label='Quantile')\n",
    "    plt.plot(sor_ues, confidences['binned_pcc'][metric][method][order], label='Binned PCC')\n",
    "    plt.plot(sor_ues, confidences['isotonic_pcc'][metric][method][order], label='Isotonic PCC')\n",
    "    plt.title(method)\n",
    "    plt.legend()\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "    plt.clf()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
