{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e5fd065-8111-48de-9c92-3f7c8f378762",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    "\n",
    "import matplotlib \n",
    "from matplotlib import colors\n",
    "import matplotlib.pyplot as plt \n",
    "\n",
    "from lm_polygraph.utils.manager import UEManager, _recombine_data, _delete_nans\n",
    "from lm_polygraph.ue_metrics import PredictionRejectionArea, KendallTauCorrelation, SpearmanRankCorrelation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2046bc0c-9d7a-484d-8acd-f347dcb28e23",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "names_dict = {\n",
    "    \"MaximumSequenceProbability\": \"Maximum Sequence Probability\",\n",
    "    \"Perplexity\":\"Perplexity\",\n",
    "\n",
    "    \"MeanTokenEntropy\":\"Mean Token Entropy\",\n",
    "    \"MutualInformation\":\"Pointwise Mutual Information\",\n",
    "    \"MeanPointwiseMutualInformation\": \"Pointwise Mutual Information\",\n",
    "    \"MeanConditionalPointwiseMutualInformation\":\"Conditional Pointwise Mutual Information\",\n",
    "\n",
    "    \"PTrueSampling\":\"P(True) Sampling\",\n",
    "    \"PTrue\":\"P(True)\",\n",
    "\n",
    "    \"SemanticEntropy\": \"Semantic Entropy\",\n",
    "    \"MonteCarloSequenceEntropy\": \"Monte Carlo Sequence Entropy\",\n",
    "    \"MonteCarloNormalizedSequenceEntropy\":\"Monte Carlo Normalized Sequence Entropy\",\n",
    "\n",
    "    \"LexicalSimilarity_rouge1\":\"Lexical Similarity Rouge-1\",\n",
    "    \"LexicalSimilarity_rouge2\":\"Lexical Similarity Rouge-2\",\n",
    "    \"LexicalSimilarity_rougeL\":\"Lexical Similarity Rouge-L\",\n",
    "    \"LexicalSimilarity_BLEU\":\"Lexical Similarity BLEU\",\n",
    "\n",
    "    'EigValLaplacian_Jaccard_score':'EigValLaplacian Jaccard Score', \n",
    "    'DegMat_Jaccard_score':'DegMat Jaccard Score',\n",
    "    'Eccentricity_Jaccard_score': 'Eccentricity Jaccard Score',\n",
    "\n",
    "    'EigValLaplacian_NLI_score_contra':'EigValLaplacian NLI Score contra.', \n",
    "    'DegMat_NLI_score_contra':'DegMat NLI Score Contra.',\n",
    "    'Eccentricity_NLI_score_contra': 'Eccentricity NLI Score contra.',\n",
    "\n",
    "    'EigValLaplacian_NLI_score_entail':'EigValLaplacian NLI Score entail.', \n",
    "    'DegMat_NLI_score_entail':'DegMat NLI Score entail.',\n",
    "    'Eccentricity_NLI_score_entail': 'Eccentricity NLI Score entail.',\n",
    "\n",
    "    \"MahalanobisDistanceSeq_decoder\":\"Mahalanobis Distance - Decoder\",\n",
    "    \"RelativeMahalanobisDistanceSeq_decoder\":\"Relative Mahalanobis Distance - Decoder\",\n",
    "    \"RDESeq_decoder\":\"RDE - Decoder\",\n",
    "    \"PPLMDSeq_decoder\":\"HUQ-MD - Decoder\",\n",
    "    \"PPLRMDSeq_decoder\":\"HUQ-RMD - Decoder\",\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6bb03658-a53b-4df3-84d6-2f171badec5f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def bootstraping_for_std(ue, metric, ue_metric, num_runs: int = 1000, return_string: bool = False):\n",
    "    idx = np.arange(0, len(ue))\n",
    "    all_samples = np.random.choice(idx, num_runs * len(ue), True)\n",
    "    samples = np.array(np.array_split(all_samples, num_runs))\n",
    "    mean_values = []\n",
    "    for s in samples:\n",
    "        mean_values.append(ue_metric(ue[s], metric[s]))\n",
    "    mean_values = np.array(mean_values)\n",
    "    sorted_mean_values = np.array(sorted(mean_values))\n",
    "    return sorted_mean_values[int(0.05*num_runs):int(0.95*num_runs)].std()\n",
    "\n",
    "def get_random_scores(function, metrics, num_iter=1000, seed=42, is_bartscore=False):\n",
    "    np.random.seed(seed)\n",
    "    rand_scores = np.arange(len(metrics))\n",
    "\n",
    "    value, scores = [], []\n",
    "    for i in range(num_iter):\n",
    "        np.random.shuffle(rand_scores)\n",
    "        rand_val = function(rand_scores, metrics)\n",
    "        value.append(rand_val)\n",
    "    return np.mean(value)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "963f5e7c-3a06-405e-bc3f-c16d3fe83074",
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_mans(paths, model, dataset):\n",
    "    mans = []\n",
    "    for path in paths:\n",
    "        man_path = f\"{path}/{model}/{dataset}\"\n",
    "        for subdir, _, man_files in os.walk(man_path):\n",
    "            if subdir != man_path:\n",
    "                continue\n",
    "            for man_file in man_files:\n",
    "                try:\n",
    "                    man = UEManager.load(os.path.join(man_path, man_file))\n",
    "                except:\n",
    "                    continue\n",
    "                mans.append(man)\n",
    "    return mans\n",
    "\n",
    "def update_mans(mans):\n",
    "    final_man = mans[0]\n",
    "    for man in mans[1:]:\n",
    "        for stats in [\"estimations\", \"metrics\", \"gen_metrics\"]:\n",
    "            final_man.__dict__[stats].update(man.__dict__[stats])\n",
    "    return final_man\n",
    "\n",
    "def get_tables(paths, models, datasets, gen_metrics, ue_metrics, recompute_metrics=True):\n",
    "    dfs = {}\n",
    "    quality_dfs = {}\n",
    "    for model in models:\n",
    "        dfs[model] = {}\n",
    "        quality_dfs[model] = pd.DataFrame({})\n",
    "        for ue_metric_name, ue_metric in ue_metrics.items():\n",
    "            result = pd.DataFrame({})\n",
    "            for ds in datasets:\n",
    "                mans = read_mans(paths, model, ds)\n",
    "                if len(mans) == 0:\n",
    "                    continue\n",
    "                final_man = update_mans(mans)\n",
    "                \n",
    "                ue_methods = np.array([k[1] for k in final_man.metrics.keys()])\n",
    "                _, idx = np.unique(ue_methods, return_index=True)\n",
    "                ue_methods = ue_methods[np.sort(idx)]\n",
    "                result[('','UE Method')] = [names_dict.get(m, m) for m in ue_methods]\n",
    "\n",
    "                for gen_metric in gen_metrics:  \n",
    "                    score_vals = []\n",
    "                    metrics_val = np.array(final_man.gen_metrics[(\"sequence\", gen_metric)])\n",
    "                    if (ds, gen_metric) not in quality_dfs[model].columns:\n",
    "                        quality_dfs[model][(ds, gen_metric)] = [metrics_val[~np.isnan(metrics_val)].mean()] \n",
    "                    for ue_method in tqdm(ue_methods):\n",
    "                        ue = np.array(final_man.estimations[(\"sequence\", ue_method)])\n",
    "                        ue_, metrics_val_, selected_ids = _delete_nans(ue, metrics_val)\n",
    "                        \n",
    "                        if len(ue):\n",
    "                            inputs_no_nans = np.array(final_man.stats['input_texts'])[selected_ids]\n",
    "                            ue_ = np.array(ue_)\n",
    "                            metrics_val_ = np.array(metrics_val_)\n",
    "                            rec_ue, rec_metrics_val = _recombine_data(ue_, metrics_val_, inputs_no_nans)\n",
    "                            rec_metrics_val = np.array(rec_metrics_val)\n",
    "                            rec_ue = np.array(rec_ue)\n",
    "                            \n",
    "                            dict_key = ('sequence', ue_method, gen_metric, ue_metric_name)\n",
    "                            if (dict_key in final_man.metrics.keys()) and not recompute_metrics:\n",
    "                                mean_val = np.array(final_man.metrics[dict_key])\n",
    "                            else:\n",
    "                                mean_val = ue_metric(rec_ue, rec_metrics_val)\n",
    "                                \n",
    "                            oracle = ue_metric(-rec_metrics_val, rec_metrics_val)\n",
    "                            random = get_random_scores(ue_metric, rec_metrics_val)\n",
    "                            final_score = (mean_val - random) / (oracle - random) \n",
    "                            std = bootstraping_for_std(rec_ue, rec_metrics_val, ue_metric)\n",
    "                        else:\n",
    "                            std = 0\n",
    "                            final_score = 0\n",
    "                        score_vals.append(f\"{final_score:.2f}±{std:.2f}\")\n",
    "                    result[(ds, gen_metric)] = score_vals\n",
    "            quality_dfs[model].columns=pd.MultiIndex.from_tuples(quality_dfs[model].columns)\n",
    "            result.columns=pd.MultiIndex.from_tuples(result.columns)\n",
    "            dfs[model][ue_metric_name] = result\n",
    "    return dfs, quality_dfs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f00a8e4a-cd37-4c3f-983a-11e00b41d0b9",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "paths = [\"../workdir/camera_ready_exps/v1\", \"../workdir/camera_ready_exps/bertscore\"]\n",
    "models = [\"vicuna\", \"llama\"]\n",
    "datasets = [\"aeslc\", \"xsum\", \"coqa\", \"babiqa\", \"wmt14_deen\", \"wmt14_fren\"]\n",
    "gen_metrics = [\"Rouge_rougeL\", \"Bert\"]\n",
    "ue_metrics = {\"prr\": PredictionRejectionArea(), \n",
    "              \"kendalltau\": KendallTauCorrelation(),\n",
    "              \"spearmanr\": SpearmanRankCorrelation()\n",
    "             }\n",
    "dfs, quality_dfs = get_tables(paths, models, datasets, gen_metrics, ue_metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e091f49d-9735-4499-b88a-efe20bf94c4d",
   "metadata": {},
   "outputs": [],
   "source": [
    "cmap = matplotlib.cm.get_cmap('Greens')\n",
    "my_cmap = cmap(np.arange(cmap.N))\n",
    "my_cmap[:,-1] = 0.5\n",
    "my_cmap = colors.ListedColormap(my_cmap)\n",
    "\n",
    "def b_g(s, cmap, low=0, high=0):\n",
    "    values = s.apply(lambda x: float(x.split(\"±\")[0]) if len(x.split(\"±\"))>1 else x)\n",
    "    if isinstance(values.max(), str):\n",
    "        return ['' for c in values]\n",
    "    rng = values.max() - values.min()\n",
    "    norm = colors.Normalize(values.min() - (rng * low), values.max() + (rng * high))\n",
    "    normed = norm(values.values)\n",
    "    back_colors = [colors.rgb2hex(x) for x in plt.cm.get_cmap(cmap)(normed)]\n",
    "    text_colors = [\"white\" if x>0.3 else \"black\" for x in normed]\n",
    "    return [f'color: {text_color}; background-color: {color}' for text_color, color in zip(text_colors, back_colors)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "042d5a04-3340-4157-b900-ede00a53f0ee",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "order = list(range(9))+[23]+[9]+list(range(11,23))+list(range(24,29))\n",
    "table_style = {\n",
    "    'selector': 'caption',\n",
    "    'props': [\n",
    "        ('color', 'black'),\n",
    "        ('font-size', '20px'),\n",
    "        ('font-weight', 'bold')\n",
    "    ]\n",
    "}\n",
    "dfs[\"vicuna\"][\"prr\"].iloc[order].style.apply(b_g, cmap=my_cmap).set_caption('Vicuna, PRR').set_table_styles([table_style])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "285aa81e-ab64-4856-9333-702fc2deec5e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "dfs[\"vicuna\"][\"kendalltau\"].iloc[order].style.apply(b_g, cmap=my_cmap).set_caption('Vicuna, Kendall $\\\\tau$').set_table_styles([table_style])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0417f2bf-a92e-4ad4-b95b-e2f570b1f0a1",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "dfs[\"llama\"][\"kendalltau\"].iloc[order].style.apply(b_g, cmap=my_cmap).set_caption('Llama, Kendall $\\\\tau$').set_table_styles([table_style])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6a34fc8-aa56-43c7-b8e3-8346b91e170b",
   "metadata": {},
   "outputs": [],
   "source": [
    "dfs[\"vicuna\"][\"spearmanr\"].iloc[order].style.apply(b_g, cmap=my_cmap).set_caption('Vicuna, Spearman $\\\\rho$').set_table_styles([table_style])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "607f1f14-0d18-4f8f-b49c-26963346f443",
   "metadata": {},
   "outputs": [],
   "source": [
    "dfs[\"llama\"][\"spearmanr\"].iloc[order].style.apply(b_g, cmap=my_cmap).set_caption('Llama, Spearman $\\\\rho$').set_table_styles([table_style])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "353f4dd7-25a8-4c4f-a5bc-567ebc2d9a06",
   "metadata": {},
   "outputs": [],
   "source": [
    "quality_dfs[\"vicuna\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43f7cf37-6d95-4045-ad74-308db39c6735",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def rgba2rgb(rgba, background=(1,1,1)):\n",
    "    ch = rgba.shape[0]\n",
    "    if ch == 3:\n",
    "        return rgba\n",
    "\n",
    "    assert ch == 4, 'RGBA image has 4 channels.'\n",
    "\n",
    "    r, g, b, a = rgba[0], rgba[1], rgba[2], rgba[3]\n",
    "    a = np.asarray(a, dtype='float32')\n",
    "    R, G, B = background\n",
    "\n",
    "    r_new = r * a + (1.0 - a) * R\n",
    "    g_new = g * a + (1.0 - a) * G\n",
    "    b_new = b * a + (1.0 - a) * B\n",
    "\n",
    "    return [r_new, g_new, b_new]\n",
    "\n",
    "\n",
    "def to_color(text, vals):\n",
    "    vals = rgba2rgb(np.array(vals))\n",
    "    return '\\\\cellcolor[rgb]{'+f'{vals[0]},'+f'{vals[1]},'+f'{vals[2]}'+'} '+f'{text}'\n",
    "\n",
    "def bold_best(df, columns):\n",
    "    for col in columns:\n",
    "        values_init_raw = [float(x.split('±')[0]) if x!='-' else np.nan for x in df[col]]\n",
    "        values_init = np.array([float(x.split('±')[0]) for x in df[col] if x!='-'])\n",
    "        if values_init.min() != values_init.max():\n",
    "            values_init_raw = np.array([(x - values_init.min()) / (values_init.max() - values_init.min()) if not np.isnan(x) else x for x in values_init_raw])\n",
    "            \n",
    "        def get_new_x(x):\n",
    "            if isinstance(x, str):\n",
    "                return x\n",
    "            return '-'\n",
    "        values = [to_color(raw, my_cmap(float(x))) if (isinstance(x, float) and (not np.isnan(x))) else get_new_x(x) for raw, x in zip(df[col], values_init_raw)]\n",
    "        df[col] = values\n",
    "    return df\n",
    "\n",
    "def prepare_latex(df1):\n",
    "    start_tex = '\\\\begin{table*}[!ht] \\\\resizebox{\\\\textwidth}{!}{'\n",
    "    end_tex = \"}\\\\caption{\\\\label{tab:llama_results} PRR$\\\\uparrow$ for Llama v2 model for various tasks for the considered sequence-level methods. Darker color indicates better results.}\\end{table*}\"\n",
    "    df1 = bold_best(df1, df1.columns[1:])\n",
    "    latex_table = df1.to_latex(bold_rows=False, index=False).replace('±', '$\\pm$')\n",
    "        \n",
    "    latex_table = latex_table.replace('\\\\textbackslash ', '\\\\')\n",
    "    latex_table = latex_table.replace('{lllllllllllll}', '{l|cc|cc|cc|cc|cc|cc}')\n",
    "    latex_table = latex_table.replace('{lllllll}', '{l|c|c|c|c|c|c}')\n",
    "    \n",
    "    latex_table = latex_table.replace('\\\\{', '{')\n",
    "    latex_table = latex_table.replace('\\\\}', '}')\n",
    "    str_list = latex_table.split('\\n')\n",
    "    str_list.pop(3)\n",
    "    latex_table = '\\n'.join(str_list)\n",
    "    return start_tex+latex_table+end_tex"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed88e8c4-6917-424e-8140-06ea655edc18",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import copy\n",
    "\n",
    "with pd.option_context(\"max_colwidth\", 1000):\n",
    "    res_str = prepare_latex(copy.deepcopy(dfs[\"vicuna\"][\"prr\"]).round(2)).split('\\n')\n",
    "    res_str[2] = \"\"\"\\\\multirow{2}{*}{\\\\textbf{UE Method}} & \\multicolumn{2}{c|}{\\\\textbf{AESLC}} & \\multicolumn{2}{c|}{\\\\textbf{XSUM}} & \\multicolumn{2}{c|}{\\\\textbf{CoQA}} & \\multicolumn{2}{c|}{\\\\textbf{bAbiQA}} & \\multicolumn{2}{c|}{\\\\textbf{WMT14 De-En}} & \\multicolumn{2}{c}{\\\\textbf{WMT14 Fr-En}} \\\\\\\\ \\\\cline{2-13}\n",
    "    & \\\\textbf{Rouge-L} & \\\\textbf{BERTScore}& \\\\textbf{Rouge-L} & \\\\textbf{BERTScore}& \\\\textbf{Rouge-L} & \\\\textbf{BERTScore}& \\\\textbf{Rouge-L} & \\\\textbf{BERTScore}&  \\\\textbf{Rouge-L} & \\\\textbf{BERTScore}& \\\\textbf{Rouge-L} & \\\\textbf{BERTScore} \\\\\\\\\\\\midrule\"\"\"\n",
    "    print('\\n'.join(res_str))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "379e7df9-2d7f-4ad3-9c77-4f22645619fc",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "with pd.option_context(\"max_colwidth\", 1000):\n",
    "    res_str = prepare_latex(copy.deepcopy(dfs[\"llama\"][\"prr\"]).round(2)).split('\\n')\n",
    "    res_str[2] = \"\"\"\\\\multirow{2}{*}{\\\\textbf{UE Method}} & \\multicolumn{2}{c|}{\\\\textbf{AESLC}} & \\multicolumn{2}{c|}{\\\\textbf{XSUM}} & \\multicolumn{2}{c|}{\\\\textbf{CoQA}} & \\multicolumn{2}{c|}{\\\\textbf{bAbiQA}} & \\multicolumn{2}{c|}{\\\\textbf{WMT14 De-En}} & \\multicolumn{2}{c}{\\\\textbf{WMT14 Fr-En}} \\\\\\\\ \\\\cline{2-13}\n",
    "    & \\\\textbf{Rouge-L} & \\\\textbf{BERTScore}& \\\\textbf{Rouge-L} & \\\\textbf{BERTScore}& \\\\textbf{Rouge-L} & \\\\textbf{BERTScore}& \\\\textbf{Rouge-L} & \\\\textbf{BERTScore}&  \\\\textbf{Rouge-L} & \\\\textbf{BERTScore}& \\\\textbf{Rouge-L} & \\\\textbf{BERTScore} \\\\\\\\\\\\midrule\"\"\"\n",
    "    print('\\n'.join(res_str))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "684981c2-c552-4f6c-81aa-6b8131d48289",
   "metadata": {},
   "outputs": [],
   "source": [
    "def prepare_latex(df1):\n",
    "    start_tex = '\\\\begin{table*}[!ht] \\\\resizebox{\\\\textwidth}{!}{'\n",
    "    end_tex = \"}\\\\caption{\\\\label{tab:llama_results} PRR$\\\\uparrow$ for Llama v2 model for various tasks for the considered sequence-level methods. Darker color indicates better results.}\\end{table*}\"\n",
    "    latex_table = df1.to_latex(bold_rows=False, index=False).replace('±', '$\\pm$')\n",
    "        \n",
    "    latex_table = latex_table.replace('\\\\textbackslash ', '\\\\')\n",
    "    latex_table = latex_table.replace('{lllllllllllll}', '{l|cc|cc|cc|cc|cc|cc}')\n",
    "    latex_table = latex_table.replace('{lllllll}', '{l|c|c|c|c|c|c}')\n",
    "    \n",
    "    latex_table = latex_table.replace('\\\\{', '{')\n",
    "    latex_table = latex_table.replace('\\\\}', '}')\n",
    "    str_list = latex_table.split('\\n')\n",
    "    str_list.pop(3)\n",
    "    latex_table = '\\n'.join(str_list)\n",
    "    return start_tex+latex_table+end_tex"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ece6e42a-8ab9-40af-af3f-9d1e1d04c9ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "with pd.option_context(\"max_colwidth\", 1000):\n",
    "    res_str = prepare_latex(copy.deepcopy(quality_dfs['vicuna']).round(2)).split('\\n')\n",
    "    res_str[2] = \"\"\"\\multicolumn{2}{c|}{\\\\textbf{AESLC}} & \\multicolumn{2}{c|}{\\\\textbf{XSUM}} & \\multicolumn{2}{c|}{\\\\textbf{CoQA}} & \\multicolumn{2}{c|}{\\\\textbf{bAbiQA}} & \\multicolumn{2}{c|}{\\\\textbf{WMT14 De-En}} & \\multicolumn{2}{c}{\\\\textbf{WMT14 Fr-En}} \\\\\\\\ \\\\midrule\n",
    "    \\\\textbf{Rouge-L} & \\\\textbf{BERTScore}& \\\\textbf{Rouge-L} & \\\\textbf{BERTScore}& \\\\textbf{Rouge-L} & \\\\textbf{BERTScore}& \\\\textbf{Rouge-L} & \\\\textbf{BERTScore}&  \\\\textbf{Rouge-L} & \\\\textbf{BERTScore}& \\\\textbf{Rouge-L} & \\\\textbf{BERTScore} \\\\\\\\\\\\midrule\"\"\"\n",
    "    print('\\n'.join(res_str))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "81cf4f44-aec3-44f3-a6e7-a793213a5d0e",
   "metadata": {},
   "outputs": [],
   "source": [
    "with pd.option_context(\"max_colwidth\", 1000):\n",
    "    res_str = prepare_latex(copy.deepcopy(quality_dfs['llama']).round(2)).split('\\n')\n",
    "    res_str[2] = \"\"\"\\multicolumn{2}{c|}{\\\\textbf{AESLC}} & \\multicolumn{2}{c|}{\\\\textbf{XSUM}} & \\multicolumn{2}{c|}{\\\\textbf{CoQA}} & \\multicolumn{2}{c|}{\\\\textbf{bAbiQA}} & \\multicolumn{2}{c|}{\\\\textbf{WMT14 De-En}} & \\multicolumn{2}{c}{\\\\textbf{WMT14 Fr-En}} \\\\\\\\ \\\\midrule\n",
    "    \\\\textbf{Rouge-L} & \\\\textbf{BERTScore}& \\\\textbf{Rouge-L} & \\\\textbf{BERTScore}& \\\\textbf{Rouge-L} & \\\\textbf{BERTScore}& \\\\textbf{Rouge-L} & \\\\textbf{BERTScore}&  \\\\textbf{Rouge-L} & \\\\textbf{BERTScore}& \\\\textbf{Rouge-L} & \\\\textbf{BERTScore} \\\\\\\\\\\\midrule\"\"\"\n",
    "    print('\\n'.join(res_str))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe86f555-9114-409c-a606-c7e6241a9bd3",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3964df99-7429-4e57-b5b7-7ba5fc377e92",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "lm_poly",
   "language": "python",
   "name": "lm_poly"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
