{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.core.display import display, HTML\n",
    "display(HTML(\"<style>.container { width:100% !important; }</style>\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import os\n",
    "import re\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.table as tbl\n",
    "from matplotlib.font_manager import FontProperties\n",
    "import itertools as itert\n",
    "import numpy as np\n",
    "import seaborn as sbn\n",
    "import gzip\n",
    "\n",
    "available_datasets = {'toy', 'toy_noise', 'toy_hf', 'toy_modulated', 'toy_uniform', 'toy_noise_strong', 'yacht', 'diabetes',  'boston', 'energy', 'concrete',  'wine_red', 'abalone', 'naval', 'power', 'california','superconduct','protein','year' }\n",
    "available_methods = {'de','pu','mc_mod_sml','mc_ll','mc'}\n",
    "available_splits = {'random_folds', 'single_random_split', 'single_label_split', 'label_folds', 'single_pca_split', 'pca_folds'}\n",
    "\n",
    "dataset_to_size = {'boston': 506, 'wine_red': 1599, 'concrete': 1030, 'toy_noise': 10000, 'abalone': 4176, 'energy': 768, \n",
    "                   'year': 515345, 'protein': 45730, 'california': 20640, 'superconduct': 21263, 'diabetes': 442, 'naval': 11934, \n",
    "                   'power': 9568, 'yacht': 308, 'toy': 1000, 'toy_hf': 1000, 'toy_noise_strong': 20000, 'toy_uniform': 20000, 'toy_modulated': 20000}\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# File Reader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_dir_files(exp_dir, dataset_id):\n",
    "    dir_files_ = os.listdir('%s/%s' % (exp_dir, dataset_id))\n",
    "    dir_files = {'plots': {split: {} for split in available_splits},\n",
    "                 'method_dict': {split: {} for split in available_splits},\n",
    "                 'global_stats': {split: {} for split in available_splits},\n",
    "                'model': {split: {} for split in available_splits}}\n",
    "\n",
    "    file_pattern = r'((\\w+)_|)dataset=(\\w+)_splitmode=(\\w+)_foldidx=(\\d+)'\n",
    "    file_matcher = re.compile(file_pattern)\n",
    "    for dir_file in dir_files_:\n",
    "        matches = file_matcher.match(dir_file)\n",
    "        if matches is not None:\n",
    "            matches = matches.groups()\n",
    "\n",
    "            if matches[2] == dataset_id and matches[3] in available_splits:\n",
    "\n",
    "                split = matches[3]\n",
    "                fold_idx = int(matches[4])\n",
    "                kind = matches[1] if matches[1] is not None else 'plots'\n",
    "                dir_files[kind][split][fold_idx] = '%s/%s/%s' % (exp_dir, dataset_id, dir_file)\n",
    "\n",
    "            else:\n",
    "                print(\"Warning. File %s has unexpected form\" % dir_file)\n",
    "            \n",
    "    return dir_files\n",
    "    \n",
    "\n",
    "def load_global_stats(dir_files, splitmode):\n",
    "    \n",
    "    res = []\n",
    "    global_stats = None\n",
    "    for fold_idx in sorted(dir_files['global_stats'][splitmode]):\n",
    "        file = dir_files['global_stats'][splitmode][fold_idx]\n",
    "        if file.endswith('.json'):\n",
    "            with open(file) as f:\n",
    "                global_stats = json.load(f)\n",
    "        elif file.endswith('.json.zip'):\n",
    "            with gzip.open(file) as f:\n",
    "                global_stats = json.load(f)\n",
    "        else:\n",
    "            raise Exception(\"File has to be .json or .json.zip, but is %s\" % file)\n",
    "        \n",
    "        res.append(global_stats)\n",
    "            \n",
    "    return res\n",
    "\n",
    "def load_method_dict(dir_files, splitmode):\n",
    "    res = []\n",
    "    method_dict_json, method_dict = None, None\n",
    "    for fold_idx in sorted(dir_files['method_dict'][splitmode]):\n",
    "        file = dir_files['method_dict'][splitmode][fold_idx]\n",
    "        \n",
    "        if file.endswith('.json'):\n",
    "            with open(file) as f:\n",
    "                method_dict_json = json.load(f)\n",
    "        elif file.endswith('.json.zip'):\n",
    "            with gzip.open(file) as f:\n",
    "                method_dict_json = json.load(f)\n",
    "        else:\n",
    "            raise Exception(\"File has to be .json or .json.zip, but is %s\" % file)\n",
    "        \n",
    "        method_dict = {}\n",
    "        for key in method_dict_json:\n",
    "            df_train = pd.read_json(method_dict_json[key][0])\n",
    "            df_test = pd.read_json(method_dict_json[key][1])\n",
    "            method_dict[key] = [df_train, df_test]\n",
    "\n",
    "        res.append(method_dict)\n",
    "            \n",
    "    return res\n",
    "            "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tables, Plots, .."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "SMALL_SIZE = 6\n",
    "MEDIUM_SIZE = 20\n",
    "BIGGER_SIZE = 25\n",
    "\n",
    "plt.rc('font', size=BIGGER_SIZE)# controls default text sizes\n",
    "plt.rc('axes', titlesize=BIGGER_SIZE)\n",
    "plt.rc('axes', labelsize=BIGGER_SIZE, linewidth=5)     # fontsize of the axes title # fontsize of the x and y labels\n",
    "plt.rc('xtick', labelsize=BIGGER_SIZE)    # fontsize of the tick labels\n",
    "plt.rc('xtick.major', width=5, size=10)\n",
    "plt.rc('xtick.minor', width=5, size=10)\n",
    "plt.rc('ytick', labelsize=BIGGER_SIZE)    # fontsize of the tick labels\n",
    "plt.rc('ytick.major', width=5, size=10)\n",
    "plt.rc('ytick.minor', width=5, size=10)\n",
    "plt.rc('legend', fontsize=MEDIUM_SIZE)    # legend fontsize\n",
    "plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title\n",
    "plt.rc('lines', linewidth=5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Residual vs. std (1/3-sigma plot)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "import scipy.stats as spst\n",
    "from scipy.interpolate import interp1d\n",
    "\n",
    "def plot_densitymap(x, y, ax):\n",
    "    xmin, xmax = x.min(), x.max()\n",
    "    ymin, ymax = y.min(), y.max()\n",
    "    x_range, y_range = np.mgrid[xmin:xmax:100j, ymin:ymax:100j]\n",
    "    positions = np.vstack([x_range.ravel(), y_range.ravel()])\n",
    "    values = np.vstack([x, y])\n",
    "    kernel = spst.gaussian_kde(values)\n",
    "    density = np.reshape(kernel(positions).T, x_range.shape)\n",
    "\n",
    "    ax.imshow(np.rot90(density), cmap=plt.cm.gist_heat_r, extent=[xmin, xmax, 0, ymax], aspect='auto')\n",
    "    ax.plot(x, y, 'k.', markersize=1, alpha=0.1)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "def show_sigmaplots(exp_dirs, datasets=available_datasets, methods=['mc', 'mc_ll', 'mc_mod_sml', 'pu', 'pu_de', 'de'], splitmode = 'single_random_split', sml_name='ours', use_heat=False, savefig=None):\n",
    "    \n",
    "    plt.clf()\n",
    "    fig, ax = plt.subplots(len(datasets), len(methods), figsize=(len(methods)*10, len(datasets)*10), squeeze=False)\n",
    "    \n",
    "    all_x, all_y = np.empty(0), np.empty(0)\n",
    "    datasets_used, methods_used = [], []\n",
    "    for exp_dir in exp_dirs:\n",
    "        dataset_dirs = os.listdir(exp_dir)\n",
    "        \n",
    "        for row_i, dataset_id in enumerate(datasets):\n",
    "            if dataset_id not in dataset_dirs:\n",
    "                continue\n",
    "            \n",
    "            dir_files = get_dir_files(exp_dir, dataset_id)\n",
    "            method_dict = load_method_dict(dir_files, splitmode)\n",
    "            if method_dict is None or len(method_dict) == 0:\n",
    "                continue\n",
    "            \n",
    "            for col_i, method in enumerate(methods): \n",
    "                if method not in method_dict[0]: # using first fold\n",
    "                    continue\n",
    "                \n",
    "                test_df = method_dict[0][method][1]\n",
    "                x, y = test_df['pred_residual'].values, test_df['pred_std'].values\n",
    "                all_x, all_y = np.concatenate((x, all_x)), np.concatenate((y, all_y))\n",
    "                if use_heat:\n",
    "                    plot_densitymap(x, y, ax[row_i, col_i])\n",
    "                else:\n",
    "                    ax[row_i, col_i].scatter(x, y)\n",
    "                \n",
    "                if dataset_id not in datasets_used:\n",
    "                    datasets_used.append(dataset_id)\n",
    "                if method not in methods_used:\n",
    "                    methods_used.append(method)\n",
    "    \n",
    "    xmin, ymin = np.quantile(all_x, 0.03), np.quantile(all_y, 0.03)\n",
    "    xmax, ymax = np.quantile(all_y, 0.97), np.quantile(all_y, 0.97)\n",
    "    xmin, xmax = min(xmin, -xmax),  max(-xmin, xmax) # symmetric x\n",
    "    for i, dataset_id in enumerate(datasets_used):\n",
    "        for j, method in enumerate(methods_used):\n",
    "            \n",
    "            if i == (len(datasets_used) - 1):\n",
    "                ax[i, j].set_xlabel('pred_residual')\n",
    "            \n",
    "            if i == 0:\n",
    "                ax[i, j].set_title('%s' % method if method != 'mc_mod_sml' else sml_name)\n",
    "            \n",
    "            if j == 0:\n",
    "                ax[i, j].set_ylabel('pred_std')\n",
    "                \n",
    "            ax[i, j].plot([xmin, 0, xmax], [abs(xmin), 0, xmax], color='orange', label=r'$1 \\sigma$')\n",
    "            ax[i, j].plot([xmin, 0, xmax], [(1./3)*abs(xmin), 0, (1./3)*xmax], color='b', label=r'$3 \\sigma$')\n",
    "            ax[i, j].set_xlim(xmin, xmax)\n",
    "            ax[i, j].set_ylim(0, ymax)\n",
    "            ax[i, j].legend()\n",
    "    \n",
    "    plt.tight_layout()\n",
    "    if savefig is not None:\n",
    "        plt.savefig(savefig)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "exp_dirs = ['/INSERT/PATH/TO/EXPERIMENT/LOGS/HERE', '/YOU/CAN/ALSO/INSERT/MULTIPLE/PATHS']\n",
    "\n",
    "show_sigmaplots(exp_dirs, datasets=['toy_modulated', 'naval', 'abalone', 'superconduct'], use_heat=False, savefig='./plots/sigma.pdf')\n",
    "show_sigmaplots(exp_dirs, datasets=['toy_modulated', 'naval', 'abalone', 'superconduct'], use_heat=False, savefig='./plots/sigma.jpg')\n",
    "\n",
    "show_sigmaplots(exp_dirs, datasets=['toy_modulated', 'naval', 'abalone', 'superconduct'], use_heat=True, savefig='./plots/sigma_heat.pdf')\n",
    "show_sigmaplots(exp_dirs, datasets=['toy_modulated', 'naval', 'abalone', 'superconduct'], use_heat=True, savefig='./plots/sigma_heat.jpg')\n",
    "\n",
    "show_sigmaplots(exp_dirs, datasets=['abalone'], methods=['mc', 'mc_mod_sml', 'pu', 'de'], use_heat=False, savefig='./plots/sigma_abalone.jpg')\n",
    "show_sigmaplots(exp_dirs, datasets=['abalone'], methods=['mc', 'mc_mod_sml', 'pu', 'de'], use_heat=False, savefig='./plots/sigma_abalone.pdf')\n",
    "\n",
    "show_sigmaplots(exp_dirs, datasets=['abalone'], methods=['mc', 'mc_mod_sml', 'pu', 'de'], use_heat=True, savefig='./plots/sigma_abalone_heat.pdf')\n",
    "show_sigmaplots(exp_dirs, datasets=['abalone'], methods=['mc', 'mc_mod_sml', 'pu', 'de'], use_heat=True, savefig='./plots/sigma_abalone_heat.jpg')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# x vs. gt, mean, standard (toy data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "def plot_x_vs_preds(method_dicts, methods=None, trte=[0, 1], rows=['gt', 'mean', 'res', 'gt_std', 'std'], fold_idx=0, s=5, interpol_gt=False, savefig=None):\n",
    "    \n",
    "    if methods is None:\n",
    "        methods = [key for key in sorted(method_dicts[fold_idx])]\n",
    "    \n",
    "    n_methods = len(methods)\n",
    "    fig, ax = plt.subplots(len(trte)*len(rows), n_methods, figsize=(len(methods)*8, len(rows)*5))\n",
    "    \n",
    "    trte_to_color = {0: 'orange', 1: 'blue'}\n",
    "    ylims = [[([], []) for _ in range(len(rows))] for _ in range(len(trte))]\n",
    "    \n",
    "    for i, trte_i in enumerate(trte): # train/test\n",
    "        for j, method in enumerate(methods):\n",
    "            gt = method_dicts[fold_idx][method][trte_i]['gt'].values\n",
    "            x = np.array([val[0] for val in method_dicts[fold_idx][method][trte_i]['x'].values])\n",
    "            x_unique = np.unique(x)\n",
    "            grouped_gt = [gt[x == x_val] for x_val in x_unique]\n",
    "            gt_std = [np.std(group) for group in grouped_gt]\n",
    "            \n",
    "            pred_mean = method_dicts[fold_idx][method][trte_i]['pred_mean']\n",
    "            if method == 'mc_mod_sml':\n",
    "                pred_mean_nomc = method_dicts[fold_idx][method][trte_i]['pred_no_mc']\n",
    "                total_std = method_dicts[fold_idx][method][trte_i]['total_std']\n",
    "            pred_std = method_dicts[fold_idx][method][trte_i]['pred_std']\n",
    "            residual = pred_mean - gt\n",
    "            \n",
    "            metrics_list = []\n",
    "\n",
    "            k = 0\n",
    "            if 'gt' in rows:\n",
    "                if interpol_gt:\n",
    "                    x_argsort = np.argsort(x)\n",
    "                    ax[i*len(rows)+k, j].plot(x[x_argsort], gt[x_argsort], color=trte_to_color[trte_i])\n",
    "                else:\n",
    "                    ax[i*len(rows)+k, j].scatter(x, gt, s=s, color=trte_to_color[trte_i])\n",
    "                metrics_list.append(gt)\n",
    "                k += 1\n",
    "            \n",
    "            if 'mean' in rows:\n",
    "                if method == 'mc_mod_sml':\n",
    "                    ax[i*len(rows)+k, j].scatter(x, pred_mean_nomc, s=s, color=trte_to_color[trte_i])\n",
    "                else:\n",
    "                    ax[i*len(rows)+k, j].scatter(x, pred_mean, s=s, color=trte_to_color[trte_i])\n",
    "                metrics_list.append(pred_mean)\n",
    "                k += 1\n",
    "            \n",
    "            if 'res' in rows:\n",
    "                ax[i*len(rows)+k, j].scatter(x, residual, s=s, color=trte_to_color[trte_i])\n",
    "                metrics_list.append(residual)\n",
    "                k += 1\n",
    "\n",
    "            if 'gt_std' in rows:\n",
    "                ax[i*len(rows)+k, j].scatter(x_unique, gt_std, s=s, color=trte_to_color[trte_i])\n",
    "                metrics_list.append(gt_std)\n",
    "                k += 1\n",
    "                \n",
    "            if 'std' in rows:\n",
    "                if method == 'mc_mod_sml':\n",
    "                    ax[i*len(rows)+k, j].scatter(x, total_std, s=s, color=trte_to_color[trte_i])\n",
    "                else:\n",
    "                    ax[i*len(rows)+k, j].scatter(x, pred_std, s=s, color=trte_to_color[trte_i])\n",
    "                #ax[i*len(rows)+k, j].set_xlabel('x')\n",
    "                if j == 0:\n",
    "                    ax[i*len(rows)+k, j].set_ylabel('std')\n",
    "                metrics_list.append(pred_std)\n",
    "            \n",
    "            for plot_ident, data in enumerate(metrics_list):\n",
    "                ylims[i][plot_ident][0].append(np.min(data))\n",
    "                ylims[i][plot_ident][1].append(np.max(data))\n",
    "    \n",
    "    for j, method in enumerate(methods):\n",
    "        ax[0, j].set_title(method if method != 'mc_mod_sml' else 'ours')\n",
    "        ax[len(trte)*len(rows)-1, j].set_xlabel('x')\n",
    "        \n",
    "    for j, row in enumerate(rows):\n",
    "        ax[j, 0].set_ylabel(row)\n",
    "    \n",
    "    for i, trte_i in enumerate(trte):\n",
    "        for plot_ident, ylim_vals in enumerate(ylims[i]):\n",
    "            ylims[i][plot_ident] = (np.min(ylims[i][plot_ident][0]),\n",
    "                                    np.max(ylims[i][plot_ident][1]))\n",
    "            \n",
    "            ymin, ymax = ylims[i][plot_ident]\n",
    "            if plot_ident == rows.index('mean'):\n",
    "                ylims[i][plot_ident] = (-max(abs(ymin), abs(ymax)) -0.5, max(abs(ymin), abs(ymax)) + 0.5)\n",
    "            else:\n",
    "                ylims[i][plot_ident] = (ymin - (ymax - ymin)*0.2,\n",
    "                                   ymax + (ymax - ymin)*0.2)\n",
    "            \n",
    "            for j in range(len(methods)):\n",
    "                ax[i*len(rows)+plot_ident, j].set_ylim(*ylims[i][plot_ident])\n",
    "\n",
    "    plt.subplots_adjust(wspace=0.2, hspace=0.2)\n",
    "    \n",
    "    if savefig is not None:\n",
    "        plt.savefig(savefig)\n",
    "    #plt.savefig()\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "SMALL_SIZE = 6\n",
    "MEDIUM_SIZE = 20\n",
    "BIGGER_SIZE = 25\n",
    "\n",
    "plt.rc('font', size=BIGGER_SIZE)# controls default text sizes\n",
    "plt.rc('axes', titlesize=40)\n",
    "plt.rc('axes', labelsize=40, linewidth=5)     # fontsize of the axes title # fontsize of the x and y labels\n",
    "plt.rc('xtick', labelsize=28)    # fontsize of the tick labels\n",
    "plt.rc('xtick.major', width=5, size=10)\n",
    "plt.rc('xtick.minor', width=5, size=10)\n",
    "plt.rc('ytick', labelsize=28)    # fontsize of the tick labels\n",
    "plt.rc('ytick.major', width=5, size=10)\n",
    "plt.rc('ytick.minor', width=5, size=10)\n",
    "plt.rc('legend', fontsize=MEDIUM_SIZE)    # legend fontsize\n",
    "plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title\n",
    "plt.rc('lines', linewidth=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "exp_dir = '/INSERT/PATH/TO/TOY_MODULATED/HERE'\n",
    "dataset_id = 'toy_modulated' \n",
    "splitmode = 'single_random_split'\n",
    "\n",
    "dir_files = get_dir_files(exp_dir, dataset_id)\n",
    "global_stats = load_global_stats(dir_files, splitmode)\n",
    "method_dicts = load_method_dict(dir_files, splitmode)\n",
    "plot_x_vs_preds(method_dicts, methods=['mc', 'mc_ll', 'mc_mod_sml', 'pu', 'pu_de', 'de'], rows=['gt', 'mean', 'std'], trte=[1], s=40, interpol_gt=False, savefig='./plots/toy_modulated.pdf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "exp_dir = '/INSERT/PATH/TO/TOY_HF/HERE'\n",
    "dataset_id = 'toy_hf' \n",
    "splitmode = 'single_random_split'\n",
    "\n",
    "dir_files = get_dir_files(exp_dir, dataset_id)\n",
    "global_stats = load_global_stats(dir_files, splitmode)\n",
    "method_dicts = load_method_dict(dir_files, splitmode)\n",
    "plot_x_vs_preds(method_dicts, methods=['mc', 'mc_ll', 'mc_mod_sml', 'pu', 'pu_de', 'de'], rows=['gt', 'mean', 'std'], trte=[1], s=40, interpol_gt=True, savefig='./plots/toy_hf.pdf')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# UCI crossvalidated runs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "MEDIUM_SIZE = 27\n",
    "BIGGER_SIZE = 35\n",
    "\n",
    "plt.rc('font', size=BIGGER_SIZE)# controls default text sizes\n",
    "plt.rc('axes', titlesize=BIGGER_SIZE)\n",
    "plt.rc('axes', labelsize=BIGGER_SIZE, linewidth=5)     # fontsize of the axes title # fontsize of the x and y labels\n",
    "plt.rc('xtick', labelsize=MEDIUM_SIZE)    # fontsize of the tick labels\n",
    "plt.rc('xtick.major', width=5, size=10)\n",
    "plt.rc('xtick.minor', width=5, size=10)\n",
    "plt.rc('ytick', labelsize=BIGGER_SIZE)    # fontsize of the tick labels\n",
    "plt.rc('ytick.major', width=5, size=10)\n",
    "plt.rc('ytick.minor', width=5, size=10)\n",
    "plt.rc('legend', fontsize=MEDIUM_SIZE)    # legend fontsize\n",
    "plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title\n",
    "plt.rc('lines', linewidth=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_ranks(x, higher_is_better=False):\n",
    "    \n",
    "    x = np.array(x)\n",
    "    if higher_is_better:\n",
    "        sorted_idxs = np.argsort(1/(x+1))\n",
    "        ranks = np.zeros(len(sorted_idxs))\n",
    "        ranks[sorted_idxs] = np.arange(len(x))\n",
    "    else:\n",
    "        sorted_idxs = np.argsort(x)\n",
    "        ranks = np.zeros(len(sorted_idxs))\n",
    "        ranks[sorted_idxs] = np.arange(len(x))\n",
    "    return ranks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "def aggregate_over_folds(exp_dirs):\n",
    "\n",
    "    def _pair_to_string(a, b):\n",
    "        return str(a) + \" \" + str(b)\n",
    "    \n",
    "    aggregated = pd.DataFrame(dtype=object)\n",
    "    \n",
    "    for exp_dir in exp_dirs:\n",
    "        \n",
    "        datasets = os.listdir(exp_dir)\n",
    "        \n",
    "        for dataset_id in datasets:\n",
    "            \n",
    "            if dataset_id in available_datasets:\n",
    "                \n",
    "                print(dataset_id)\n",
    "\n",
    "                dir_files = get_dir_files(exp_dir, dataset_id)\n",
    "\n",
    "                splitmode = 'random_folds'\n",
    "                global_stats_folds = load_global_stats(dir_files, splitmode)\n",
    "\n",
    "                for fold_idx, fold in enumerate(global_stats_folds):\n",
    "                    for method in sorted(fold):\n",
    "                        for i, trte in enumerate(['train', 'test']):\n",
    "                            dataset_trte = _pair_to_string(dataset_id, trte)\n",
    "\n",
    "                            for metric in fold[method][i]:\n",
    "                                method_metric = _pair_to_string(method, metric)\n",
    "                                if dataset_trte not in aggregated.index \\\n",
    "                                or method_metric not in aggregated.columns \\\n",
    "                                or not isinstance(aggregated.loc[dataset_trte, method_metric], np.ndarray):\n",
    "                                    aggregated.loc[dataset_trte, method_metric] = 0.\n",
    "                                    aggregated[method_metric] = aggregated[method_metric].astype('object')\n",
    "                                    aggregated.at[dataset_trte, method_metric] = np.zeros(len(global_stats_folds))\n",
    "\n",
    "                                aggregated.loc[dataset_trte, method_metric][fold_idx] = fold[method][i][metric]\n",
    "\n",
    "                splitmode_to_ident = {'label_folds': 'label_test', 'pca_folds': 'pca_test'}\n",
    "                for splitmode in sorted(splitmode_to_ident):\n",
    "                    global_stats_folds = load_global_stats(dir_files, splitmode)\n",
    "\n",
    "                    fold_mode_to_fold_idxs  = {'extrapolate': [0, len(global_stats_folds) -1], 'interpolate': np.arange(1, len(global_stats_folds)-1)}\n",
    "                    for fold_mode in ['extrapolate', 'interpolate']:\n",
    "                        dataset_ident = _pair_to_string(dataset_id, '%s_%s' % (splitmode_to_ident[splitmode], fold_mode))\n",
    "                        \n",
    "                        fold_idxs = fold_mode_to_fold_idxs[fold_mode]\n",
    "                        for i, fold_idx in enumerate(fold_idxs):\n",
    "                            \n",
    "                            fold = global_stats_folds[fold_idx]\n",
    "                            for method in sorted(fold):\n",
    "\n",
    "                                for metric in fold[method][1]:\n",
    "\n",
    "                                    method_metric = _pair_to_string(method, metric)\n",
    "                                    if dataset_ident not in aggregated.index \\\n",
    "                                    or method_metric not in aggregated.columns \\\n",
    "                                    or not isinstance(aggregated.loc[dataset_ident, method_metric], np.ndarray):\n",
    "                                        aggregated.loc[dataset_ident, method_metric] = 0.\n",
    "                                        aggregated[method_metric] = aggregated[method_metric].astype('object')\n",
    "                                        aggregated.at[dataset_ident, method_metric] = np.zeros(len(fold_idxs))\n",
    "\n",
    "                                    aggregated.loc[dataset_ident, method_metric][i] = fold[method][1][metric]\n",
    "                            \n",
    "    aggregated.columns = aggregated.columns.str.split(expand=True)\n",
    "    aggregated = aggregated.set_index(aggregated.index.str.split(expand=True))\n",
    "    return aggregated\n",
    "\n",
    "def _75q(x):\n",
    "    return np.quantile(x, .75)\n",
    "\n",
    "def _25q(x):\n",
    "    return np.quantile(x, .25)\n",
    "\n",
    "default_ident_offsets = {'train': -0.2, 'test': -0.12, \n",
    "                     'label_test_interpolate': -0.04, 'label_test_extrapolate': 0.04,\n",
    "                     'pca_test_interpolate': 0.12, 'pca_test_extrapolate': 0.2}\n",
    "\n",
    "def plot_metrics(aggregated_mean, metric, idents=('train', 'test', 'label_test_interpolate', 'label_test_extrapolate', 'pca_test_interpolate', 'pca_test_extrapolate'), \n",
    "                 datasets=None, methods=None, ylim=None, yscale=None, figsize=(10, 6), ax=None, \n",
    "                 summary_stat_over=None, summary_stat_funcs=None, summary_rank_funcs=None,\n",
    "                ident_offsets=default_ident_offsets, sml_name='ours', ticklabel_tilt=0, s=10,\n",
    "                savefig=None, xticklabels=True):\n",
    "    \n",
    "    show = False\n",
    "    if ax is None:\n",
    "        plt.figure(figsize=figsize)\n",
    "        ax = plt.gca()\n",
    "        show = True\n",
    "    \n",
    "    if datasets is None:\n",
    "        datasets = aggregated_mean.index.get_level_values(0).unique().values\n",
    "    else:\n",
    "        datasets = np.array(datasets)\n",
    "    \n",
    "    if methods is None:\n",
    "        methods = aggregated_mean.columns.get_level_values(0).unique().values\n",
    "        \n",
    "    datasets_idx = np.arange(datasets.size)\n",
    "    max_dataset_idx = datasets_idx.max()\n",
    "    \n",
    "    if summary_stat_funcs is None:\n",
    "        summary_stat_funcs = [np.mean, np.median, np.min, lambda x: np.quantile(x, 0.25), lambda x: np.quantile(x, 0.75), np.max]\n",
    "    if summary_rank_funcs is None:\n",
    "        summary_rank_funcs = [np.mean]\n",
    "    \n",
    "    aggregated_ranks = None\n",
    "    if len(summary_rank_funcs) > 0:\n",
    "        aggregated_ranks = aggregated_mean.loc[(datasets, idents), (methods, metric)].apply(\n",
    "            get_ranks if metric != 'r2' else lambda x: get_ranks(x, higher_is_better=True), axis=1, result_type='broadcast')\n",
    "    \n",
    "    method_str = {'mc_mod_sml': r'%s' % sml_name, 'mc_mod_sml1': r'%s, $\\beta=0.1$' % sml_name, \n",
    "                  'mc_mod_sml25': r'%s, $\\beta=0.25$' % sml_name, 'mc_mod_sml75': r'%s, $\\beta=0.75$' % sml_name, \n",
    "                  'mc_mod_sml9': r'%s, $\\beta=0.9$' % sml_name, 'mc_mod_sml10': r'%s, $\\beta=10$' % sml_name,\n",
    "                 'mc_mod_sml0': r'%s, $\\beta=0$' % sml_name, 'sml_de': '%s_de' % sml_name}\n",
    "    ident_str = {'label_test_interpolate': 'label_test_interp', 'label_test_extrapolate': 'label_test_extrap', 'pca_test_interpolate': 'pca_test_interp', 'pca_test_extrapolate': 'pca_test_extrap'}\n",
    "    \n",
    "    for ident in idents:\n",
    "        for method in sorted(methods):\n",
    "            \n",
    "            values_over_datasets = aggregated_mean.loc[(datasets, ident), (method, metric)].values\n",
    "            ax.scatter(datasets_idx + ident_offsets[ident], \n",
    "                        values_over_datasets,\n",
    "                       s=s,\n",
    "                       label='%s, %s' % (method_str[method] if method in method_str else method, ident_str[ident] if ident in ident_str else ident),\n",
    "                       marker=method_to_marker[method],\n",
    "                       color=ident_to_color[ident],\n",
    "                       alpha=0.5)\n",
    "            \n",
    "            if summary_stat_over is None:\n",
    "                summary_values = values_over_datasets\n",
    "            else:\n",
    "                summary_values = aggregated_mean.loc[(summary_stat_over, ident), (method, metric)].values\n",
    "            \n",
    "            for summary_stat_count, summary_stat_func in enumerate(summary_stat_funcs):\n",
    "                if callable(summary_stat_func):\n",
    "                    ax.scatter(max_dataset_idx + summary_stat_count + 1 + ident_offsets[ident], \n",
    "                               summary_stat_func(summary_values),\n",
    "                               s=s,\n",
    "                               marker=method_to_marker[method],\n",
    "                               color=ident_to_color[ident],\n",
    "                               alpha=0.5)\n",
    "    \n",
    "    \n",
    "    if ylim is None:\n",
    "        ylim = ax.get_ylim()\n",
    "        \n",
    "    if summary_rank_funcs is not None and len(summary_rank_funcs) > 0:\n",
    "        for ident in idents:\n",
    "            for method in sorted(methods):\n",
    "                if summary_stat_over is None:\n",
    "                    rank_summary_values = aggregated_ranks.loc[(datasets, ident), (method, metric)].values\n",
    "                else:\n",
    "                    rank_summary_values = aggregated_ranks.loc[(summary_stat_over, ident), (method, metric)].values\n",
    "                \n",
    "                rank_summary_values = ((ylim[1]-ylim[0])/len(methods))*rank_summary_values + ylim[0]\n",
    "                \n",
    "                for summary_stat_count, summary_rank_func in enumerate(summary_rank_funcs):\n",
    "                    if callable(summary_stat_func):\n",
    "                        ax.scatter(max_dataset_idx + summary_stat_count + len(summary_stat_funcs) + 1 + ident_offsets[ident], \n",
    "                                   summary_rank_func(rank_summary_values),\n",
    "                                   s=s,\n",
    "                                   marker=method_to_marker[method],\n",
    "                                   color=ident_to_color[ident],\n",
    "                                   alpha=0.5)\n",
    "            \n",
    "                   \n",
    "    a = 0.2\n",
    "    ax.plot([-0.5, -0.5], ylim, '--', color='grey', alpha=a)\n",
    "    for ds_idx in datasets_idx:\n",
    "        if ds_idx == max_dataset_idx:\n",
    "            a = 0.5\n",
    "        else:\n",
    "            a = 0.2\n",
    "        ax.plot([ds_idx+0.5, ds_idx+0.5], ylim, '--', color='grey', alpha=a)\n",
    "    \n",
    "    for ds_idx in range(max_dataset_idx + 1, max_dataset_idx + len(summary_stat_funcs) + len(summary_rank_funcs) +1):\n",
    "        ax.plot([ds_idx+0.5, ds_idx+0.5], ylim, '--', color='grey', alpha=0.2)\n",
    "    ax.axvspan(max_dataset_idx + 1 - 0.5, max_dataset_idx + len(summary_stat_funcs) + len(summary_rank_funcs) + 1.5, color='grey', alpha=0.05)\n",
    "            \n",
    "    ax.legend(prop=fontP, bbox_to_anchor=(1, 1), loc='upper left')\n",
    "    \n",
    "    if yscale is None:\n",
    "        ax.set_yscale('log')\n",
    "    else:\n",
    "        ax.set_yscale(yscale)\n",
    "\n",
    "    reduce_func_to_str = {np.mean: 'mean', np.median: 'median', np.min: 'min', np.max: 'max', _75q: '75q', _25q: '25q'}\n",
    "        \n",
    "    ax.set_ylabel(metric if metric != 'ws_dist' else 'Wasserstein distance')\n",
    "    \n",
    "    \n",
    "    ax.set_xticks(np.concatenate((datasets_idx, np.arange(max_dataset_idx + 1, max_dataset_idx + len(summary_stat_funcs) + len(summary_rank_funcs) +1))))\n",
    "    if xticklabels:\n",
    "        ax.set_xticklabels(np.concatenate(([dataset + \"\\n\" + (\"(%.0fk)\"% (float(dataset_to_size[dataset])/1000.) if (float(dataset_to_size[dataset]) >= 1000) else \"(%.1fk)\"%(float(dataset_to_size[dataset])/1000.) ) for dataset in datasets], \n",
    "                                       [reduce_func_to_str[func] if func in reduce_func_to_str else func.__name__ for func in summary_stat_funcs],\n",
    "                                      [reduce_func_to_str[func] + \" rank\" if func in reduce_func_to_str else func.__name__ for func in summary_rank_funcs])),\n",
    "                          rotation=ticklabel_tilt)\n",
    "    else:\n",
    "        ax.set_xticklabels([])\n",
    "    if ylim is not None:\n",
    "        ax.set_ylim(*ylim)\n",
    "    ax.set_xlim(-1.25, max_dataset_idx + len(summary_stat_funcs) + len(summary_rank_funcs) + 1.25)\n",
    "    \n",
    "\n",
    "    if savefig is not None:\n",
    "        plt.savefig(savefig)\n",
    "    \n",
    "    if show:\n",
    "        plt.show()\n",
    "    \n",
    "\n",
    "def plot_uncertainty_vs_performance(aggregated_mean, perf, unc, item, second_dim_vals=None, second_dim='datasets',\n",
    "                                    idents=('train', 'test', 'label_test_interpolate', 'label_test_extrapolate', 'pca_test_interpolate', 'pca_test_extrapolate'), \n",
    "                                    xlim=None, ylim=None, xscale=None, figsize=(10, 6)):\n",
    "    \n",
    "    plt.figure(figsize=figsize)\n",
    "    \n",
    "    \n",
    "    if second_dim == 'datasets':\n",
    "        \n",
    "        if second_dim_vals is None:\n",
    "            second_dim_vals = aggregated_mean.index.get_level_values(0).unique().values\n",
    "        \n",
    "        for dataset_id in second_dim_vals:\n",
    "            for ident in idents:\n",
    "                plt.scatter(aggregated_mean.loc[(dataset_id, ident), (item, perf)],\n",
    "                            aggregated_mean.loc[(dataset_id, ident), (item, unc)],\n",
    "                           color=ident_to_color[ident],\n",
    "                           label='%s_%s' % (dataset_id, ident),\n",
    "                            marker=dataset_to_marker[dataset_id],\n",
    "                           alpha=0.5) \n",
    "                \n",
    "        plt.title('Method=%s' % item)\n",
    "    \n",
    "    elif second_dim == 'methods':\n",
    "        \n",
    "        if second_dim_vals is None:\n",
    "            second_dim_vals = aggregated_mean.columns.get_level_values(0).unique().values\n",
    "        \n",
    "        for method in second_dim_vals:\n",
    "            for ident in idents:\n",
    "                plt.scatter(aggregated_mean.loc[(item, ident), (method, perf)],\n",
    "                            aggregated_mean.loc[(item, ident), (method, unc)],\n",
    "                           color=ident_to_color[ident],\n",
    "                            marker=method_to_marker[method],\n",
    "                           label='%s_%s' % (method, ident),\n",
    "                           alpha=0.5) \n",
    "                \n",
    "        plt.title('Dataset=%s' % item)\n",
    "    \n",
    "    plt.xlabel(perf)\n",
    "    plt.ylabel(unc)\n",
    "    plt.legend(prop=fontP, bbox_to_anchor=(1, 1), loc='upper left')\n",
    "    plt.xscale('symlog')\n",
    "    plt.yscale('linear')\n",
    "    \n",
    "    if xlim is not None:\n",
    "        plt.xlim(*xlim)\n",
    "    \n",
    "    if ylim is not None:\n",
    "        plt.ylim(*ylim)\n",
    "        \n",
    "    if xscale is not None:\n",
    "        plt.xscale(xscale)\n",
    "    \n",
    "    plt.show()\n",
    "\n",
    "def plot_uncertainty_vs_uncertainty(aggregated_mean, metrics, item, second_dim='datasets', second_dim_vals=None,\n",
    "                                    idents=('train', 'test', 'label_test_interpolate', 'label_test_extrapolate', 'pca_test_interpolate', 'pca_test_extrapolate'), \n",
    "                                    ylim=None, xlim=None, xscale=None, figsize=(10, 6)):\n",
    "    \n",
    "    plt.figure(figsize=figsize)\n",
    "    \n",
    "    for metric_pair in itert.combinations(metrics, 2):\n",
    "        unc1, unc2 = metric_pair\n",
    "        \n",
    "        for ident in idents:\n",
    "            \n",
    "            if second_dim == 'datasets':\n",
    "                if second_dim_vals is None:\n",
    "                    datasets = aggregated_mean.index.get_level_values(0).unique().values\n",
    "                else:\n",
    "                    datasets = second_dim_vals\n",
    "                    \n",
    "                for dataset_id in datasets:\n",
    "                    plt.scatter(aggregated_mean.loc[(dataset_id, ident), (item, unc1)],\n",
    "                                aggregated_mean.loc[(dataset_id, ident), (item, unc2)],\n",
    "                                label='%s_%s_%s' % (unc1, unc2, ident),\n",
    "                                color=ident_to_color[ident],\n",
    "                                marker=dataset_to_marker[dataset_id],\n",
    "                                alpha=0.5)\n",
    "                    \n",
    "                    plt.title('dataset=%s' % item)\n",
    "            elif second_dim == 'methods':\n",
    "                if second_dim_vals is None:\n",
    "                    methods = aggregated_mean.columns.get_level_values(0).unique().values\n",
    "                else:\n",
    "                    methods = second_dim_vals\n",
    "                    \n",
    "                for method in methods:\n",
    "                    plt.scatter(aggregated_mean.loc[(item, ident), (method, unc1)],\n",
    "                                aggregated_mean.loc[(item, ident), (method, unc2)],\n",
    "                               label='%s_%s_%s' % (unc1, unc2, ident),\n",
    "                                color=ident_to_color[ident],\n",
    "                                marker=method_to_marker[method],\n",
    "                                alpha=0.5\n",
    "                               )\n",
    "                    \n",
    "                plt.title('dataset=%s' % item)\n",
    "        \n",
    "        plt.legend(prop=fontP, bbox_to_anchor=(1, 1), loc='upper left')\n",
    "        \n",
    "        if len(metrics) == 2:\n",
    "            plt.xlabel(metrics[0])\n",
    "            plt.ylabel(metrics[1])\n",
    "        else:\n",
    "            plt.xlabel('uncertainty')\n",
    "            plt.ylabel('uncertainty')\n",
    "        \n",
    "        if xscale is not None:\n",
    "            plt.xscale(xscale)\n",
    "        else:\n",
    "            plt.xscale('log')\n",
    "        \n",
    "        if ylim is not None:\n",
    "            plt.ylim(*ylim)\n",
    "            \n",
    "        if xlim is not None:\n",
    "            plt.xlim(*xlim)\n",
    "        \n",
    "        plt.show()\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "exp_dirs = ['/INCLUDE/EXPERIMENT/DIR/HERE', '/MULTIPLE/PATHS/CAN/BE/GIVEN']\n",
    "    \n",
    "aggregated = aggregate_over_folds(exp_dirs)\n",
    "aggregated_mean = aggregated.applymap(lambda x: np.mean(x))\n",
    "\n",
    "# order\n",
    "sorted_datasets = ['toy', 'toy_hf', 'toy_uniform', 'toy_modulated', 'toy_noise', 'toy_noise_strong', 'yacht', 'diabetes',  'boston', 'energy', 'concrete',  'wine_red', 'abalone', 'power','naval', 'california','superconduct','protein','year']\n",
    "aggregated_mean = aggregated_mean.loc[[ds for ds in sorted_datasets if ds in aggregated_mean.index.get_level_values(0).values]]\n",
    "\n",
    "\n",
    "method_to_marker = {'mc': 'D', 'mc_ll': 'd', 'mc_mod_sml': 's', \n",
    "                   'pu': '.', 'de': 'x', 'mc_mod_sml1': 'o', 'mc_mod_sml25': '^', 'mc_mod_sml75': 'p', 'mc_mod_sml9': '+', 'pu_de': '*', 'sml_de': '+', 'mc_mod_sml0': '.', 'mc_mod_sml10': '*'}\n",
    "ident_to_color = {'train': 'g', 'test': 'b', 'label_test_interpolate': 'r', 'label_test_extrapolate': 'lightcoral',\n",
    "                  'pca_test_interpolate': 'y', 'pca_test_extrapolate': 'orange'}\n",
    "dataset_to_marker = {'toy': ',', 'toy_noise': '.', 'yacht': '+', 'diabetes': 'x', 'boston': '|', \n",
    "                     'energy': '_', 'concrete': '1', 'wine_red': '3', \n",
    "                    'abalone': 'o', 'naval': 'v', 'power': '^', 'california': 's', 'superconduct': 'P', \n",
    "                     'protein': 'D', 'year': '*'}\n",
    "fontP = FontProperties()\n",
    "fontP.set_size('xx-small')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for method in available_methods:\n",
    "    print(method, aggregated_mean.loc[(slice(None), slice(None)), (method, 'ws_dist')].quantile(0.9))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for method in available_methods:\n",
    "    print(method, aggregated_mean.loc[(slice(None), slice(None)), (method, 'ece')].max())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(2, figsize=(32, 20))\n",
    "\n",
    "plot_metrics(aggregated_mean, 'rmse', idents=('train', 'test'), \n",
    "             datasets=['yacht', 'diabetes', 'boston', 'energy', 'concrete', 'wine_red', 'abalone', 'power', 'naval', 'california', 'superconduct', 'protein', 'year'], \n",
    "             methods=['de', 'mc', 'mc_ll', 'mc_mod_sml', 'pu', 'pu_de'], \n",
    "             ylim=[-0.05, 1.1], yscale='linear', figsize=(32, 10 ), \n",
    "             ident_offsets={'train':-0.2, 'test':0.2}, summary_stat_funcs=[np.mean, np.median, _25q, _75q],\n",
    "             summary_rank_funcs=[],\n",
    "             sml_name='ours', xticklabels=False, ticklabel_tilt=45, s=160, savefig=None, ax=ax[0])\n",
    "\n",
    "plot_metrics(aggregated_mean, 'rmse', idents=('label_test_interpolate', 'label_test_extrapolate', 'pca_test_interpolate', 'pca_test_extrapolate'), \n",
    "             datasets=['yacht', 'diabetes', 'boston', 'energy', 'concrete', 'wine_red', 'abalone', 'power', 'naval', 'california', 'superconduct', 'protein', 'year'], \n",
    "             methods=['de', 'mc', 'mc_ll', 'mc_mod_sml', 'pu', 'pu_de'], \n",
    "             ylim=[-0.05, 1.1], yscale='linear', figsize=(32, 10), \n",
    "             ident_offsets={'label_test_interpolate':-0.25, 'label_test_extrapolate':-0.1, 'pca_test_interpolate':0.1, 'pca_test_extrapolate': 0.25}, summary_stat_funcs=[np.mean, np.median, _25q, _75q],\n",
    "             summary_rank_funcs=[],\n",
    "             sml_name='ours', ticklabel_tilt=45, s=160, savefig=None, ax=ax[1])\n",
    "plt.tight_layout()\n",
    "plt.savefig('./plots/rmse.pdf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(2, figsize=(32, 20))\n",
    "\n",
    "plot_metrics(aggregated_mean, 'rmse', idents=('train', 'test'), \n",
    "             datasets=['yacht', 'diabetes', 'boston', 'energy', 'concrete', 'wine_red', 'abalone', 'power', 'naval', 'california', 'superconduct', 'protein', 'year'], \n",
    "             methods=['de', 'mc', 'mc_ll', 'mc_mod_sml', 'pu', 'pu_de', 'sml_de'], \n",
    "             ylim=[-0.05, 1.1], yscale='linear', figsize=(32, 10 ), \n",
    "             ident_offsets={'train':-0.2, 'test':0.2}, summary_stat_funcs=[np.mean, np.median, _25q, _75q],\n",
    "             summary_rank_funcs=[],\n",
    "             sml_name='ours', xticklabels=False, ticklabel_tilt=45, s=160, savefig=None, ax=ax[0])\n",
    "\n",
    "plot_metrics(aggregated_mean, 'rmse', idents=('label_test_interpolate', 'label_test_extrapolate', 'pca_test_interpolate', 'pca_test_extrapolate'), \n",
    "             datasets=['yacht', 'diabetes', 'boston', 'energy', 'concrete', 'wine_red', 'abalone', 'power', 'naval', 'california', 'superconduct', 'protein', 'year'], \n",
    "             methods=['de', 'mc', 'mc_ll', 'mc_mod_sml', 'pu', 'pu_de', 'sml_de'], \n",
    "             ylim=[-0.05, 1.1], yscale='linear', figsize=(32, 10), \n",
    "             ident_offsets={'label_test_interpolate':-0.25, 'label_test_extrapolate':-0.1, 'pca_test_interpolate':0.1, 'pca_test_extrapolate': 0.25}, summary_stat_funcs=[np.mean, np.median, _25q, _75q],\n",
    "             summary_rank_funcs=[],\n",
    "             sml_name='ours', ticklabel_tilt=45, s=160, savefig=None, ax=ax[1])\n",
    "plt.tight_layout()\n",
    "plt.savefig('./plots/rmse_smlde.pdf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_metrics(aggregated_mean, 'rmse', idents=('label_test_interpolate', 'label_test_extrapolate', 'pca_test_interpolate', 'pca_test_extrapolate'), \n",
    "             datasets=['yacht', 'diabetes', 'boston', 'energy', 'concrete', 'wine_red', 'abalone', 'power', 'naval', 'superconduct', 'protein', 'year'], \n",
    "             methods=['de', 'mc', 'mc_ll', 'mc_mod_sml', 'pu', 'pu_de'], \n",
    "             ylim=[-0.05, 1.1], yscale='linear', figsize=(32, 12), \n",
    "             ident_offsets={'label_test_interpolate':-0.25, 'label_test_extrapolate':-0.1, 'pca_test_interpolate':0.1, 'pca_test_extrapolate': 0.25}, summary_stat_funcs=[np.mean, np.median, _25q, _75q],\n",
    "             summary_rank_funcs=[],\n",
    "             sml_name='ours', ticklabel_tilt=45, s=160, savefig='./plots/rmse_ood.png')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(2, figsize=(32, 20))\n",
    "plot_metrics(aggregated_mean, 'nll', idents=('train', 'test'), \n",
    "             datasets=['yacht', 'diabetes', 'boston', 'energy', 'concrete', 'wine_red', 'abalone', 'power', 'naval', 'california', 'superconduct', 'protein', 'year'], \n",
    "             methods=['de', 'mc', 'mc_ll', 'mc_mod_sml', 'pu', 'pu_de'], \n",
    "             ylim=[-10, 70], yscale='linear', figsize=(32, 10 ), \n",
    "             ident_offsets={'train':-0.2, 'test':0.2}, summary_stat_funcs=[np.mean, np.median, _25q, _75q],\n",
    "             summary_rank_funcs=[],\n",
    "            sml_name='ours', xticklabels=False, s=160, savefig=None, ax=ax[0])\n",
    "\n",
    "plot_metrics(aggregated_mean, 'nll', idents=('label_test_interpolate', 'label_test_extrapolate', 'pca_test_interpolate', 'pca_test_extrapolate'), \n",
    "             datasets=['yacht', 'diabetes', 'boston', 'energy', 'concrete', 'wine_red', 'abalone', 'power', 'naval','california', 'superconduct', 'protein', 'year'], \n",
    "             methods=['de', 'mc', 'mc_ll', 'mc_mod_sml', 'pu', 'pu_de'], \n",
    "             ylim=[-10, 70], yscale='linear', figsize=(32, 10 ), \n",
    "             ident_offsets={'label_test_interpolate':-0.25, 'label_test_extrapolate':-0.1, 'pca_test_interpolate':0.1, 'pca_test_extrapolate': 0.25}, summary_stat_funcs=[np.mean, np.median, _25q, _75q],\n",
    "             summary_rank_funcs=[],\n",
    "            sml_name='ours', ticklabel_tilt=45, s=160, savefig=None, ax=ax[1])\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('./plots/nll.pdf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(2, figsize=(32, 20))\n",
    "plot_metrics(aggregated_mean, 'nll', idents=('train', 'test'), \n",
    "             datasets=['yacht', 'diabetes', 'boston', 'energy', 'concrete', 'wine_red', 'abalone', 'power', 'naval', 'california', 'superconduct', 'protein', 'year'], \n",
    "             methods=['de', 'mc', 'mc_ll', 'mc_mod_sml', 'pu', 'pu_de', 'sml_de'], \n",
    "             ylim=[-10, 70], yscale='linear', figsize=(32, 10 ), \n",
    "             ident_offsets={'train':-0.2, 'test':0.2}, summary_stat_funcs=[np.mean, np.median, _25q, _75q],\n",
    "             summary_rank_funcs=[],\n",
    "            sml_name='ours', xticklabels=False, s=160, savefig=None, ax=ax[0])\n",
    "\n",
    "plot_metrics(aggregated_mean, 'nll', idents=('label_test_interpolate', 'label_test_extrapolate', 'pca_test_interpolate', 'pca_test_extrapolate'), \n",
    "             datasets=['yacht', 'diabetes', 'boston', 'energy', 'concrete', 'wine_red', 'abalone', 'power', 'naval','california', 'superconduct', 'protein', 'year'], \n",
    "             methods=['de', 'mc', 'mc_ll', 'mc_mod_sml', 'pu', 'pu_de', 'sml_de'], \n",
    "             ylim=[-10, 70], yscale='linear', figsize=(32, 10 ), \n",
    "             ident_offsets={'label_test_interpolate':-0.25, 'label_test_extrapolate':-0.1, 'pca_test_interpolate':0.1, 'pca_test_extrapolate': 0.25}, summary_stat_funcs=[np.mean, np.median, _25q, _75q],\n",
    "             summary_rank_funcs=[],\n",
    "            sml_name='ours', ticklabel_tilt=45, s=160, savefig=None, ax=ax[1])\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('./plots/nll_smlde.pdf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_metrics(aggregated_mean, 'nll', idents=('label_test_interpolate', 'label_test_extrapolate', 'pca_test_interpolate', 'pca_test_extrapolate'), \n",
    "             datasets=['yacht', 'diabetes', 'boston', 'energy', 'concrete', 'wine_red', 'abalone', 'power', 'naval', 'superconduct', 'protein', 'year'], \n",
    "             methods=['de', 'mc', 'mc_ll', 'mc_mod_sml', 'pu', 'pu_de'], \n",
    "             ylim=[-10, 70], yscale='linear', figsize=(32, 12 ), \n",
    "             ident_offsets={'label_test_interpolate':-0.25, 'label_test_extrapolate':-0.1, 'pca_test_interpolate':0.1, 'pca_test_extrapolate': 0.25}, summary_stat_funcs=[np.mean, np.median, _75q],\n",
    "            sml_name='ours', ticklabel_tilt=45, s=160, savefig='./plots/nll_ood.png')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "fig, ax = plt.subplots(2, figsize=(32, 20))\n",
    "plot_metrics(aggregated_mean, 'ece', idents=('train', 'test'), \n",
    "             datasets=['yacht', 'diabetes', 'boston', 'energy', 'concrete', 'wine_red', 'abalone', 'power', 'naval','california', 'superconduct', 'protein', 'year'], \n",
    "             methods=['de', 'mc', 'mc_ll', 'mc_mod_sml', 'pu', 'pu_de'], \n",
    "             ylim=[0, 2], yscale='linear', figsize=(32, 10 ), \n",
    "             ident_offsets={'train':-0.2, 'test':0.2}, summary_stat_funcs=[np.mean, np.median, _25q, _75q],\n",
    "             summary_rank_funcs=[],\n",
    "            sml_name='ours', ticklabel_tilt=45,xticklabels=False, s=160, savefig=None, ax=ax[0])\n",
    "\n",
    "plot_metrics(aggregated_mean, 'ece', idents=('label_test_interpolate', 'label_test_extrapolate', 'pca_test_interpolate', 'pca_test_extrapolate'), \n",
    "             datasets=['yacht', 'diabetes', 'boston', 'energy', 'concrete', 'wine_red', 'abalone', 'power', 'naval','california', 'superconduct', 'protein', 'year'], \n",
    "             methods=['de', 'mc', 'mc_ll', 'mc_mod_sml', 'pu', 'pu_de'], \n",
    "             ylim=[0, 2], yscale='linear', figsize=(32, 10 ), \n",
    "             ident_offsets={'label_test_interpolate':-0.25, 'label_test_extrapolate':-0.1, 'pca_test_interpolate':0.1, 'pca_test_extrapolate': 0.25}, summary_stat_funcs=[np.mean, np.median, _25q, _75q],\n",
    "             summary_rank_funcs=[],\n",
    "            sml_name='ours', ticklabel_tilt=45, s=160, savefig=None, ax=ax[1])\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('./plots/ece.pdf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(2, figsize=(32, 20))\n",
    "plot_metrics(aggregated_mean, 'ece', idents=('train', 'test'), \n",
    "             datasets=['yacht', 'diabetes', 'boston', 'energy', 'concrete', 'wine_red', 'abalone', 'power', 'naval','california', 'superconduct', 'protein', 'year'], \n",
    "             methods=['de', 'mc', 'mc_ll', 'mc_mod_sml', 'pu', 'pu_de', 'sml_de'], \n",
    "             ylim=[0, 2], yscale='linear', figsize=(32, 10 ), \n",
    "             ident_offsets={'train':-0.2, 'test':0.2}, summary_stat_funcs=[np.mean, np.median, _25q, _75q],\n",
    "             summary_rank_funcs=[],\n",
    "            sml_name='ours', ticklabel_tilt=45,xticklabels=False, s=160, savefig=None, ax=ax[0])\n",
    "\n",
    "plot_metrics(aggregated_mean, 'ece', idents=('label_test_interpolate', 'label_test_extrapolate', 'pca_test_interpolate', 'pca_test_extrapolate'), \n",
    "             datasets=['yacht', 'diabetes', 'boston', 'energy', 'concrete', 'wine_red', 'abalone', 'power', 'naval','california', 'superconduct', 'protein', 'year'], \n",
    "             methods=['de', 'mc', 'mc_ll', 'mc_mod_sml', 'pu', 'pu_de', 'sml_de'], \n",
    "             ylim=[0, 2], yscale='linear', figsize=(32, 10 ), \n",
    "             ident_offsets={'label_test_interpolate':-0.25, 'label_test_extrapolate':-0.1, 'pca_test_interpolate':0.1, 'pca_test_extrapolate': 0.25}, summary_stat_funcs=[np.mean, np.median, _25q, _75q],\n",
    "             summary_rank_funcs=[],\n",
    "            sml_name='ours', ticklabel_tilt=45, s=160, savefig=None, ax=ax[1])\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('./plots/ece_smlde.pdf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_metrics(aggregated_mean, 'ece', idents=('label_test_interpolate', 'label_test_extrapolate', 'pca_test_interpolate', 'pca_test_extrapolate'), \n",
    "             datasets=['yacht', 'diabetes', 'boston', 'energy', 'concrete', 'wine_red', 'abalone', 'power', 'naval', 'superconduct', 'protein', 'year'], \n",
    "             methods=['de', 'mc', 'mc_ll', 'mc_mod_sml', 'pu', 'pu_de'], \n",
    "             ylim=[0, 2], yscale='linear', figsize=(32, 12 ), \n",
    "             ident_offsets={'label_test_interpolate':-0.25, 'label_test_extrapolate':-0.1, 'pca_test_interpolate':0.1, 'pca_test_extrapolate': 0.25}, summary_stat_funcs=[np.mean, np.median, _25q, _75q],\n",
    "             summary_rank_funcs=[],\n",
    "            sml_name='ours', ticklabel_tilt=45, s=160, savefig='./plots/ece_ood.png')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(2, figsize=(32, 20))\n",
    "plot_metrics(aggregated_mean, 'ws_dist', idents=('train', 'test'), \n",
    "             datasets=['yacht', 'diabetes', 'boston', 'energy', 'concrete', 'wine_red', 'abalone', 'power', 'naval','california', 'superconduct', 'protein', 'year'], \n",
    "             methods=['de', 'mc', 'mc_ll', 'mc_mod_sml', 'pu', 'pu_de'], \n",
    "             ylim=[0, 4], yscale='linear', figsize=(32, 10 ), \n",
    "             ident_offsets={'train':-0.2, 'test': 0.2}, summary_stat_funcs=[np.mean, np.median, _25q, _75q],\n",
    "             summary_rank_funcs=[],\n",
    "            sml_name='ours', ticklabel_tilt=45, xticklabels=False, s=160, savefig=None, ax=ax[0])\n",
    "\n",
    "plot_metrics(aggregated_mean, 'ws_dist', idents=('label_test_interpolate', 'label_test_extrapolate', 'pca_test_interpolate', 'pca_test_extrapolate'), \n",
    "             datasets=['yacht', 'diabetes', 'boston', 'energy', 'concrete', 'wine_red', 'abalone', 'power', 'naval','california', 'superconduct', 'protein', 'year'], \n",
    "             methods=['de', 'mc', 'mc_ll', 'mc_mod_sml', 'pu', 'pu_de'], \n",
    "             ylim=[0, 4], yscale='linear', figsize=(32, 10 ), \n",
    "             ident_offsets={'label_test_interpolate':-0.2, 'label_test_extrapolate':-0.1, 'pca_test_interpolate':0.1, 'pca_test_extrapolate': 0.2}, summary_stat_funcs=[np.mean, np.median, _25q, _75q],\n",
    "             summary_rank_funcs=[],\n",
    "            sml_name='ours', ticklabel_tilt=45, s=160, savefig=None, ax=ax[1])\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('./plots/ws_dist.pdf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(2, figsize=(32, 20))\n",
    "plot_metrics(aggregated_mean, 'ws_dist', idents=('train', 'test'), \n",
    "             datasets=['yacht', 'diabetes', 'boston', 'energy', 'concrete', 'wine_red', 'abalone', 'power', 'naval','california', 'superconduct', 'protein', 'year'], \n",
    "             methods=['de', 'mc', 'mc_ll', 'mc_mod_sml', 'pu', 'pu_de', 'sml_de'], \n",
    "             ylim=[0, 4], yscale='linear', figsize=(32, 10 ), \n",
    "             ident_offsets={'train':-0.2, 'test': 0.2}, summary_stat_funcs=[np.mean, np.median, _25q, _75q],\n",
    "             summary_rank_funcs=[],\n",
    "            sml_name='ours', ticklabel_tilt=45, xticklabels=False, s=160, savefig=None, ax=ax[0])\n",
    "\n",
    "plot_metrics(aggregated_mean, 'ws_dist', idents=('label_test_interpolate', 'label_test_extrapolate', 'pca_test_interpolate', 'pca_test_extrapolate'), \n",
    "             datasets=['yacht', 'diabetes', 'boston', 'energy', 'concrete', 'wine_red', 'abalone', 'power', 'naval','california', 'superconduct', 'protein', 'year'], \n",
    "             methods=['de', 'mc', 'mc_ll', 'mc_mod_sml', 'pu', 'pu_de', 'sml_de'], \n",
    "             ylim=[0, 4], yscale='linear', figsize=(32, 10 ), \n",
    "             ident_offsets={'label_test_interpolate':-0.2, 'label_test_extrapolate':-0.1, 'pca_test_interpolate':0.1, 'pca_test_extrapolate': 0.2}, summary_stat_funcs=[np.mean, np.median, _25q, _75q],\n",
    "             summary_rank_funcs=[],\n",
    "            sml_name='ours', ticklabel_tilt=45, s=160, savefig=None, ax=ax[1])\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('./plots/ws_dist_smlde.pdf')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Hyperparameter study (beta parameter)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "exp_dirs = ['/INCLUDE/PATH/TO/HYPERPARAMETER_EXPERIMENT/DIR/HERE', '/MULTIPLE/PATHS/CAN/BE/GIVEN']\n",
    "\n",
    "aggregated = aggregate_over_folds(exp_dirs)\n",
    "aggregated_mean = aggregated.applymap(lambda x: np.mean(x))\n",
    "\n",
    "# order\n",
    "sorted_datasets = ['toy', 'toy_hf', 'toy_uniform', 'toy_modulated', 'toy_noise', 'toy_noise_strong', 'yacht', 'diabetes',  'boston', 'energy', 'concrete',  'wine_red', 'abalone', 'power','naval', 'california','superconduct','protein','year']\n",
    "aggregated_mean = aggregated_mean.loc[[ds for ds in sorted_datasets if ds in aggregated_mean.index.get_level_values(0).values]]\n",
    "\n",
    "\n",
    "method_to_marker = {'mc': 'D', 'mc_ll': 'd', 'mc_mod_sml': 's', \n",
    "                   'pu': '.', 'de': 'x', 'mc_mod_sml1': 'o', 'mc_mod_sml25': '^', 'mc_mod_sml75': 'p', 'mc_mod_sml9': '+', 'pu_de': '*', 'sml_de': '+', 'mc_mod_sml0': '.', 'mc_mod_sml10': '*'}\n",
    "ident_to_color = {'train': 'g', 'test': 'b', 'label_test_interpolate': 'r', 'label_test_extrapolate': 'lightcoral',\n",
    "                  'pca_test_interpolate': 'y', 'pca_test_extrapolate': 'orange'}\n",
    "dataset_to_marker = {'toy': ',', 'toy_noise': '.', 'yacht': '+', 'diabetes': 'x', 'boston': '|', \n",
    "                     'energy': '_', 'concrete': '1', 'wine_red': '3', \n",
    "                    'abalone': 'o', 'naval': 'v', 'power': '^', 'california': 's', 'superconduct': 'P', \n",
    "                     'protein': 'D', 'year': '*'}\n",
    "fontP = FontProperties()\n",
    "fontP.set_size('xx-small')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(2, figsize=(32, 20))\n",
    "\n",
    "plot_metrics(aggregated_mean, 'rmse', idents=('train', 'test'), \n",
    "             datasets=['yacht', 'diabetes', 'boston', 'energy', 'concrete', 'wine_red', 'abalone', 'power', 'naval', 'california', 'superconduct', 'protein', 'year'], \n",
    "             methods=['mc_mod_sml1', 'mc_mod_sml25', 'mc_mod_sml', 'mc_mod_sml75', 'mc_mod_sml9'], \n",
    "             ylim=[-0.05, 1.1], yscale='linear', figsize=(32, 10 ), \n",
    "             ident_offsets={'train':-0.2, 'test':0.2}, summary_stat_funcs=[np.mean, np.median, _25q, _75q],\n",
    "             summary_rank_funcs=[],\n",
    "             sml_name='ours', xticklabels=False, ticklabel_tilt=45, s=160, savefig=None, ax=ax[0])\n",
    "\n",
    "plot_metrics(aggregated_mean, 'rmse', idents=('label_test_interpolate', 'label_test_extrapolate', 'pca_test_interpolate', 'pca_test_extrapolate'), \n",
    "             datasets=['yacht', 'diabetes', 'boston', 'energy', 'concrete', 'wine_red', 'abalone', 'power', 'naval', 'california', 'superconduct', 'protein', 'year'], \n",
    "             methods=['mc_mod_sml1', 'mc_mod_sml25', 'mc_mod_sml', 'mc_mod_sml75', 'mc_mod_sml9'], \n",
    "             ylim=[-0.05, 1.1], yscale='linear', figsize=(32, 10), \n",
    "             ident_offsets={'label_test_interpolate':-0.25, 'label_test_extrapolate':-0.1, 'pca_test_interpolate':0.1, 'pca_test_extrapolate': 0.25}, summary_stat_funcs=[np.mean, np.median, _25q, _75q],\n",
    "             summary_rank_funcs=[],\n",
    "             sml_name='ours', ticklabel_tilt=45, s=160, savefig=None, ax=ax[1])\n",
    "plt.tight_layout()\n",
    "plt.savefig('./plots/rmse_smlbeta.pdf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(2, figsize=(32, 20))\n",
    "plot_metrics(aggregated_mean, 'nll', idents=('train', 'test'), \n",
    "             datasets=['yacht', 'diabetes', 'boston', 'energy', 'concrete', 'wine_red', 'abalone', 'power', 'naval', 'california', 'superconduct', 'protein', 'year'], \n",
    "             methods=['mc_mod_sml1', 'mc_mod_sml25', 'mc_mod_sml', 'mc_mod_sml75', 'mc_mod_sml9'], \n",
    "             ylim=[-10, 70], yscale='linear', figsize=(32, 10 ), \n",
    "             ident_offsets={'train':-0.2, 'test':0.2}, summary_stat_funcs=[np.mean, np.median, _25q, _75q],\n",
    "             summary_rank_funcs=[],\n",
    "            sml_name='ours', xticklabels=False, s=160, savefig=None, ax=ax[0])\n",
    "\n",
    "plot_metrics(aggregated_mean, 'nll', idents=('label_test_interpolate', 'label_test_extrapolate', 'pca_test_interpolate', 'pca_test_extrapolate'), \n",
    "             datasets=['yacht', 'diabetes', 'boston', 'energy', 'concrete', 'wine_red', 'abalone', 'power', 'naval','california', 'superconduct', 'protein', 'year'], \n",
    "             methods=['mc_mod_sml1', 'mc_mod_sml25', 'mc_mod_sml', 'mc_mod_sml75', 'mc_mod_sml9'], \n",
    "             ylim=[-10, 70], yscale='linear', figsize=(32, 10 ), \n",
    "             ident_offsets={'label_test_interpolate':-0.25, 'label_test_extrapolate':-0.1, 'pca_test_interpolate':0.1, 'pca_test_extrapolate': 0.25}, summary_stat_funcs=[np.mean, np.median, _25q, _75q],\n",
    "             summary_rank_funcs=[],\n",
    "            sml_name='ours', ticklabel_tilt=45, s=160, savefig=None, ax=ax[1])\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('./plots/nll_sml.pdf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(2, figsize=(32, 20))\n",
    "plot_metrics(aggregated_mean, 'ece', idents=('train', 'test'), \n",
    "             datasets=['yacht', 'diabetes', 'boston', 'energy', 'concrete', 'wine_red', 'abalone', 'power', 'naval','california', 'superconduct', 'protein', 'year'], \n",
    "             methods=['mc_mod_sml1', 'mc_mod_sml25', 'mc_mod_sml', 'mc_mod_sml75', 'mc_mod_sml9'], \n",
    "             ylim=[0, 2], yscale='linear', figsize=(32, 10 ), \n",
    "             ident_offsets={'train':-0.2, 'test':0.2}, summary_stat_funcs=[np.mean, np.median, _25q, _75q],\n",
    "             summary_rank_funcs=[],\n",
    "            sml_name='ours', ticklabel_tilt=45,xticklabels=False, s=160, savefig=None, ax=ax[0])\n",
    "\n",
    "plot_metrics(aggregated_mean, 'ece', idents=('label_test_interpolate', 'label_test_extrapolate', 'pca_test_interpolate', 'pca_test_extrapolate'), \n",
    "             datasets=['yacht', 'diabetes', 'boston', 'energy', 'concrete', 'wine_red', 'abalone', 'power', 'naval','california', 'superconduct', 'protein', 'year'], \n",
    "             methods=['mc_mod_sml1', 'mc_mod_sml25', 'mc_mod_sml', 'mc_mod_sml75', 'mc_mod_sml9'], \n",
    "             ylim=[0, 2], yscale='linear', figsize=(32, 10 ), \n",
    "             ident_offsets={'label_test_interpolate':-0.25, 'label_test_extrapolate':-0.1, 'pca_test_interpolate':0.1, 'pca_test_extrapolate': 0.25}, summary_stat_funcs=[np.mean, np.median, _25q, _75q],\n",
    "             summary_rank_funcs=[],\n",
    "            sml_name='ours', ticklabel_tilt=45, s=160, savefig=None, ax=ax[1])\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('./plots/ece_sml.pdf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(2, figsize=(32, 20))\n",
    "plot_metrics(aggregated_mean, 'ws_dist', idents=('train', 'test'), \n",
    "             datasets=['yacht', 'diabetes', 'boston', 'energy', 'concrete', 'wine_red', 'abalone', 'power', 'naval','california', 'superconduct', 'protein', 'year'], \n",
    "             methods=['mc_mod_sml1', 'mc_mod_sml25', 'mc_mod_sml', 'mc_mod_sml75', 'mc_mod_sml9'], \n",
    "             ylim=[0, 4], yscale='linear', figsize=(32, 10 ), \n",
    "             ident_offsets={'train':-0.2, 'test': 0.2}, summary_stat_funcs=[np.mean, np.median, _25q, _75q],\n",
    "             summary_rank_funcs=[],\n",
    "            sml_name='ours', ticklabel_tilt=45, xticklabels=False, s=160, savefig=None, ax=ax[0])\n",
    "\n",
    "plot_metrics(aggregated_mean, 'ws_dist', idents=('label_test_interpolate', 'label_test_extrapolate', 'pca_test_interpolate', 'pca_test_extrapolate'), \n",
    "             datasets=['yacht', 'diabetes', 'boston', 'energy', 'concrete', 'wine_red', 'abalone', 'power', 'naval','california', 'superconduct', 'protein', 'year'], \n",
    "             methods=['mc_mod_sml1', 'mc_mod_sml25', 'mc_mod_sml', 'mc_mod_sml75', 'mc_mod_sml9'], \n",
    "             ylim=[0, 4], yscale='linear', figsize=(32, 10 ), \n",
    "             ident_offsets={'label_test_interpolate':-0.2, 'label_test_extrapolate':-0.1, 'pca_test_interpolate':0.1, 'pca_test_extrapolate': 0.2}, summary_stat_funcs=[np.mean, np.median, _25q, _75q],\n",
    "             summary_rank_funcs=[],\n",
    "            sml_name='ours', ticklabel_tilt=45, s=160, savefig=None, ax=ax[1])\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('./plots/ws_dist_sml.pdf')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Uncertainty/Performance measure analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "SMALL_SIZE = 6\n",
    "MEDIUM_SIZE = 20\n",
    "BIGGER_SIZE = 25\n",
    "\n",
    "plt.rc('font', size=BIGGER_SIZE)# controls default text sizes\n",
    "plt.rc('axes', titlesize=BIGGER_SIZE)\n",
    "plt.rc('axes', labelsize=BIGGER_SIZE, linewidth=5)     # fontsize of the axes title # fontsize of the x and y labels\n",
    "plt.rc('xtick', labelsize=BIGGER_SIZE)    # fontsize of the tick labels\n",
    "plt.rc('xtick.major', width=5, size=10)\n",
    "plt.rc('xtick.minor', width=5, size=10)\n",
    "plt.rc('ytick', labelsize=BIGGER_SIZE)    # fontsize of the tick labels\n",
    "plt.rc('ytick.major', width=5, size=10)\n",
    "plt.rc('ytick.minor', width=5, size=10)\n",
    "plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize\n",
    "plt.rc('figure', titlesize=SMALL_SIZE)  # fontsize of the figure title\n",
    "plt.rc('lines', linewidth=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "exp_dirs = ['/INCLUDE/PATH/TO/EXPERIMENTS/DIR/HERE', '/MULTIPLE/PATHS/CAN/BE/GIVEN']\n",
    "\n",
    "aggregated = aggregate_over_folds(exp_dirs)\n",
    "aggregated_mean = aggregated.applymap(lambda x: np.mean(x))\n",
    "\n",
    "# order\n",
    "sorted_datasets = ['toy', 'toy_hf', 'toy_uniform', 'toy_modulated', 'toy_noise', 'toy_noise_strong', 'yacht', 'diabetes',  'boston', 'energy', 'concrete',  'wine_red', 'abalone', 'power','naval', 'california','superconduct','protein','year']\n",
    "aggregated_mean = aggregated_mean.loc[[ds for ds in sorted_datasets if ds in aggregated_mean.index.get_level_values(0).values]]\n",
    "\n",
    "\n",
    "method_to_marker = {'mc': 'D', 'mc_ll': 'd', 'mc_mod_sml': 's', \n",
    "                   'pu': '.', 'de': 'x', 'mc_mod_sml1': 'o', 'mc_mod_sml25': '^', 'mc_mod_sml75': 'p', 'mc_mod_sml9': '+', 'pu_de': '*', 'sml_de': '+', 'mc_mod_sml0': '.', 'mc_mod_sml10': '*'}\n",
    "ident_to_color = {'train': 'g', 'test': 'b', 'label_test_interpolate': 'r', 'label_test_extrapolate': 'lightcoral',\n",
    "                  'pca_test_interpolate': 'y', 'pca_test_extrapolate': 'orange'}\n",
    "dataset_to_marker = {'toy': ',', 'toy_noise': '.', 'yacht': '+', 'diabetes': 'x', 'boston': '|', \n",
    "                     'energy': '_', 'concrete': '1', 'wine_red': '3', \n",
    "                    'abalone': 'o', 'naval': 'v', 'power': '^', 'california': 's', 'superconduct': 'P', \n",
    "                     'protein': 'D', 'year': '*'}\n",
    "fontP = FontProperties()\n",
    "fontP.set_size('xx-small')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_uncertainty_vs_performance(aggregated_mean, 'rmse', 'nll', available_datasets, second_dim='methods', xlim=[0, 1.3], ylim=[-6, 20], xscale='linear', figsize=(40, 25))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_uncertainty_vs_performance(aggregated_mean, 'rmse', 'r2', available_datasets, second_dim='methods', xlim=[0, 1.3], ylim=[-10, 2], xscale='linear', figsize=(40, 25))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_uncertainty_vs_uncertainty(aggregated_mean, ['ws_dist', 'ece'], available_datasets, 'methods', xlim=[1e-2, 5*10**1], figsize=(40, 25))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_uncertainty_vs_uncertainty(aggregated_mean, ['ws_dist', 'ks_dist'], available_datasets, 'methods', xlim=[1e-2, 5*10**1], figsize=(40, 25))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_uncertainty_vs_uncertainty(aggregated_mean, ['ks_dist', 'ece'], available_datasets, 'methods', xlim=[0, 1], figsize=(40, 25), xscale='linear')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_uncertainty_vs_uncertainty(aggregated_mean, ['ece', 'ece_calib'], available_datasets, 'methods', xlim=[0, 2], ylim=[0, 2], figsize=(40, 25), xscale='linear')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_uncertainty_vs_performance(aggregated_mean, 'rmse', 'ece', available_datasets, second_dim='methods', ylim=[0, 2], xscale='linear', figsize=(30, 15))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_uncertainty_vs_performance(aggregated_mean, 'nll', 'ece', available_datasets, second_dim='methods', ylim=[0, 2], xlim=[-6, 50], xscale='linear', figsize=(40, 25))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:.conda-torch-env]",
   "language": "python",
   "name": "conda-env-.conda-torch-env-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
