{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import tikzplotlib\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ECE averaged over model families"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "metric = 'ECE_15'\n",
    "dataset = 'yahoo_answers_topics' # 'CIFAR10', 'CIFAR100', 'ImageNet', amazon_food, dynasent, mnli, yahoo_answers_topics\n",
    "valid_size_IN = 25000\n",
    "seeds = 5\n",
    "mean_or_std = 'mean' # 'mean' 'std'\n",
    "\n",
    "df = pd.DataFrame()\n",
    "for seed in range(seeds):\n",
    "    if dataset == 'ImageNet':\n",
    "        df_s = pd.read_csv(f'../results/benchmark_calibration_{dataset}_calibSize{valid_size_IN}_seed{seed}.csv')\n",
    "    else:\n",
    "        df_s = pd.read_csv(f'../results/benchmark_calibration_{dataset}_seed{seed}.csv')\n",
    "    # choose best of eqsize and eqmass for HB\n",
    "    for m in df_s['model'].unique():\n",
    "        hb_tva_eqmass = df_s[(df_s['model'] == m) & (df_s['method'] == 'netcal_HB_tva_eqmass')]\n",
    "        hb_tva_eqsize = df_s[(df_s['model'] == m) & (df_s['method'] == 'netcal_HB_tva_eqsize')]\n",
    "        if hb_tva_eqmass['ECE_15'].item() < hb_tva_eqsize['ECE_15'].item():\n",
    "            hb_tva = hb_tva_eqmass.copy()\n",
    "        else:\n",
    "            hb_tva = hb_tva_eqsize.copy()\n",
    "        hb_tva['method'] = 'netcal_HB_tva'\n",
    "        df_s = pd.concat([df_s, hb_tva], axis=0)\n",
    "    df = pd.concat([df, df_s], axis=0)\n",
    "    \n",
    "# get mean, std values across seeds\n",
    "if mean_or_std == 'mean':\n",
    "    df = df.groupby(['dataset', 'model', 'method', df.index]).mean().reset_index().drop(columns='level_3')\n",
    "elif mean_or_std == 'std':\n",
    "    df = df.groupby(['dataset', 'model', 'method', df.index]).std().reset_index().drop(columns='level_3')\n",
    "\n",
    "# choose dataset & metric\n",
    "df_metric = pd.pivot_table(df, index=['dataset', 'model'], columns=['method'], values=[metric])[metric]\n",
    "if metric != 'Accuracy':\n",
    "    df_metric = df_metric * 100 # in %\n",
    "df_metric = df_metric.loc[dataset]\n",
    "\n",
    "\n",
    "dict_ref = {'TS_tva': 'TS', 'VS_reg_tva': 'VS', 'Dir-ODIR_reg_tva': 'Dir-ODIR', 'netcal_Iso_tva': 'netcal_Iso', 'netcal_BBQ_tva': 'netcal_BBQ', 'netcal_HB_tva': 'netcal_HB_eqsize'}\n",
    "\n",
    "columns = ['original', 'IRM', 'Patel2021_sCW_imax']\n",
    "for k in dict_ref.keys():\n",
    "    columns.append(dict_ref[k])\n",
    "    columns.append(k)\n",
    "df_metric = df_metric[columns]\n",
    "\n",
    "if 'CIFAR' in dataset:\n",
    "    new_idx = [\n",
    "        'ResNet-50',\n",
    "        'ResNet-110',\n",
    "        'WRN',\n",
    "        'DenseNet',\n",
    "        'clip-vit-base-patch32',\n",
    "        'clip-vit-base-patch16',\n",
    "        'clip-vit-large-patch14'\n",
    "        ]\n",
    "elif dataset == 'ImageNet':\n",
    "    new_idx = [\n",
    "        'ResNet-18','ResNet-34','ResNet-50','ResNet-101','EffNet-B7','EffNetV2-S','EffNetV2-M','EffNetV2-L','ConvNeXt-T','ConvNeXt-S','ConvNeXt-B','ConvNeXt-L',\n",
    "        'ViT-B/32','ViT-B/16','ViT-L/32','ViT-L/16','ViT-H/14','Swin-T','Swin-S','Swin-B','SwinV2-T','SwinV2-S','SwinV2-B', 'clip-vit-base-patch32', 'clip-vit-base-patch16', 'clip-vit-large-patch14']\n",
    "elif dataset in ['amazon_food', 'dynasent', 'mnli', 'yahoo_answers_topics']:\n",
    "    new_idx = ['t5', 't5-large', 'roberta', 'roberta-large']\n",
    "    \n",
    "\n",
    "df_metric = df_metric.reindex(index=new_idx)\n",
    "df_metric.index = df_metric.index.map(lambda x: x.replace('clip-vit-base-patch32', 'CLIP (ViT-B/32)'))\n",
    "df_metric.index = df_metric.index.map(lambda x: x.replace('clip-vit-base-patch16', 'CLIP (ViT-B/16)'))\n",
    "df_metric.index = df_metric.index.map(lambda x: x.replace('clip-vit-large-patch14', 'CLIP (ViT-L/14)'))\n",
    "df_metric.index = df_metric.index.map(lambda x: x.replace('t5', 'T5'))\n",
    "df_metric.index = df_metric.index.map(lambda x: x.replace('roberta', 'RoBERTa'))\n",
    "\n",
    "df_metric = df_metric.map(lambda x: float(f'{x:.2f}')) # to consider all close values as min/max"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if 'CIFAR' in dataset:\n",
    "    model_families = ['ConvNets', 'CLIP'] \n",
    "    df_avg = pd.DataFrame(index=model_families, columns=df_metric.columns)\n",
    "    df_avg.loc['ConvNets', :] = df_metric.loc[df_metric.index.str.contains('ResNet|WRN|DenseNet')].mean()\n",
    "    df_avg.loc['CLIP', :] = df_metric.loc[df_metric.index.str.contains('CLIP')].mean()\n",
    "elif dataset == 'ImageNet':\n",
    "    model_families = ['ResNet', 'EffNet', 'ConvNeXt', 'ViT', 'Swin', 'CLIP']\n",
    "    df_avg = pd.DataFrame(index=model_families, columns=df_metric.columns)\n",
    "    for model_family in model_families:\n",
    "        if model_family == 'ViT': # otherwise includes CLIP\n",
    "            df_avg.loc[model_family, :] = df_metric.loc[df_metric.index.isin(['ViT-B/32','ViT-B/16','ViT-L/32','ViT-L/16','ViT-H/14'])].mean()\n",
    "        else:\n",
    "            df_avg.loc[model_family, :] = df_metric.loc[df_metric.index.str.contains(model_family)].mean()\n",
    "elif dataset in ['amazon_food', 'dynasent', 'mnli', 'yahoo_answers_topics']:\n",
    "    model_families = ['T5', 'RoBERTa']\n",
    "    df_avg = pd.DataFrame(index=model_families, columns=df_metric.columns)\n",
    "    for model_family in model_families:\n",
    "        df_avg.loc[model_family, :] = df_metric.loc[df_metric.index.str.contains(model_family)].mean()\n",
    "\n",
    "s = df_avg.style.highlight_min(axis=1, props=\"textbf:--rwrap;\") # min value per row in bold\n",
    "s = s.format('{:.2f}') # float format\n",
    "print(s.to_latex())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ECE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "metric = 'ECE_15'\n",
    "dataset = 'yahoo_answers_topics' # 'CIFAR10', 'CIFAR100', 'ImageNet', 'ImageNet21k', amazon_food, dynasent, mnli, yahoo_answers_topics\n",
    "valid_size_IN = 78741 # 25000 for IN, 261250 for IN21k , 14001 for yahoo, 19635 mnli, 11160 dynasent, 78741 amazon\n",
    "seeds = 5\n",
    "mean_or_std = 'std' # 'mean' 'std'\n",
    "\n",
    "df = pd.DataFrame()\n",
    "for seed in range(seeds):\n",
    "    if 'ImageNet' in dataset:\n",
    "        df_s = pd.read_csv(f'../results/benchmark_calibration_{dataset}_calibSize{valid_size_IN}_seed{seed}.csv')\n",
    "    else:\n",
    "        df_s = pd.read_csv(f'../results/benchmark_calibration_{dataset}_seed{seed}.csv')\n",
    "    # choose best of eqsize and eqmass for HB\n",
    "    for m in df_s['model'].unique():\n",
    "        hb_tva_eqmass = df_s[(df_s['model'] == m) & (df_s['method'] == 'netcal_HB_tva_eqmass')]\n",
    "        hb_tva_eqsize = df_s[(df_s['model'] == m) & (df_s['method'] == 'netcal_HB_tva_eqsize')]\n",
    "        if hb_tva_eqmass['ECE_15'].item() < hb_tva_eqsize['ECE_15'].item():\n",
    "            hb_tva = hb_tva_eqmass.copy()\n",
    "        else:\n",
    "            hb_tva = hb_tva_eqsize.copy()\n",
    "        hb_tva['method'] = 'netcal_HB_tva'\n",
    "        df_s = pd.concat([df_s, hb_tva], axis=0)\n",
    "    df = pd.concat([df, df_s], axis=0)\n",
    "    \n",
    "# get mean, std values across seeds\n",
    "if mean_or_std == 'mean':\n",
    "    df = df.groupby(['dataset', 'model', 'method', df.index]).mean().reset_index().drop(columns='level_3')\n",
    "elif mean_or_std == 'std':\n",
    "    df = df.groupby(['dataset', 'model', 'method', df.index]).std().reset_index().drop(columns='level_3')\n",
    "\n",
    "# choose dataset & metric\n",
    "df_metric = pd.pivot_table(df, index=['dataset', 'model'], columns=['method'], values=[metric])[metric].reindex(columns=df['method'].unique())\n",
    "if metric != 'Accuracy':\n",
    "    df_metric = df_metric * 100 # in %\n",
    "df_metric = df_metric.loc[dataset]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dict_ref = {'TS_tva': 'TS', 'VS_reg_tva': 'VS', 'Dir-ODIR_reg_tva': 'Dir-ODIR', 'netcal_Beta_tva': 'netcal_Beta', 'netcal_Iso_tva': 'netcal_Iso', 'netcal_BBQ_tva': 'netcal_BBQ', 'netcal_HB_tva': 'netcal_HB_eqsize'}\n",
    "\n",
    "columns = ['original', 'IRM', 'Patel2021_sCW_imax']\n",
    "for k in dict_ref.keys():\n",
    "    columns.append(dict_ref[k])\n",
    "    columns.append(k)\n",
    "df_metric = df_metric[columns]\n",
    "\n",
    "if 'CIFAR' in dataset:\n",
    "    new_idx = [\n",
    "        'ResNet-50',\n",
    "        'ResNet-110',\n",
    "        'WRN',\n",
    "        'DenseNet',\n",
    "        'clip-vit-base-patch32',\n",
    "        'clip-vit-base-patch16',\n",
    "        'clip-vit-large-patch14'\n",
    "        ]\n",
    "elif dataset == 'ImageNet':\n",
    "    new_idx = [\n",
    "        'ResNet-18','ResNet-34','ResNet-50','ResNet-101','EffNet-B7','EffNetV2-S','EffNetV2-M','EffNetV2-L','ConvNeXt-T','ConvNeXt-S','ConvNeXt-B','ConvNeXt-L',\n",
    "        'ViT-B/32','ViT-B/16','ViT-L/32','ViT-L/16','ViT-H/14','Swin-T','Swin-S','Swin-B','SwinV2-T','SwinV2-S','SwinV2-B', 'clip-vit-base-patch32', 'clip-vit-base-patch16', 'clip-vit-large-patch14']\n",
    "elif dataset == 'ImageNet21k':\n",
    "    new_idx = ['mobilenetv3_large_100_miil_in21k', 'vit_base_patch16_224_miil_in21k']\n",
    "elif dataset in ['amazon_food', 'dynasent', 'mnli', 'yahoo_answers_topics']:\n",
    "    new_idx = ['t5', 't5-large', 'roberta', 'roberta-large']\n",
    "    \n",
    "df_metric = df_metric.reindex(index=new_idx)\n",
    "df_metric.index = df_metric.index.map(lambda x: x.replace('clip-vit-base-patch32', 'CLIP (ViT-B/32)'))\n",
    "df_metric.index = df_metric.index.map(lambda x: x.replace('clip-vit-base-patch16', 'CLIP (ViT-B/16)'))\n",
    "df_metric.index = df_metric.index.map(lambda x: x.replace('clip-vit-large-patch14', 'CLIP (ViT-L/14)'))\n",
    "df_metric.index = df_metric.index.map(lambda x: x.replace('vit_base_patch16_224_miil_in21k', 'ViT-B/16'))\n",
    "df_metric.index = df_metric.index.map(lambda x: x.replace('mobilenetv3_large_100_miil_in21k', 'MN3'))\n",
    "df_metric.index = df_metric.index.map(lambda x: x.replace('t5', 'T5'))\n",
    "df_metric.index = df_metric.index.map(lambda x: x.replace('roberta', 'RoBERTa'))\n",
    "\n",
    "df_metric = df_metric.map(lambda x: float(f'{x:.2f}')) # to consider all close values as min/max\n",
    "\n",
    "if mean_or_std == 'mean':\n",
    "    s = df_metric.style.highlight_min(axis=1, props=\"textbf:--rwrap;\") # min value per row in bold\n",
    "elif mean_or_std == 'std':\n",
    "    s = df_metric.style\n",
    "s = s.format('{:.2f}') # float format\n",
    "print(s.to_latex().replace('nan', 'err.'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if 'CIFAR' in dataset:\n",
    "    model_families = ['N', 'CLIP'] # contained in all models: for CIFAR only global improvemnt\n",
    "elif dataset == 'ImageNet':\n",
    "    model_families = ['ResNet', 'EffNet', 'ConvNeXt', 'ViT', 'Swin', 'CLIP']\n",
    "elif dataset == 'ImageNet21k':\n",
    "    model_families = [''] # contained in all models: global improvemnt\n",
    "elif dataset in ['amazon_food', 'dynasent', 'mnli', 'yahoo_answers_topics']:\n",
    "    model_families = ['T5', 'RoBERTa']\n",
    "df_improv = pd.DataFrame(index=model_families, columns=dict_ref.values())\n",
    "for model_family in model_families:\n",
    "    for method_new, method_ref in dict_ref.items():\n",
    "        if model_family == 'ViT': # otherwise includes CLIP\n",
    "            new_val = df_metric.loc[df_metric.index.isin(['ViT-B/32','ViT-B/16','ViT-L/32','ViT-L/16','ViT-H/14']), method_new]\n",
    "            ref_val = df_metric.loc[df_metric.index.isin(['ViT-B/32','ViT-B/16','ViT-L/32','ViT-L/16','ViT-H/14']), method_ref]\n",
    "        else:        \n",
    "            new_val = df_metric.loc[df_metric.index.str.contains(model_family), method_new]\n",
    "            ref_val = df_metric.loc[df_metric.index.str.contains(model_family), method_ref]\n",
    "        improv = (100 * (new_val - ref_val) / ref_val).mean()\n",
    "\n",
    "        df_improv.loc[model_family, method_ref] = improv\n",
    "        \n",
    "for model_family in model_families:\n",
    "    print('\\n', model_family)\n",
    "    try:\n",
    "        print(rf\"\\multicolumn{{4}}{{c|}}{{Mean improvement {model_family}}} & \\multicolumn{{2}}{{c|}}{{{df_improv.loc[model_family, 'TS'].item():.0f}\\%}} & \\multicolumn{{2}}{{c|}}{{{df_improv.loc[model_family, 'VS'].item():.0f}\\%}}& \\multicolumn{{2}}{{c|}}{{{df_improv.loc[model_family, 'Dir-ODIR'].item():.0f}\\%}}& \\multicolumn{{2}}{{c|}}{{{df_improv.loc[model_family, 'netcal_Beta'].item():.0f}\\%}}& \\multicolumn{{2}}{{c|}}{{{df_improv.loc[model_family, 'netcal_Iso'].item():.0f}\\%}}& \\multicolumn{{2}}{{c|}}{{{df_improv.loc[model_family, 'netcal_BBQ'].item():.0f}\\%}}& \\multicolumn{{2}}{{c}}{{{df_improv.loc[model_family, 'netcal_HB_eqsize'].item():.0f}\\%}} \\\\\")\n",
    "    except:\n",
    "        print(rf\"\\multicolumn{{4}}{{c|}}{{Mean improvement {model_family}}} & \\multicolumn{{2}}{{c|}}{{{df_improv.loc[model_family, 'TS']:.0f}\\%}} & \\multicolumn{{2}}{{c|}}{{{df_improv.loc[model_family, 'VS']:.0f}\\%}}& \\multicolumn{{2}}{{c|}}{{{df_improv.loc[model_family, 'Dir-ODIR']:.0f}\\%}}& \\multicolumn{{2}}{{c|}}{{{df_improv.loc[model_family, 'netcal_Beta']:.0f}\\%}}& \\multicolumn{{2}}{{c|}}{{{df_improv.loc[model_family, 'netcal_Iso']:.0f}\\%}}& \\multicolumn{{2}}{{c|}}{{{df_improv.loc[model_family, 'netcal_BBQ']:.0f}\\%}}& \\multicolumn{{2}}{{c}}{{{df_improv.loc[model_family, 'netcal_HB_eqsize']:.0f}\\%}} \\\\\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Apendix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "metric = 'Brier_top' # Accuracy AdaECE_15 Brier_top ECE_100 AdaECE_100\n",
    "dataset = 'yahoo_answers_topics' # 'CIFAR10', 'CIFAR100', 'ImageNet', 'ImageNet21k' amazon_food, dynasent, mnli, yahoo_answers_topics\n",
    "valid_size_IN = 261250 # 25000 for IN, 261250 for IN21k \n",
    "seeds = 5\n",
    "\n",
    "df = pd.DataFrame()\n",
    "for seed in range(seeds):\n",
    "    if 'ImageNet' in dataset:\n",
    "        df_s = pd.read_csv(f'../results/benchmark_calibration_{dataset}_calibSize{valid_size_IN}_seed{seed}.csv')\n",
    "    else:\n",
    "        df_s = pd.read_csv(f'../results/benchmark_calibration_{dataset}_seed{seed}.csv')\n",
    "    # choose best of eqsize and eqmass for HB\n",
    "    for m in df_s['model'].unique():\n",
    "        hb_tva_eqmass = df_s[(df_s['model'] == m) & (df_s['method'] == 'netcal_HB_tva_eqmass')]\n",
    "        hb_tva_eqsize = df_s[(df_s['model'] == m) & (df_s['method'] == 'netcal_HB_tva_eqsize')]\n",
    "        if metric in ['AdaECE_15', 'ECE_100', 'AdaECE_100', 'Brier_top']:\n",
    "            if hb_tva_eqmass[metric].item() < hb_tva_eqsize[metric].item():\n",
    "                hb_tva = hb_tva_eqmass.copy()\n",
    "            else:\n",
    "                hb_tva = hb_tva_eqsize.copy()\n",
    "        else:\n",
    "            if hb_tva_eqmass['ECE_15'].item() < hb_tva_eqsize['ECE_15'].item():\n",
    "                hb_tva = hb_tva_eqmass.copy()\n",
    "            else:\n",
    "                hb_tva = hb_tva_eqsize.copy()\n",
    "        hb_tva['method'] = 'netcal_HB_tva'\n",
    "        df_s = pd.concat([df_s, hb_tva], axis=0)\n",
    "    df = pd.concat([df, df_s], axis=0)\n",
    "\n",
    "\n",
    "# choose dataset & metric\n",
    "df_metric = pd.pivot_table(df, index=['dataset', 'model'], columns=['method'], values=[metric])[metric].reindex(columns=df['method'].unique())\n",
    "if metric != 'Accuracy':\n",
    "    df_metric = df_metric * 100 # in %\n",
    "df_metric = df_metric.loc[dataset]\n",
    "\n",
    "df_metric = df_metric[columns]\n",
    "\n",
    "if 'CIFAR' in dataset:\n",
    "    new_idx = [\n",
    "        'ResNet-50',\n",
    "        'ResNet-110',\n",
    "        'WRN',\n",
    "        'DenseNet',\n",
    "        'clip-vit-base-patch32',\n",
    "        'clip-vit-base-patch16',\n",
    "        'clip-vit-large-patch14']\n",
    "elif dataset == 'ImageNet':\n",
    "    new_idx = [\n",
    "        'ResNet-18','ResNet-34','ResNet-50','ResNet-101','EffNet-B7','EffNetV2-S','EffNetV2-M','EffNetV2-L','ConvNeXt-T','ConvNeXt-S','ConvNeXt-B','ConvNeXt-L',\n",
    "        'ViT-B/32','ViT-B/16','ViT-L/32','ViT-L/16','ViT-H/14','Swin-T','Swin-S','Swin-B','SwinV2-T','SwinV2-S','SwinV2-B', 'clip-vit-base-patch32', 'clip-vit-base-patch16', 'clip-vit-large-patch14']\n",
    "elif dataset == 'ImageNet21k':\n",
    "    new_idx = ['mobilenetv3_large_100_miil_in21k', 'vit_base_patch16_224_miil_in21k']\n",
    "elif dataset in ['amazon_food', 'dynasent', 'mnli', 'yahoo_answers_topics']:\n",
    "    new_idx = ['t5', 't5-large', 'roberta', 'roberta-large']\n",
    "    \n",
    "df_metric = df_metric.reindex(index=new_idx)\n",
    "df_metric.index = df_metric.index.map(lambda x: x.replace('clip-vit-base-patch32', 'CLIP (ViT-B/32)'))\n",
    "df_metric.index = df_metric.index.map(lambda x: x.replace('clip-vit-base-patch16', 'CLIP (ViT-B/16)'))\n",
    "df_metric.index = df_metric.index.map(lambda x: x.replace('clip-vit-large-patch14', 'CLIP (ViT-L/14)'))\n",
    "df_metric.index = df_metric.index.map(lambda x: x.replace('vit_base_patch16_224_miil_in21k', 'ViT-B/16'))\n",
    "df_metric.index = df_metric.index.map(lambda x: x.replace('mobilenetv3_large_100_miil_in21k', 'MN3'))\n",
    "df_metric.index = df_metric.index.map(lambda x: x.replace('t5', 'T5'))\n",
    "df_metric.index = df_metric.index.map(lambda x: x.replace('roberta', 'RoBERTa'))\n",
    "\n",
    "df_metric = df_metric.applymap(lambda x: float(f'{x:.2f}'))\n",
    "\n",
    "s = df_metric.style\n",
    "if metric == 'Accuracy':\n",
    "    s = s.format('{:.2f}').highlight_max(axis=1, props=\"textbf:--rwrap;\") # float format\n",
    "else:\n",
    "    s = s.format('{:.2f}').highlight_min(axis=1, props=\"textbf:--rwrap;\") # float format\n",
    "print(s.to_latex().replace('nan', 'err.'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# AUROC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "metric = 'AUROC'\n",
    "dataset = 'yahoo_answers_topics' # 'CIFAR10', 'CIFAR100', 'ImageNet', 'ImageNet21k' amazon_food, dynasent, mnli, yahoo_answers_topics\n",
    "valid_size_IN = 261250 # 25000 for IN, 261250 for IN21k \n",
    "seeds = 5\n",
    "\n",
    "df = pd.DataFrame()\n",
    "for seed in range(seeds):\n",
    "    if 'ImageNet' in dataset:\n",
    "        df_s = pd.read_csv(f'../results/benchmark_calibration_{dataset}_calibSize{valid_size_IN}_seed{seed}.csv')\n",
    "    else:\n",
    "        df_s = pd.read_csv(f'../results/benchmark_calibration_{dataset}_seed{seed}.csv')\n",
    "    # choose best of eqsize and eqmass for HB\n",
    "    for m in df_s['model'].unique():\n",
    "        hb_tva_eqmass = df_s[(df_s['model'] == m) & (df_s['method'] == 'netcal_HB_tva_eqmass')]\n",
    "        hb_tva_eqsize = df_s[(df_s['model'] == m) & (df_s['method'] == 'netcal_HB_tva_eqsize')]\n",
    "        if hb_tva_eqmass['ECE_15'].item() < hb_tva_eqsize['ECE_15'].item():\n",
    "            hb_tva = hb_tva_eqmass.copy()\n",
    "        else:\n",
    "            hb_tva = hb_tva_eqsize.copy()\n",
    "        hb_tva['method'] = 'netcal_HB_tva'\n",
    "        df_s = pd.concat([df_s, hb_tva], axis=0)\n",
    "    df = pd.concat([df, df_s], axis=0)\n",
    "    \n",
    "# get mean, std values across seeds\n",
    "df_std = df.groupby(['dataset', 'model', 'method', df.index]).std().reset_index().drop(columns='level_3')\n",
    "df = df.groupby(['dataset', 'model', 'method', df.index]).mean().reset_index().drop(columns='level_3')\n",
    "\n",
    "\n",
    "# choose dataset & metric\n",
    "df_metric = pd.pivot_table(df, index=['dataset', 'model'], columns=['method'], values=[metric])[metric].reindex(columns=df['method'].unique())\n",
    "if metric != 'Accuracy':\n",
    "    df_metric = df_metric * 100 # in %\n",
    "df_metric = df_metric.loc[dataset]\n",
    "\n",
    "\n",
    "df_metric = df_metric[columns]\n",
    "\n",
    "if 'CIFAR' in dataset:\n",
    "    new_idx = [\n",
    "        'ResNet-50',\n",
    "        'ResNet-110',\n",
    "        'WRN',\n",
    "        'DenseNet',\n",
    "        'clip-vit-base-patch32',\n",
    "        'clip-vit-base-patch16',\n",
    "        'clip-vit-large-patch14']\n",
    "elif dataset == 'ImageNet':\n",
    "    new_idx = [\n",
    "        'ResNet-18','ResNet-34','ResNet-50','ResNet-101','EffNet-B7','EffNetV2-S','EffNetV2-M','EffNetV2-L','ConvNeXt-T','ConvNeXt-S','ConvNeXt-B','ConvNeXt-L',\n",
    "        'ViT-B/32','ViT-B/16','ViT-L/32','ViT-L/16','ViT-H/14','Swin-T','Swin-S','Swin-B','SwinV2-T','SwinV2-S','SwinV2-B', 'clip-vit-base-patch32', 'clip-vit-base-patch16', 'clip-vit-large-patch14']\n",
    "elif dataset == 'ImageNet21k':\n",
    "    new_idx = ['mobilenetv3_large_100_miil_in21k', 'vit_base_patch16_224_miil_in21k']\n",
    "elif dataset in ['amazon_food', 'dynasent', 'mnli', 'yahoo_answers_topics']:\n",
    "    new_idx = ['t5', 't5-large', 'roberta', 'roberta-large']\n",
    "    \n",
    "df_metric = df_metric.reindex(index=new_idx)\n",
    "df_metric.index = df_metric.index.map(lambda x: x.replace('clip-vit-base-patch32', 'CLIP (ViT-B/32)'))\n",
    "df_metric.index = df_metric.index.map(lambda x: x.replace('clip-vit-base-patch16', 'CLIP (ViT-B/16)'))\n",
    "df_metric.index = df_metric.index.map(lambda x: x.replace('clip-vit-large-patch14', 'CLIP (ViT-L/14)'))\n",
    "df_metric.index = df_metric.index.map(lambda x: x.replace('vit_base_patch16_224_miil_in21k', 'ViT-B/16'))\n",
    "df_metric.index = df_metric.index.map(lambda x: x.replace('mobilenetv3_large_100_miil_in21k', 'MN3'))\n",
    "df_metric.index = df_metric.index.map(lambda x: x.replace('t5', 'T5'))\n",
    "df_metric.index = df_metric.index.map(lambda x: x.replace('roberta', 'RoBERTa'))\n",
    "\n",
    "df_metric_color = df_metric.copy()\n",
    "for col in columns[1:]:\n",
    "    df_metric_color.loc[df_metric[col] >= df_metric['original'], col] = df_metric.loc[df_metric[col] >= df_metric['original'], col].apply(lambda x: rf'\\textcolor{{blue}}{{{x:.2f}}}')\n",
    "    df_metric_color.loc[df_metric[col] < df_metric['original'], col] = df_metric.loc[df_metric[col] < df_metric['original'], col].apply(lambda x: rf'\\textcolor{{orange}}{{{x:.2f}}}')\n",
    "df_metric_color['original'] = df_metric_color['original'].apply(lambda x: f'{x:.2f}')\n",
    "\n",
    "s = df_metric_color.style\n",
    "print(s.to_latex().replace('nan', 'err.'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Underconfidence"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "metric = 'Average_Confidence'\n",
    "dataset = 'yahoo_answers_topics' # 'CIFAR10', 'CIFAR100', 'ImageNet', 'ImageNet21k' amazon_food, dynasent, mnli, yahoo_answers_topics\n",
    "valid_size_IN = 261250 # 25000 for IN, 261250 for IN21k \n",
    "seeds = 5\n",
    "\n",
    "df = pd.DataFrame()\n",
    "for seed in range(seeds):\n",
    "    if 'ImageNet' in dataset:\n",
    "        df_s = pd.read_csv(f'../results/benchmark_calibration_{dataset}_calibSize{valid_size_IN}_seed{seed}.csv')\n",
    "    else:\n",
    "        df_s = pd.read_csv(f'../results/benchmark_calibration_{dataset}_seed{seed}.csv')\n",
    "    # choose best of eqsize and eqmass for HB\n",
    "    for m in df_s['model'].unique():\n",
    "        hb_tva_eqmass = df_s[(df_s['model'] == m) & (df_s['method'] == 'netcal_HB_tva_eqmass')]\n",
    "        hb_tva_eqsize = df_s[(df_s['model'] == m) & (df_s['method'] == 'netcal_HB_tva_eqsize')]\n",
    "        if hb_tva_eqmass['ECE_15'].item() < hb_tva_eqsize['ECE_15'].item():\n",
    "            hb_tva = hb_tva_eqmass.copy()\n",
    "        else:\n",
    "            hb_tva = hb_tva_eqsize.copy()\n",
    "        hb_tva['method'] = 'netcal_HB_tva'\n",
    "        df_s = pd.concat([df_s, hb_tva], axis=0)\n",
    "    df = pd.concat([df, df_s], axis=0)\n",
    "    \n",
    "# get mean, std values across seeds\n",
    "df_std = df.groupby(['dataset', 'model', 'method', df.index]).std().reset_index().drop(columns='level_3')\n",
    "df = df.groupby(['dataset', 'model', 'method', df.index]).mean().reset_index().drop(columns='level_3')\n",
    "\n",
    "\n",
    "# choose dataset & metric\n",
    "df_metric = pd.pivot_table(df, index=['dataset', 'model'], columns=['method'], values=[metric])[metric].reindex(columns=df['method'].unique())\n",
    "if metric != 'Accuracy':\n",
    "    df_metric = df_metric * 100 # in %\n",
    "df_metric = df_metric.loc[dataset]\n",
    "\n",
    "df_metric['Accuracy'] = df.loc[(df['dataset'] == dataset) & (df['method'] == 'original'), 'Accuracy'].to_list()\n",
    "df_metric = df_metric[['Accuracy']+columns]\n",
    "\n",
    "if 'CIFAR' in dataset:\n",
    "    new_idx = [\n",
    "        'ResNet-50',\n",
    "        'ResNet-110',\n",
    "        'WRN',\n",
    "        'DenseNet',\n",
    "        'clip-vit-base-patch32',\n",
    "        'clip-vit-base-patch16',\n",
    "        'clip-vit-large-patch14']\n",
    "elif dataset == 'ImageNet':\n",
    "    new_idx = [\n",
    "        'ResNet-18','ResNet-34','ResNet-50','ResNet-101','EffNet-B7','EffNetV2-S','EffNetV2-M','EffNetV2-L','ConvNeXt-T','ConvNeXt-S','ConvNeXt-B','ConvNeXt-L',\n",
    "        'ViT-B/32','ViT-B/16','ViT-L/32','ViT-L/16','ViT-H/14','Swin-T','Swin-S','Swin-B','SwinV2-T','SwinV2-S','SwinV2-B', 'clip-vit-base-patch32', 'clip-vit-base-patch16', 'clip-vit-large-patch14']\n",
    "elif dataset == 'ImageNet21k':\n",
    "    new_idx = ['mobilenetv3_large_100_miil_in21k', 'vit_base_patch16_224_miil_in21k']\n",
    "elif dataset in ['amazon_food', 'dynasent', 'mnli', 'yahoo_answers_topics']:\n",
    "    new_idx = ['t5', 't5-large', 'roberta', 'roberta-large']\n",
    "    \n",
    "df_metric = df_metric.reindex(index=new_idx)\n",
    "df_metric.index = df_metric.index.map(lambda x: x.replace('clip-vit-base-patch32', 'CLIP (ViT-B/32)'))\n",
    "df_metric.index = df_metric.index.map(lambda x: x.replace('clip-vit-base-patch16', 'CLIP (ViT-B/16)'))\n",
    "df_metric.index = df_metric.index.map(lambda x: x.replace('clip-vit-large-patch14', 'CLIP (ViT-L/14)'))\n",
    "df_metric.index = df_metric.index.map(lambda x: x.replace('vit_base_patch16_224_miil_in21k', 'ViT-B/16'))\n",
    "df_metric.index = df_metric.index.map(lambda x: x.replace('mobilenetv3_large_100_miil_in21k', 'MN3'))\n",
    "df_metric.index = df_metric.index.map(lambda x: x.replace('t5', 'T5'))\n",
    "df_metric.index = df_metric.index.map(lambda x: x.replace('roberta', 'RoBERTa'))\n",
    "\n",
    "df_metric_color = df_metric.copy()\n",
    "for col in columns:\n",
    "    df_metric_color.loc[df_metric[col] > df_metric['Accuracy'], col] = df_metric.loc[df_metric[col] > df_metric['Accuracy'], col].apply(lambda x: f'\\\\textcolor{{violet}}{{{x:.1f}}}')\n",
    "    df_metric_color.loc[df_metric[col] < df_metric['Accuracy'], col] = df_metric.loc[df_metric[col] < df_metric['Accuracy'], col].apply(lambda x: f'\\\\textcolor{{brown}}{{{x:.1f}}}')\n",
    "df_metric_color['Accuracy'] = df_metric_color['Accuracy'].apply(lambda x: f'{x:.1f}')\n",
    "\n",
    "s = df_metric_color.style\n",
    "# s = s.format('{:.1f}') # float format\n",
    "print(s.to_latex().replace('nan', 'err.'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Calib size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = 'ResNet-101'\n",
    "methods = ['TS', 'TS_tva', 'VS', 'VS_reg_tva', 'Dir-ODIR', 'Dir-ODIR_reg_tva']\n",
    "metric = 'ECE_15'\n",
    "dataset = 'ImageNet'\n",
    "load_many_seeds = True\n",
    "\n",
    "\n",
    "if load_many_seeds:\n",
    "    df = pd.DataFrame()\n",
    "    for seed in range(5):\n",
    "        if dataset == 'ImageNet':\n",
    "            for valid_size_IN in [5000, 10000, 15000, 20000, 25000]:\n",
    "                df_s = pd.read_csv(f'../results/benchmark_calibration_ImageNet_calibSize{valid_size_IN}_seed{seed}_final.csv')\n",
    "                df = pd.concat([df, df_s], axis=0)\n",
    "        \n",
    "    # get mean, std values across seeds\n",
    "    df_std = df.groupby(['dataset', 'model', 'method', 'valid_size', df.index]).std().reset_index().drop(columns='level_4')\n",
    "    df = df.groupby(['dataset', 'model', 'method', 'valid_size', df.index]).mean().reset_index().drop(columns='level_4')\n",
    "\n",
    "# choose metric & model\n",
    "df_metric = pd.pivot_table(df, index=['dataset', 'model', 'valid_size'], columns=['method'], values=[metric])[metric]\n",
    "if metric != 'Accuracy':\n",
    "    df_metric = df_metric * 100 # in %\n",
    "df_metric = df_metric.loc[(dataset, model)]\n",
    "\n",
    "df_metric_std = pd.pivot_table(df_std, index=['dataset', 'model', 'valid_size'], columns=['method'], values=[metric])[metric]\n",
    "if metric != 'Accuracy':\n",
    "    df_metric_std = df_metric_std * 100 # in %\n",
    "df_metric_std = df_metric_std.loc[(dataset, model)]\n",
    "\n",
    "labels = {\n",
    "    'TS': 'TS',\n",
    "    'TS_tva': r'TS\\textsubscript{TvA}',\n",
    "    'VS': 'VS',\n",
    "    'VS_reg_tva': r'VS\\textsubscript{reg\\_TvA}',\n",
    "    'Dir-ODIR': 'DC',\n",
    "    'Dir-ODIR_reg_tva': r'DC\\textsubscript{reg\\_TvA}'}\n",
    "\n",
    "plt.figure()\n",
    "for i, method in enumerate(methods):\n",
    "    ls = '-' if 'tva' in method else '--'\n",
    "    plt.plot(df_metric.index, df_metric[method], color=f'C{i // 2}', ls=ls, label=labels[method])\n",
    "    plt.fill_between(df_metric.index, df_metric[method] - df_metric_std[method], df_metric[method] + df_metric_std[method], color=f'C{i // 2}', alpha=0.2)\n",
    "plt.legend()\n",
    "plt.xlabel('calibration set size')\n",
    "plt.ylabel('ECE test [%]')\n",
    "tikzplotlib.save('calib_size_scaling.tikz')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = 'ResNet-101'\n",
    "methods = [\n",
    "    'Patel2021_sCW_imax',\n",
    "    'netcal_HB_eqsize', 'netcal_HB_tva_eqmass', \n",
    "    'netcal_Iso', 'netcal_Iso_tva', \n",
    "    #    'netcal_Beta', 'netcal_Beta_tva',\n",
    "    'netcal_BBQ', 'netcal_BBQ_tva']\n",
    "\n",
    "metric = 'ECE_15'\n",
    "dataset = 'ImageNet'\n",
    "load_many_seeds = True\n",
    "\n",
    "\n",
    "if load_many_seeds:\n",
    "    df = pd.DataFrame()\n",
    "    for seed in range(5):\n",
    "        if dataset == 'ImageNet':\n",
    "            for valid_size_IN in [5000, 10000, 15000, 20000, 25000]:\n",
    "                df_s = pd.read_csv(f'../results/benchmark_calibration_ImageNet_calibSize{valid_size_IN}_seed{seed}_final.csv')\n",
    "                df = pd.concat([df, df_s], axis=0)\n",
    "        \n",
    "    # get mean, std values across seeds\n",
    "    df_std = df.groupby(['dataset', 'model', 'method', 'valid_size', df.index]).std().reset_index().drop(columns='level_4')\n",
    "    df = df.groupby(['dataset', 'model', 'method', 'valid_size', df.index]).mean().reset_index().drop(columns='level_4')\n",
    "\n",
    "# choose metric & model\n",
    "df_metric = pd.pivot_table(df, index=['dataset', 'model', 'valid_size'], columns=['method'], values=[metric])[metric]\n",
    "if metric != 'Accuracy':\n",
    "    df_metric = df_metric * 100 # in %\n",
    "df_metric = df_metric.loc[(dataset, model)]\n",
    "\n",
    "df_metric_std = pd.pivot_table(df_std, index=['dataset', 'model', 'valid_size'], columns=['method'], values=[metric])[metric]\n",
    "if metric != 'Accuracy':\n",
    "    df_metric_std = df_metric_std * 100 # in %\n",
    "df_metric_std = df_metric_std.loc[(dataset, model)]\n",
    "\n",
    "labels = {\n",
    "    'netcal_HB_eqsize': 'HB',\n",
    "    'netcal_HB_tva_eqmass': r'HB\\textsubscript{TvA}',\n",
    "    'netcal_Iso': 'Iso',\n",
    "    'netcal_Iso_tva': r'Iso\\textsubscript{TvA}',\n",
    "    'netcal_Beta': 'Beta',\n",
    "    'netcal_Beta_tva': r'Beta\\textsubscript{TvA}',\n",
    "    'netcal_BBQ': 'BBQ',\n",
    "    'netcal_BBQ_tva': r'BBQ\\textsubscript{TvA}',\n",
    "    'Patel2021_sCW_imax': 'I-Max'}\n",
    "\n",
    "plt.figure()\n",
    "for i, method in enumerate(methods):\n",
    "    \n",
    "    if method == 'Patel2021_sCW_imax':\n",
    "        ls = '--'\n",
    "        color = 'k'\n",
    "    else:\n",
    "        ls = '-' if 'tva' in method else '--'\n",
    "        color = f'C{(i-1) // 2}'\n",
    "    plt.plot(df_metric.index, df_metric[method], color=color, ls=ls, label=labels[method])\n",
    "    plt.fill_between(df_metric.index, df_metric[method] - df_metric_std[method], df_metric[method] + df_metric_std[method], color=color, alpha=0.2)\n",
    "plt.xlabel('calibration set size')\n",
    "plt.ylabel('ECE test [%]')\n",
    "# # reorder legend (does not work with tikz)\n",
    "# handles, labels = plt.gca().get_legend_handles_labels()\n",
    "# labels = [labels[5], labels[1], labels[3], labels[6], labels[2], labels[4], labels[0]]\n",
    "# handles = [handles[5], handles[1], handles[3], handles[6], handles[2], handles[4], handles[0]]\n",
    "# plt.legend(handles, labels)\n",
    "plt.legend()\n",
    "\n",
    "tikzplotlib.save('calib_size_binary.tikz')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Compute time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = 'ImageNet' # 'CIFAR10', 'CIFAR100', 'ImageNet'\n",
    "valid_size_IN = 25000\n",
    "\n",
    "\n",
    "# only load single file\n",
    "df = pd.read_csv(f'../results/benchmark_calibration_{dataset}_calibSize{valid_size_IN}_seed0_computeTime.csv')\n",
    "\n",
    "df = df.replace({'Effnet-B7': 'EffNet-B7'})\n",
    "\n",
    "# choose dataset & metric\n",
    "df_metric = pd.pivot_table(df, index=['dataset', 'model'], columns=['method'], values=['execution_time'])['execution_time']\n",
    "df_metric = df_metric.loc[dataset]\n",
    "\n",
    "df_metric"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dict_ref = {'TS_tva': 'TS', 'VS_reg_tva': 'VS', 'Dir-ODIR_reg_tva': 'Dir-ODIR', 'netcal_Beta_tva': 'netcal_Beta', 'netcal_Iso_tva': 'netcal_Iso', 'netcal_BBQ_tva': 'netcal_BBQ', 'netcal_HB_tva': 'netcal_HB_eqsize'}\n",
    "\n",
    "columns = ['original', 'Patel2021_sCW_imax']\n",
    "for k in dict_ref.keys():\n",
    "    columns.append(dict_ref[k])\n",
    "    columns.append(k)\n",
    "df_metric = df_metric[columns]\n",
    "\n",
    "if 'CIFAR' in dataset:\n",
    "    new_idx = [\n",
    "        'ResNet-50',\n",
    "        'ResNet-110',\n",
    "        'WRN',\n",
    "        'DenseNet']\n",
    "elif dataset == 'ImageNet':\n",
    "    new_idx = [\n",
    "        'ResNet-50','ViT-B/16']\n",
    "df_metric = df_metric.reindex(index=new_idx)\n",
    "\n",
    "s = df_metric.style\n",
    "s = s.format('{:.0f}') # float format\n",
    "print(s.to_latex())"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
