{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.core.display import display, HTML\n",
    "display(HTML(\"<style>.container { width:100% !important; }</style>\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os \n",
    "#os.environ[\"CUDA_VISIBLE_DEVICES\"] = '1'\n",
    "import sys\n",
    "import copy \n",
    "import datetime as dt\n",
    "import json\n",
    "import re\n",
    "import os\n",
    "import time as t\n",
    "import gzip\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import TensorDataset, DataLoader\n",
    "\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "from   torchvision.transforms import Compose, ToTensor, Resize\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import sklearn\n",
    "from   sklearn.metrics import confusion_matrix,r2_score\n",
    "from   sklearn.svm import SVC\n",
    "from   sklearn.pipeline import make_pipeline\n",
    "from   sklearn.preprocessing import StandardScaler\n",
    "from   sklearn.model_selection import train_test_split\n",
    "from   sklearn import tree\n",
    "from   sklearn.ensemble import RandomForestClassifier\n",
    "from   sklearn.utils import shuffle\n",
    "from   sklearn.decomposition import PCA\n",
    "from sklearn.isotonic import IsotonicRegression\n",
    "from sklearn.datasets import load_boston, load_diabetes, fetch_california_housing\n",
    "\n",
    "import scipy.stats as stats\n",
    "from   scipy.stats import moment, kurtosis, skew, norm, kstest, wasserstein_distance\n",
    "from scipy.interpolate import interp1d\n",
    "\n",
    "import seaborn as sbn\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Network, Metrics, ..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def cov_matrix(c,n):\n",
    "    return (1-c) * np.diag(np.ones(n)) + c * np.ones((n,n))\n",
    "\n",
    "def corr_init_matrix(in_features,out_features,c):\n",
    "    \n",
    "    cdf_func = lambda x: norm.cdf(x,loc=0,scale=np.sqrt(2)/2)\n",
    "    n        = max(in_features,out_features)\n",
    "    \n",
    "    W  = np.random.multivariate_normal(np.zeros(n),cov_matrix(c,n),n)\n",
    "    W += np.random.multivariate_normal(np.zeros(n),cov_matrix(c,n),n).T\n",
    "    W  = 0.5*W\n",
    "    W  = cdf_func(W)\n",
    "    \n",
    "    W  = (2*W-1)*np.sqrt(1.0/in_features)\n",
    "    W = W[:out_features,:in_features]\n",
    "    \n",
    "    return torch.nn.Parameter(torch.FloatTensor(W))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# network for MC dropout and deep ensembles"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Fully connected neural network with three hidden layers (with dropout)\n",
    "class Net(nn.Module):\n",
    "    def __init__(self, net_params, train_params):\n",
    "        super(Net, self).__init__()\n",
    "        \n",
    "        self.n_input      = net_params['n_input']\n",
    "        self.layer_width  = net_params['layer_width']\n",
    "        self.num_layers   = net_params['num_layers']\n",
    "        self.n_output     = net_params['n_output']\n",
    "        self.nonlinearity = net_params['nonlinearity']\n",
    "            \n",
    "        self.drop_bool    = train_params['drop_bool']\n",
    "        self.drop_bool_ll = train_params['drop_bool_ll']\n",
    "        self.drop_p       = train_params['drop_p']\n",
    "        \n",
    "        self.layers       = nn.ModuleList()\n",
    "        \n",
    "        if self.num_layers == 0:\n",
    "            self.layers.append(nn.Linear(self.n_input,self.n_output))\n",
    "        else:\n",
    "            self.layers.append(nn.Linear(self.n_input,self.layer_width))\n",
    "            for _ in range(self.num_layers-1):\n",
    "                self.layers.append(nn.Linear(self.layer_width,self.layer_width))\n",
    "            self.layers.append(nn.Linear(self.layer_width,self.n_output))\n",
    "    \n",
    "        for layer in self.layers:\n",
    "            layer.weight = corr_init_matrix(layer.in_features,layer.out_features,net_params['init_corrcoef'])\n",
    "\n",
    "    def forward(self, x, drop_bool=None):\n",
    "        \n",
    "        # drop_bool controls whether last layer dropout is used (True/False) or if values from the constructor shall be used (None)\n",
    "        if drop_bool is None:\n",
    "            drop_bool    = self.drop_bool\n",
    "            drop_bool_ll = self.drop_bool_ll\n",
    "        elif drop_bool is False:\n",
    "            drop_bool_ll = False\n",
    "        elif drop_bool is True:\n",
    "            drop_bool_ll = True\n",
    "        \n",
    "        if self.num_layers == 0:\n",
    "            x = F.dropout(x, p=self.drop_p, training=drop_bool_ll)\n",
    "            x = self.layers[-1](x)\n",
    "        else:\n",
    "            for layer in self.layers[:-2]:\n",
    "                x = layer(x)\n",
    "                x = self.nonlinearity(x)\n",
    "                x = F.dropout(x, p=self.drop_p, training=drop_bool)\n",
    "\n",
    "            x = self.layers[-2](x)\n",
    "            x = self.nonlinearity(x)\n",
    "            x = F.dropout(x, p=self.drop_p, training=drop_bool_ll)\n",
    "\n",
    "            x = self.layers[-1](x)\n",
    "        \n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# network for parametric uncertainty (PU)\n",
    "# x[:, 0]: Network output, x[:, 1] uncertainty estimate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Net_PU(Net):\n",
    "    def __init__(self, net_params, train_params):\n",
    "        super(Net_PU, self).__init__(net_params=net_params, train_params=train_params)\n",
    "        self.softplus    = nn.Softplus()\n",
    "    \n",
    "    def forward(self, x):\n",
    "        for layer in self.layers[:-1]:\n",
    "            x = layer(x)\n",
    "            x = self.nonlinearity(x)\n",
    "            x = F.dropout(x, p=self.drop_p, training=self.drop_bool)\n",
    "        x = self.layers[-1](x)\n",
    "        x = torch.stack([x[:,0],self.softplus(x[:,1])],dim=1)\n",
    "        return x     "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_network(net,data,train_params,method):\n",
    "\n",
    "    if method in ['de','pu_de','sml_de']:  # de = deep ensembles; net is a list, train all networks in that list\n",
    "        for i in range(len(net)):\n",
    "            train_network(net[i],data=data,train_params=train_params,method='mc')\n",
    "            \n",
    "    else:\n",
    "        \n",
    "        start_time = t.time()\n",
    "        \n",
    "        X_train,y_train = data\n",
    "        batch_size = train_params['batch_size']\n",
    "        batch_no = len(X_train) // batch_size\n",
    "\n",
    "        optimizer = torch.optim.Adam(net.parameters(), lr=train_params['learning_rate'])\n",
    "        #optimizer = torch.optim.SGD(net.parameters(), lr=0.2)\n",
    "        loss_func = train_params['loss_func'] \n",
    "\n",
    "        running_loss = 0.0\n",
    "        for epoch in range(train_params['num_epochs']):\n",
    "            \n",
    "            X_train, y_train = shuffle(X_train, y_train)\n",
    "            \n",
    "            \n",
    "            start_time_2 = t.time()\n",
    "            for i in range(batch_no):\n",
    "                \n",
    "                start  = i * batch_size\n",
    "                end    = start + batch_size\n",
    "                inputs = torch.FloatTensor(X_train[start:end]).to(train_params['device'])\n",
    "                labels = torch.FloatTensor(y_train[start:end].flatten()).to(train_params['device'])\n",
    "                labels = torch.unsqueeze(labels,dim=1)\n",
    "\n",
    "                # zero the parameter gradients\n",
    "                optimizer.zero_grad()\n",
    "                \n",
    "                # forward + backward + optimize\n",
    "                if loss_func == sml_loss:\n",
    "                    loss = sml_loss(net=net,data=[inputs,labels],loss_params=train_params['sml_loss_params']) \n",
    "                elif loss_func == train_second_moments_loss:\n",
    "                    loss = train_second_moments_loss(net=net,data=[inputs,labels],loss_params=train_params['sml_loss_params']) \n",
    "                else:\n",
    "                    outputs = net(inputs)\n",
    "                    loss    = loss_func(outputs,labels)\n",
    "                \n",
    "                loss.backward()\n",
    "                optimizer.step()\n",
    "\n",
    "                # print statistics\n",
    "                running_loss += loss.item()\n",
    "                \n",
    "            end_time_2 = t.time()\n",
    "                \n",
    "            if epoch % 100 == 0:\n",
    "                end_time = t.time()\n",
    "                print('Epoch {}'.format(epoch), \"loss: \",running_loss, \"took: %.5fs (exp. total time: %.5fs)\" % (end_time-start_time, (end_time-start_time)*train_params['num_epochs']/100) )\n",
    "                start_time = t.time()\n",
    "            running_loss = 0.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calc_datapoint_statistics(net,data,method, iso_reg=None):\n",
    "    \n",
    "    X,y = data\n",
    "    pred_y_samples = []\n",
    "    eps = 1e-10\n",
    "\n",
    "    df = pd.DataFrame(y.flatten()).rename(columns={0:'gt'})\n",
    "    df['x'] = X.tolist()\n",
    "    \n",
    "    with torch.no_grad(): \n",
    "        \n",
    "        # Compute mean and std from network outputs   \n",
    "        if 'mc' in method: # Get predictions with deactivated dropout and multiple predictions per input point with activated dropout\n",
    "            pred_y_no_mc = list((net(torch.FloatTensor(X),drop_bool=False).cpu().numpy()).flatten()) \n",
    "            for _ in range(200):\n",
    "                pred_y_samples.append(list((net(torch.FloatTensor(X)).cpu().numpy()).flatten()))\n",
    "            df['pred_mean'] = pd.DataFrame(pred_y_samples).mean()\n",
    "            df['pred_std']  = pd.DataFrame(pred_y_samples).std()\n",
    "            \n",
    "        \n",
    "        elif method == 'de':\n",
    "            for i in range(len(net)):\n",
    "                pred_y_samples.append(list((net[i](torch.FloatTensor(X)).cpu().numpy()).flatten()))\n",
    "            #df['pred_y_samples'] = list(np.asarray(pred_y_samples).reshape((-1, len(net))))\n",
    "            df['pred_mean'] = pd.DataFrame(pred_y_samples).mean()\n",
    "            df['pred_std']  = pd.DataFrame(pred_y_samples).std()\n",
    "            \n",
    "        \n",
    "        elif method == 'pu':\n",
    "            df[['pred_mean','pred_std']] = pd.DataFrame(net(torch.FloatTensor(X)).cpu().numpy())\n",
    "\n",
    "            \n",
    "        elif method == 'pu_de':\n",
    "            mus = []\n",
    "            sigmas = []\n",
    "            for net_ in net:\n",
    "                net_mu_sigma = net_(torch.FloatTensor(X)).cpu().data.numpy()    \n",
    "                mus.append(net_mu_sigma[:,0])\n",
    "                sigmas.append(net_mu_sigma[:,1])\n",
    "\n",
    "            mus    = np.array(mus)\n",
    "            sigmas = np.array(sigmas)\n",
    "            df['pred_mean'] = mus.mean(axis=0)\n",
    "            df['pred_std']  = np.sqrt( (sigmas**2 + mus**2).mean(axis=0) - df['pred_mean']**2 )\n",
    "\n",
    "            \n",
    "        elif method == 'sml_de':\n",
    "            mus = []\n",
    "            sigmas = []\n",
    "            for net_ in net:\n",
    "                net_pred_no_mc = list((net_(torch.FloatTensor(X),drop_bool=False).cpu().numpy()).flatten())\n",
    "                pred_y_samples = []\n",
    "\n",
    "                for _ in range(200):\n",
    "                    pred_y_samples.append(list((net_(torch.FloatTensor(X)).cpu().numpy()).flatten()))\n",
    "\n",
    "                net_pred_mean = pd.DataFrame(pred_y_samples).mean()\n",
    "                net_pred_std  = pd.DataFrame(pred_y_samples).std()\n",
    "                net_spread    = net_pred_mean - net_pred_no_mc\n",
    "                net_total_std = net_pred_std + np.abs(net_spread)\n",
    "                mus.append(net_pred_no_mc)\n",
    "                sigmas.append(net_total_std)\n",
    "\n",
    "            mus    = np.array(mus)\n",
    "            sigmas = np.array(sigmas)\n",
    "            df['pred_mean'] = mus.mean(axis=0)\n",
    "            df['pred_std']  = np.sqrt( (sigmas**2 + mus**2).mean(axis=0) - df['pred_mean']**2 )\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "        if 'mc' in method:\n",
    "            df['pred_no_mc'] = pred_y_no_mc\n",
    "            df['spread']     = df['pred_mean'] - df['pred_no_mc']\n",
    "\n",
    "\n",
    "        # Further metrics: nll (of gt in model under gaussian assumption), residual (i.e. mean - gt), error quantile (quantile of gt in normalized prediction distribution)        \n",
    "        if method == 'mc_mod_sml':\n",
    "            df['total_std'] = df['pred_std']+np.abs(df['spread'])\n",
    "            df['nll']                  = df.apply(lambda x: nll(x['pred_no_mc'],x['total_std'],x['gt']),axis=1)\n",
    "            df['pred_residual']        = df['pred_no_mc']-df['gt']\n",
    "        else:\n",
    "            df['total_std'] = df['pred_std']\n",
    "            df['nll']                  = df.apply(lambda x: nll(x['pred_mean'],x['pred_std'],x['gt']),axis=1)\n",
    "            df['pred_residual']        = df['pred_mean']-df['gt']\n",
    "\n",
    "        df['pred_residual_normed'] = df['pred_residual']/(df['total_std']+eps)   \n",
    "        df['error_quantile']       = df['pred_residual_normed'].apply(lambda x: np.round(norm.cdf(x),2))\n",
    "\n",
    "\n",
    "\n",
    "    if 'mc' in method:\n",
    "        df['net_gradient_norm'] = pd.DataFrame(X).apply(lambda x: net_gradient_norm(x,net),axis=1)\n",
    "    else:\n",
    "        df['net_gradient_norm'] = 1e10\n",
    "\n",
    "\n",
    "\n",
    "    _, iso_reg_ = calc_ece_and_iso_reg(df['error_quantile'])\n",
    "    if iso_reg is not None:\n",
    "        if isinstance(iso_reg, list):\n",
    "            if len(iso_reg) == 0:\n",
    "                iso_reg.append(iso_reg_)\n",
    "                df['error_quantile_calibrated'] = iso_reg_.predict(df['error_quantile'])\n",
    "            else:\n",
    "                df['error_quantile_calibrated'] = iso_reg[0].predict(df['error_quantile'])\n",
    "\n",
    "    return df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def nll(mu,sigma,y):\n",
    "    eps = 1e-10\n",
    "    return np.log(eps+sigma**2)/2 + ((y-mu)**2)/(eps+2*sigma**2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def nll_floored(y_pred,y_gt):  # only for training of parametric uncertainty model\n",
    "    mu    = y_pred[:,0]\n",
    "    sigma = y_pred[:,1]\n",
    "    y_gt  = torch.squeeze(y_gt)\n",
    "    \n",
    "    nll = torch.log(sigma**2)/2 + ((y_gt-mu)**2)/(2*sigma**2)\n",
    "    nll[nll<-100]=-100 # why floor?\n",
    "    nll = nll.mean()  # why mean? should be sum i guess\n",
    "    \n",
    "    return nll"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# a MSE(y_pred, y) + b MSE(|y_MC - y_pred|, |y - y_pred|) + c MSE(y_MC, y)\n",
    "# Dropout spread is learned to be equal to the residual of the prediction\n",
    "def sml_loss(net, data,loss_params):\n",
    "    inputs, labels = data\n",
    "    alpha, beta, gamma = loss_params\n",
    "    mse_loss = torch.nn.MSELoss(reduction='mean') \n",
    "    \n",
    "    outputs_no_mc     = net(inputs,drop_bool=False)\n",
    "    outputs_no_mc_det = outputs_no_mc.detach()\n",
    "    outputs_mc        = net(inputs)\n",
    "\n",
    "    loss0 = mse_loss(outputs_no_mc, labels)        \n",
    "\n",
    "    a_abs = torch.abs(outputs_mc - outputs_no_mc_det) # | y_MC - y_pred|\n",
    "    b_abs = torch.abs(labels     - outputs_no_mc_det) # | y_gt - y_pred|\n",
    "    loss1 = mse_loss(a_abs, b_abs)\n",
    "\n",
    "    loss2 = mse_loss(outputs_mc, labels)  \n",
    "\n",
    "    loss = alpha * loss0 + beta * loss1 + gamma * loss2\n",
    "\n",
    "    return loss "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# retrieves multiple mc outputs\n",
    "# loss: mean of outputs should be equal to gt (loss0)\n",
    "# distance of the mc outputs should resemble the distance of mean output to gt (loss1)\n",
    "def train_second_moments_loss(net,data,loss_params):\n",
    "    inputs, labels = data\n",
    "    alpha, beta, gamma = loss_params\n",
    "    mse_loss = torch.nn.MSELoss(reduction='mean') \n",
    "    \n",
    "    #outputs_no_mc     = net(inputs,drop_bool=False)\n",
    "    #outputs_no_mc_det = outputs_no_mc.detach()\n",
    "    outputs_mc_1    = net(inputs)\n",
    "    outputs_mc_2    = net(inputs)\n",
    "    outputs_mc_mean = 0.5*(outputs_mc_1 + outputs_mc_2) \n",
    "\n",
    "    loss01 = mse_loss(outputs_mc_1,labels)        \n",
    "    loss02 = mse_loss(outputs_mc_2,labels)\n",
    "    loss0  = 0.5*(loss01+loss02)\n",
    "    \n",
    "    a_abs = torch.abs(outputs_mc_1 - outputs_mc_2)\n",
    "    b_abs = torch.abs(labels       - outputs_mc_mean)\n",
    "    loss1 = mse_loss(a_abs, b_abs)\n",
    "\n",
    "    loss2 = mse_loss(outputs_mc_1, labels)  \n",
    "\n",
    "    loss = alpha * loss0 + beta * loss1 + gamma * loss2\n",
    "\n",
    "    return loss "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calc_ece(pred_error_quantiles):\n",
    "    bins = np.linspace(0.1, 0.9, 9)\n",
    "    n    = len(pred_error_quantiles)\n",
    "    \n",
    "    digitized = np.digitize(pred_error_quantiles, bins)\n",
    "    ece       = np.abs(((pd.Series(digitized).value_counts()/n)-0.1)).sum()\n",
    "    \n",
    "    return ece"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calc_ece_and_iso_reg(pred_error_quantiles):\n",
    "    bins = np.linspace(-0.0001,1.0001,21)\n",
    "    rel_freqs = np.zeros(len(bins)-1)\n",
    "    n = len(pred_error_quantiles)\n",
    "        \n",
    "    digitized = np.digitize(pred_error_quantiles, bins)\n",
    "    digitized = pd.Series(digitized).value_counts()/n\n",
    "    \n",
    "    for i in digitized.index:\n",
    "        rel_freqs[i-1] = digitized[i]\n",
    "    \n",
    "    ece = np.abs(rel_freqs-0.05).sum()\n",
    "    \n",
    "    model_quantiles = bins[1:]-0.025\n",
    "    emp_quantiles   = np.add.accumulate(rel_freqs)\n",
    "    iso_reg         = IsotonicRegression(out_of_bounds='clip').fit(model_quantiles,emp_quantiles)\n",
    "\n",
    " \n",
    "\n",
    "    return ece, iso_reg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calc_global_statistics(df):\n",
    "    \n",
    "    rmse = np.sqrt((df['pred_residual']**2).mean())\n",
    "    r2   = r2_score(df['gt'],df['pred_mean'])\n",
    "    nll  = df['nll'].mean()\n",
    "    \n",
    "    ece, _ = calc_ece_and_iso_reg(df['error_quantile']) \n",
    "    \n",
    "    ws_dist = wasserstein_distance(df['pred_residual_normed'],np.random.randn(100000))\n",
    "    ks_dist = kstest(df['pred_residual_normed'].values,'norm')[0]\n",
    "    \n",
    "    res = {'rmse':rmse,'r2':r2,'nll':nll,'ece':ece, 'ks_dist':ks_dist,'ws_dist':ws_dist}\n",
    "    if ('error_quantile_calibrated' in df.columns.values):\n",
    "        res['ece_calib'], _ = calc_ece_and_iso_reg(df['error_quantile_calibrated'])\n",
    "    return res\n",
    "    \n",
    "# example for re-scaling\n",
    "#np.sqrt((df_train['pred_residual']**2).mean())*np.sqrt(y_scaler.var_[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def net_gradient_norm(datapoint,net):\n",
    "    test_in  = torch.tensor(datapoint,requires_grad=True,dtype=torch.float32)\n",
    "    test_out = net(test_in)\n",
    "    return torch.autograd.grad(test_out, test_in)[0].norm().item()\n",
    "\n",
    "# better use isotropic Gaussian for const density on eps-sphere\n",
    "def random_perturb_hull(eps,dim):\n",
    "    per = np.random.uniform(2*eps,5*eps,dim)\n",
    "    if norm(per) > eps:\n",
    "        per = eps*per/norm(per)\n",
    "    return per\n",
    "\n",
    "def random_perturb_ball(eps,dim):     \n",
    "    return norm(np.random.multivariate_normal(np.zeros(dim),0.005*eps*np.diag(np.ones(dim))))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Import data / Toy dataset generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Computes gaussian * sine; Represents the noise/uncertainty of the main polynomial function\n",
    "def sine_bump(centre, std, amplitude, frequency):\n",
    "    def sine_bump_instance(x):\n",
    "        return amplitude * np.exp( -((x-centre)**2) / (2*std**2) ) * np.sin(frequency*x)\n",
    "\n",
    "    return sine_bump_instance\n",
    "\n",
    "# third degree polynomial + uncertainty (sine * gaussian)\n",
    "def poly_fluct(x, centre=-1, std=1, amplitude=4000, frequency=2):\n",
    "    return 0.01*((5*x)**2-(1*x)**3+sine_bump(centre,std,amplitude,frequency)(x))\n",
    "\n",
    "# Like poly_fluct but sine has frequency 1 (why does this represent the \"mean\"?)\n",
    "def poly_fluct_mean(x, centre=-1, std=10, amplitude=4000, frequency=1):\n",
    "    return 0.01 * ((5 * x) ** 2 - (1 * x) ** 3 + sine_bump(centre, std, amplitude, frequency)(x))\n",
    "\n",
    "# Takes the absolute value of the uncertainty curve\n",
    "def poly_fluct_sigma(x):\n",
    "    return np.abs(sine_bump(12, 5, 10, 2)(x))\n",
    "\n",
    "# samples from a gaussian with the third degree polynomial evaluated at x as mean and the absolute uncertainty curve eval. at x as sigma \n",
    "def poly_fluct_sigma_fluct_normal(x,sample_size, centre_1=-1, std_1=1, amplitude_1=4000, frequency_1=2, \n",
    "                                  centre_2=12, std_2=5, amplitude_2=1200, frequency_2=0.1, added_std=0):\n",
    "    return 0.01*(np.random.normal(100*poly_fluct(x, centre_1, std_1, amplitude_1, frequency_1),\n",
    "                                  np.abs(sine_bump(centre_2,std_2,amplitude_2,frequency_2)(x))+added_std,sample_size))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# reads plain text file without header, seperation by arbitrary number of whitespace\n",
    "# converts to float\n",
    "def plain_table_reader(file):\n",
    "    res = []\n",
    "    with open(file) as f:\n",
    "        for line in f:\n",
    "            str_feats = [feat.strip() for feat in re.split(r'\\s+', line)]\n",
    "            float_feats = [float(feat) for feat in str_feats if len(feat) > 0]\n",
    "            if len(float_feats) > 0:\n",
    "                res.append(float_feats)\n",
    "    return np.asarray(res)\n",
    "\n",
    "def load_dataset(id):\n",
    "    \n",
    "    if id == 'toy':\n",
    "        lb, ub, size = -15, 20, 1000 #-20, 30, 1000\n",
    "        x_range = np.linspace(lb, ub, size)\n",
    "        X = x_range[:, None]\n",
    "        y = poly_fluct_mean(x_range)\n",
    "        return X, y\n",
    "    \n",
    "    if id == 'toy_hf':\n",
    "        lb, ub, size = -15, 20, 1000 #-20, 30, 1000\n",
    "        x_range = np.linspace(lb, ub, size)\n",
    "        X = x_range[:, None]\n",
    "        y = poly_fluct_mean(x_range, frequency=3)\n",
    "        return X, y\n",
    "    \n",
    "    if id == 'toy_uniform':\n",
    "        sample_size = 10\n",
    "        lb, ub, steps = -15, 30, 2000\n",
    "        data_range = np.linspace(lb, ub, steps)\n",
    "        X = np.repeat(data_range, sample_size)[:, None]\n",
    "        y = np.concatenate([np.random.randn(sample_size) for i in data_range])#[:, None]\n",
    "        return X, y\n",
    "    \n",
    "    if id == 'toy_modulated':\n",
    "        sample_size = 10\n",
    "        lb, ub, steps = -15, 15, 2000\n",
    "        data_range = np.linspace(lb, ub, steps)\n",
    "        X = np.repeat(data_range, sample_size)[:, None]\n",
    "        y = np.concatenate([np.random.normal(0,np.exp(-0.02*i**2),sample_size) for i in data_range])\n",
    "        return X, y\n",
    "    \n",
    "    if id == 'toy_noise':\n",
    "        sample_size = 10\n",
    "        lb, ub, steps = -15, 30, 1000\n",
    "        data_range = np.linspace(lb, ub, steps)\n",
    "        X = np.repeat(data_range, sample_size)[:, None]\n",
    "        y = np.concatenate([poly_fluct_sigma_fluct_normal(i, sample_size) for i in data_range])[:, None]\n",
    "        return X, y\n",
    "    \n",
    "    if id == 'toy_noise_strong':\n",
    "        sample_size = 10\n",
    "        lb, ub, steps = -15, 30, 2000\n",
    "        data_range = np.linspace(lb, ub, steps)\n",
    "        X = np.repeat(data_range, sample_size)[:, None]\n",
    "        y = np.concatenate([poly_fluct_sigma_fluct_normal(i ,sample_size, centre_1=-5, std_1=2, amplitude_1=4000, frequency_1=2,\n",
    "                                                          amplitude_2=2000, frequency_2=0.1, added_std=80) for i in data_range])[:, None]\n",
    "        return X, y\n",
    "\n",
    "    # Features: 13, Points: 506\n",
    "    if id == 'boston':\n",
    "        boston = load_boston()\n",
    "        return boston['data'], boston['target']\n",
    "    \n",
    "    # features: 8, points: 20640\n",
    "    if id == 'california':\n",
    "        california = fetch_california_housing()\n",
    "        return california['data'], california['target']\n",
    "    \n",
    "    # features: 7, points: 442\n",
    "    if id == 'diabetes':\n",
    "        diabetes = load_diabetes()\n",
    "        return diabetes['data'], diabetes['target']\n",
    "    \n",
    "    #http://archive.ics.uci.edu/ml/datasets/Concrete+Compressive+Strength\n",
    "    # features: 8, points: 1030\n",
    "    if id == 'concrete':\n",
    "        concrete = pd.ExcelFile('./data/Concrete_Data.xls').parse()\n",
    "        concrete = concrete.to_numpy()\n",
    "        return concrete[:, :-1], concrete[:, -1]\n",
    "        \n",
    "    #https://archive.ics.uci.edu/ml/datasets/Energy+efficiency\n",
    "    # features: 8, points: 768; 2 gt labels (using latter one)\n",
    "    if id == 'energy':\n",
    "        energy_n_feat = 8\n",
    "        energy_n_gt = 2\n",
    "        energy = pd.ExcelFile('./data/ENB2012_data.xlsx').parse()\n",
    "        energy = energy.to_numpy()\n",
    "        assert(energy.shape[1] == (energy_n_feat + energy_n_gt))\n",
    "        return energy[:, :-energy_n_gt], energy[:, -1] # note: using cooling load gt only #energy[:, -energy_n_gt:]\n",
    "    \n",
    "    #https://archive.ics.uci.edu/ml/datasets/abalone\n",
    "    # features: 8 (using only 7, first feature is ignored), points: 4176\n",
    "    if id == 'abalone':\n",
    "        abalone = pd.read_csv('./data/abalone.data')\n",
    "        abalone = abalone.to_numpy()[:, 1:].astype(np.float64) # ignoring first feature which is categorical\n",
    "        return abalone[:, :-1], abalone[:, -1]\n",
    "    \n",
    "    #https://archive.ics.uci.edu/ml/datasets/Condition+Based+Maintenance+of+Naval+Propulsion+Plants\n",
    "    #features: 16, points: 11934, has 2 gt labels, using the latter one\n",
    "    if id == 'naval':\n",
    "        naval_n_feat = 16\n",
    "        naval_n_gt = 2\n",
    "        naval = plain_table_reader('./data/UCI CBM Dataset/data.txt')\n",
    "        return naval[:, :-naval_n_gt], naval[:, -1] # note: using turbine gt only #naval[:, -naval_n_gt:]\n",
    "    \n",
    "    #https://archive.ics.uci.edu/ml/datasets/Combined+Cycle+Power+Plant\n",
    "    if id == 'power':\n",
    "        power = pd.ExcelFile('./data/CCPP/Folds5x2_pp.xlsx').parse()\n",
    "        power = power.to_numpy()\n",
    "        return power[:, :-1], power[:, -1]\n",
    "    \n",
    "    #https://archive.ics.uci.edu/ml/datasets/Physicochemical+Properties+of+Protein+Tertiary+Structure\n",
    "    if id == 'protein':\n",
    "        protein = pd.read_csv('./data/CASP.csv')\n",
    "        protein = protein.to_numpy()\n",
    "        return protein[:, 1:], protein[:, 0]\n",
    "    \n",
    "    #https://archive.ics.uci.edu/ml/datasets/wine+quality\n",
    "    # features: 11, points: 1599\n",
    "    if id == 'wine_red':\n",
    "        wine_red = pd.read_csv('./data/winequality-red.csv', sep=';')\n",
    "        wine_red = wine_red.to_numpy()\n",
    "        return wine_red[:, :-1], wine_red[:, -1]\n",
    "    \n",
    "    #http://archive.ics.uci.edu/ml/datasets/yacht+hydrodynamics\n",
    "    # features: 6, points: 308\n",
    "    if id == 'yacht':\n",
    "        yacht = plain_table_reader('./data/yacht_hydrodynamics.data')\n",
    "        return yacht[:, :-1], yacht[:, -1]\n",
    "    \n",
    "    #https://archive.ics.uci.edu/ml/datasets/YearPredictionMSD\n",
    "    # features: 90, points: 515345\n",
    "    if id == 'year':\n",
    "        year = pd.read_csv('./data/YearPredictionMSD.txt', header=None)\n",
    "        year = year.to_numpy()\n",
    "        return year[:, 1:], year[:, 0]\n",
    "    \n",
    "    # features: 81, points: 21263\n",
    "    if id == 'superconduct':\n",
    "        superconduct = pd.read_csv('./data/superconduct/train.csv')\n",
    "        superconduct = superconduct.to_numpy()\n",
    "        return superconduct[:, :-1], superconduct[:, -1]\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# generate idx lists for data split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_idx_splits(X, y, fold_idxs=None, split_perc=0.8, splits=None):\n",
    "    \n",
    "    if fold_idxs is None:\n",
    "        fold_idxs = list(range(10))\n",
    "    else:\n",
    "        fold_idxs = np.array(fold_idxs)\n",
    "        if np.any((fold_idxs < 0) | (fold_idxs > 9) ):\n",
    "            raise Exception(\"Given fold_idxs have to lie in [0, 9]\")\n",
    "    \n",
    "    res = dict()\n",
    "    \n",
    "    n_data = X.shape[0]\n",
    "    n_test = n_data // 10\n",
    "    n_train = n_data - n_test\n",
    "    assert(n_data == len(y))\n",
    "    \n",
    "    idxs_random = np.random.choice(n_data, size=n_data, replace=False)\n",
    "    \n",
    "    if 'random_folds' in splits:\n",
    "        folds = []\n",
    "        for i in fold_idxs:\n",
    "            start_test = i*n_test\n",
    "            end_test = start_test + n_test\n",
    "            folds.append((np.concatenate((idxs_random[0:start_test], idxs_random[end_test:])),\n",
    "                              idxs_random[start_test:end_test]))\n",
    "        res['random_folds'] = folds\n",
    "    \n",
    "    if 'single_random_split' in splits:\n",
    "        n_single_train = int(split_perc*n_data)\n",
    "        res['single_random_split'] = (idxs_random[:n_single_train], idxs_random[n_single_train:])\n",
    "    \n",
    "    if 'single_label_split' in splits:\n",
    "        y_median = np.median(y)\n",
    "        idxs_lower_half = np.where(y <= y_median)[0]\n",
    "        idxs_upper_half = np.where(y > y_median)[0]\n",
    "        res['single_label_split'] = (idxs_lower_half, idxs_upper_half)\n",
    "    \n",
    "    if 'label_folds' in splits:\n",
    "        quantile_fold_range = 1. / 10.\n",
    "        fold_label = []\n",
    "        for i in fold_idxs:\n",
    "            lower_quantile = i * quantile_fold_range\n",
    "            upper_quantile = lower_quantile + quantile_fold_range\n",
    "            lower_quantile = np.quantile(y, lower_quantile)\n",
    "            upper_quantile = np.quantile(y, upper_quantile)\n",
    "            fold_label.append((np.concatenate((np.where(y < lower_quantile)[0], np.where(y > upper_quantile)[0])),\n",
    "                               np.where((lower_quantile <= y) & (y <= upper_quantile))[0]))\n",
    "        res['label_folds'] = fold_label\n",
    "    \n",
    "    pca_scaler = StandardScaler()  # each feature centered around mean with std = 1\n",
    "    X_scaled = pca_scaler.fit_transform(X)\n",
    "\n",
    "    pca = PCA(n_components=min(X_scaled.shape[1], 5))\n",
    "    pca.fit(X_scaled)\n",
    "    projections = np.matmul(X_scaled, pca.components_[0])\n",
    "    res['projections'] = projections\n",
    "    \n",
    "    if 'single_pca_split' in splits:\n",
    "        projections_median = np.median(projections)\n",
    "        idxs_lower_pca0 = np.where(projections <= projections_median)[0]\n",
    "        idxs_upper_pca0 = np.where(projections > projections_median)[0]\n",
    "        res['single_pca_split'] = (idxs_lower_pca0, idxs_upper_pca0)\n",
    "    \n",
    "    if 'pca_folds' in splits:\n",
    "        quantile_fold_range = 1. / 10.\n",
    "        fold_pca0 = []\n",
    "        for i in fold_idxs:\n",
    "            lower_quantile = i * quantile_fold_range\n",
    "            upper_quantile = lower_quantile + quantile_fold_range\n",
    "            lower_quantile = np.quantile(projections, lower_quantile)\n",
    "            upper_quantile = np.quantile(projections, upper_quantile)\n",
    "            fold_pca0.append((np.concatenate((np.where(projections < lower_quantile)[0], np.where(projections > upper_quantile)[0])),\n",
    "                             np.where((lower_quantile <= projections) & (projections <= upper_quantile))[0]))\n",
    "        res['pca_folds'] = fold_pca0\n",
    "    \n",
    "    return res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def scale_to_standard(X_train, y_train, X_test, y_test):\n",
    "\n",
    "    X_scaler = StandardScaler()\n",
    "    y_scaler = StandardScaler()\n",
    "\n",
    "    X_train = X_scaler.fit_transform(X_train)\n",
    "    y_train = y_scaler.fit_transform(y_train.reshape(-1,1))\n",
    "\n",
    "    X_test = X_scaler.transform(X_test)\n",
    "    y_test = y_scaler.transform(y_test.reshape(-1,1))\n",
    "    \n",
    "    return X_train, y_train, X_test, y_test\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Plot generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "def plot_densitymap(x, y, ax):\n",
    "    xmin, xmax = x.min(), x.max()\n",
    "    ymin, ymax = y.min(), y.max()\n",
    "    x_range, y_range = np.mgrid[xmin:xmax:100j, ymin:ymax:100j]\n",
    "    positions = np.vstack([x_range.ravel(), y_range.ravel()])\n",
    "    values = np.vstack([x, y])\n",
    "    kernel = stats.gaussian_kde(values)\n",
    "    density = np.reshape(kernel(positions).T, x_range.shape)\n",
    "\n",
    "    ax.imshow(np.rot90(density), cmap=plt.cm.gist_heat_r, extent=[xmin, xmax, ymin, ymax], aspect='equal')\n",
    "    ax.plot(x, y, 'k.', markersize=1, alpha=0.1)\n",
    "\n",
    "def plot_results(method_dict, file=None):\n",
    "    # set min-/max-values for all subplots in the next cell\n",
    "    \n",
    "    plt.clf()\n",
    "\n",
    "    concatted = pd.DataFrame()\n",
    "\n",
    "    for key in method_dict:\n",
    "        for i in [0,1]:\n",
    "            concatted = pd.concat([concatted,method_dict[key][i][['pred_mean','pred_std','pred_residual','pred_residual_normed']]])\n",
    "\n",
    "    max_pred_mean     = concatted.quantile(0.98)['pred_mean']\n",
    "    min_pred_mean     = concatted.quantile(0.02)['pred_mean']\n",
    "    max_pred_std      = concatted.quantile(0.98)['pred_std']\n",
    "    min_pred_std      = concatted.quantile(0.02)['pred_std']\n",
    "    max_pred_residual = concatted.quantile(0.98)['pred_residual']\n",
    "    min_pred_residual = concatted.quantile(0.02)['pred_residual']\n",
    "    max_pred_residual_normed = concatted['pred_residual_normed'].quantile(0.98)\n",
    "    min_pred_residual_normed = concatted['pred_residual_normed'].quantile(0.02)\n",
    "\n",
    "    # visualize all results\n",
    "\n",
    "    num_methods  = method_dict.__len__()\n",
    "    method_names = list(method_dict.keys())\n",
    "    datasets     = ['train','test']\n",
    "    colors       = ['b','orange']\n",
    "\n",
    "    fig,ax = plt.subplots(15,num_methods,figsize=(35,40), squeeze=False)\n",
    "    ax = np.array(ax)\n",
    "    \n",
    "    for j,method in enumerate(method_dict):\n",
    "        for i,df in enumerate(method_dict[method]):\n",
    "        \n",
    "            df.plot.scatter(x='gt',y='pred_mean',ax=ax[0,j],color=colors[i])\n",
    "            df.plot.scatter(x='gt',y='pred_std',ax=ax[1,j],color=colors[i])\n",
    "            df.plot.scatter(x='gt',y='total_std',ax=ax[2,j],color=colors[i])\n",
    "            df.plot.scatter(x='gt',y='pred_residual',ax=ax[3,j],color=colors[i])\n",
    "            df.plot.scatter(x='pred_residual',y='pred_std',ax=ax[4,j],color=colors[i])\n",
    "            df.plot.scatter(x='pred_residual',y='total_std',ax=ax[5,j],color=colors[i])\n",
    "            \n",
    "            df.plot.scatter(x='pca0_projection',y='pred_mean',ax=ax[10,j],color=colors[i])\n",
    "            df.plot.scatter(x='pca0_projection',y='pred_std',ax=ax[11,j],color=colors[i])\n",
    "            df.plot.scatter(x='pca0_projection',y='pred_residual',ax=ax[12,j],color=colors[i])\n",
    "            ax[13,j].hist(df['pred_residual_normed'],bins=30,density=True,color=colors[i])\n",
    "            df.plot.scatter(x='gt',y='net_gradient_norm',ax=ax[14,j],color=colors[i])\n",
    "\n",
    "        try:\n",
    "            plot_densitymap(method_dict[method][0]['pred_residual'], method_dict[method][0]['pred_std'], ax[6, j])\n",
    "        except Exception:\n",
    "            print(\"Exception caught in plot_densitymap, skipping plot ... \", method_dict[method][0]['pred_residual'],  method_dict[method][0]['pred_std'])\n",
    "        \n",
    "        try:\n",
    "            plot_densitymap(method_dict[method][0]['pred_residual'], method_dict[method][0]['total_std'], ax[7, j])\n",
    "        except Exception:\n",
    "            print(\"Exception caught in plot_densitymap, skipping plot ... \", method_dict[method][0]['pred_residual'],  method_dict[method][0]['total_std'])\n",
    "        \n",
    "        try:\n",
    "            plot_densitymap(method_dict[method][1]['pred_residual'], method_dict[method][1]['pred_std'], ax[8, j])\n",
    "        except Exception:\n",
    "            print(\"Exception caught in plot_densitymap, skipping plot ... \", method_dict[method][1]['pred_residual'],  method_dict[method][1]['pred_std'])\n",
    "\n",
    "        try:\n",
    "            plot_densitymap(method_dict[method][1]['pred_residual'], method_dict[method][1]['total_std'], ax[9, j])\n",
    "        except Exception:\n",
    "            print(\"Exception caught in plot_densitymap, skipping plot ... \", method_dict[method][1]['pred_residual'],  method_dict[method][1]['total_std'])\n",
    "        \n",
    "        line_sigma_1_data =  pd.DataFrame([[x,np.abs(x)] for x in np.linspace(min_pred_residual-0.2,max_pred_residual+0.2,200)])\n",
    "        line_sigma_3_data =  pd.DataFrame([[x,np.abs(x/3)] for x in np.linspace(min_pred_residual-0.2,max_pred_residual+0.2,200)])\n",
    "       \n",
    "        line_sigma_1_data.plot(kind='line',x=0,y=1, color='k',ax=ax[4,j],alpha=1)\n",
    "        line_sigma_3_data.plot(kind='line',x=0,y=1,color='r',ax=ax[4,j],alpha=1) \n",
    "        line_sigma_1_data.plot(kind='line',x=0,y=1,color='k',ax=ax[5,j],alpha=1)\n",
    "        line_sigma_3_data.plot(kind='line',x=0,y=1,color='r',ax=ax[5,j],alpha=1)\n",
    "        line_sigma_1_data.plot(kind='line',x=0,y=1,color='k',ax=ax[6,j],alpha=1)\n",
    "        line_sigma_3_data.plot(kind='line',x=0,y=1,color='r',ax=ax[6,j],alpha=1) \n",
    "        line_sigma_1_data.plot(kind='line',x=0,y=1,color='k',ax=ax[7,j],alpha=1)\n",
    "        line_sigma_3_data.plot(kind='line',x=0,y=1,color='r',ax=ax[7,j],alpha=1)\n",
    "        line_sigma_1_data.plot(kind='line',x=0,y=1,color='k',ax=ax[8,j],alpha=1)\n",
    "        line_sigma_3_data.plot(kind='line',x=0,y=1,color='r',ax=ax[8,j],alpha=1) \n",
    "        line_sigma_1_data.plot(kind='line',x=0,y=1,color='k',ax=ax[9,j],alpha=1)\n",
    "        line_sigma_3_data.plot(kind='line',x=0,y=1,color='r',ax=ax[9,j],alpha=1)\n",
    "        \n",
    "        ax[0,j].set_ylim([min_pred_mean,max_pred_mean])\n",
    "        ax[1,j].set_ylim([min_pred_std,max_pred_std])\n",
    "        ax[2,j].set_ylim([min_pred_std,max_pred_std])\n",
    "        ax[3,j].set_ylim([min_pred_residual,max_pred_residual])\n",
    "        ax[4,j].set_xlim([min_pred_residual-0.2,max_pred_residual+0.2])\n",
    "        ax[4,j].set_ylim([min_pred_std,max_pred_std])\n",
    "        ax[4,j].set_xlabel('pred_residual')\n",
    "        ax[5,j].set_xlim([min_pred_residual-0.2,max_pred_residual+0.2])\n",
    "        ax[5,j].set_ylim([min_pred_std,max_pred_std])\n",
    "        ax[5,j].set_xlabel('pred_residual')\n",
    "        \n",
    "        ax[6,j].set_xlim([min_pred_residual-0.2,max_pred_residual+0.2])\n",
    "        ax[6,j].set_ylim([min_pred_std,max_pred_std])\n",
    "        ax[6,j].set_xlabel('pred_residual')\n",
    "        ax[7,j].set_xlim([min_pred_residual-0.2,max_pred_residual+0.2])\n",
    "        ax[7,j].set_ylim([min_pred_std,max_pred_std])\n",
    "        ax[7,j].set_xlabel('pred_residual')\n",
    "        \n",
    "        ax[8,j].set_xlim([min_pred_residual-0.2,max_pred_residual+0.2])\n",
    "        ax[8,j].set_ylim([min_pred_std,max_pred_std])\n",
    "        ax[8,j].set_xlabel('pred_residual')\n",
    "        ax[9,j].set_xlim([min_pred_residual-0.2,max_pred_residual+0.2])\n",
    "        ax[9,j].set_ylim([min_pred_std,max_pred_std])\n",
    "        ax[9,j].set_xlabel('pred_residual')\n",
    "        \n",
    "        ax[11,j].set_ylim([min_pred_mean,max_pred_mean])\n",
    "        ax[11,j].set_ylim([min_pred_std,max_pred_std])\n",
    "        ax[12,j].set_ylim([min_pred_residual,max_pred_residual])\n",
    "        #ax[7,j].set_ylim(-0.05,1)\n",
    "        ax[13,j].set_xlim([min_pred_residual_normed-5,max_pred_residual_normed+5])\n",
    "        ax[13,j].set_xlabel('pred_residual_normed')\n",
    "        ax[13,j].set_ylabel('pdf')\n",
    "        #ax[7,j].set_yscale('log')\n",
    "\n",
    "        for k in range(6):\n",
    "            ax[k,j].set_title(method_names[j]+' (train/test data)')\n",
    "        \n",
    "        for k in range(6, 8):\n",
    "            ax[k, j].set_title(method_names[j] + ' (train data)')\n",
    "            ax[k, j].legend()\n",
    "        \n",
    "        for k in range(8, 10):\n",
    "            ax[k, j].set_title(method_names[j] + ' (test data)')\n",
    "            ax[k, j].legend()\n",
    "        \n",
    "        for k in range(10, 15):\n",
    "            ax[k,j].set_title(method_names[j]+' (train/test data)')\n",
    "\n",
    "    plt.tight_layout()\n",
    "    \n",
    "    if file is not None:\n",
    "        plt.savefig(file)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training and evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_net_from_method(method, n_feat, net_params, train_params):\n",
    "    \n",
    "    net_params['n_input'] = n_feat\n",
    "    \n",
    "    if method == 'mc':  # dropout in all layers, standard mse\n",
    "        net_params['n_output']       = 1    \n",
    "        train_params['drop_bool']    = True\n",
    "        train_params['drop_bool_ll'] = True\n",
    "        train_params['loss_func']    = torch.nn.MSELoss(reduction='mean')\n",
    "        #train_params['num_epochs']   = 2000\n",
    "        train_params['sml_loss_params'] = None\n",
    "\n",
    "        net = Net(net_params=net_params,train_params=train_params)\n",
    "        net.to(train_params['device'])\n",
    "        return net\n",
    "\n",
    "    elif method == 'mc_ll':  # dropout in last layer, standard mse\n",
    "        net_params['n_output']       = 1    \n",
    "        train_params['drop_bool']    = False\n",
    "        train_params['drop_bool_ll'] = True\n",
    "        train_params['loss_func']    = torch.nn.MSELoss(reduction='mean')\n",
    "        #train_params['num_epochs']   = 2000\n",
    "        train_params['sml_loss_params'] = None\n",
    "\n",
    "        net = Net(net_params=net_params,train_params=train_params)\n",
    "        net.to(train_params['device'])\n",
    "        return net\n",
    "    \n",
    "    elif method == 'mc_mod_sml0':\n",
    "        net_params['n_output']       = 1    \n",
    "        train_params['drop_bool']    = True\n",
    "        train_params['drop_bool_ll'] = True\n",
    "        #train_params['num_epochs']   = 2000\n",
    "        train_params['loss_func']    = sml_loss\n",
    "        train_params['sml_loss_params']= [1,0,0]\n",
    "\n",
    "        net = Net(net_params=net_params,train_params=train_params)\n",
    "        net.to(train_params['device'])\n",
    "        return net\n",
    "    \n",
    "    elif method == 'mc_mod_sml10':\n",
    "        net_params['n_output']       = 1    \n",
    "        train_params['drop_bool']    = True\n",
    "        train_params['drop_bool_ll'] = True\n",
    "        #train_params['num_epochs']   = 2000\n",
    "        train_params['loss_func']    = sml_loss\n",
    "        train_params['sml_loss_params']= [1,10,0]\n",
    "\n",
    "        net = Net(net_params=net_params,train_params=train_params)\n",
    "        net.to(train_params['device'])\n",
    "        return net\n",
    "\n",
    "    elif method == 'mc_mod_sml':  # dropout in all layers, sml loss\n",
    "        net_params['n_output']       = 1    \n",
    "        train_params['drop_bool']    = True\n",
    "        train_params['drop_bool_ll'] = True\n",
    "        #train_params['num_epochs']   = 2000\n",
    "        train_params['loss_func']    = sml_loss\n",
    "        train_params['sml_loss_params']= [1,0.5,0]\n",
    "\n",
    "        net = Net(net_params=net_params,train_params=train_params)\n",
    "        net.to(train_params['device'])\n",
    "        return net\n",
    "    \n",
    "    elif method == 'mc_mod_sml75':  # dropout in all layers, sml loss\n",
    "        net_params['n_output']       = 1    \n",
    "        train_params['drop_bool']    = True\n",
    "        train_params['drop_bool_ll'] = True\n",
    "        #train_params['num_epochs']   = 2000\n",
    "        train_params['loss_func']    = sml_loss\n",
    "        train_params['sml_loss_params']= [1,0.75,0]\n",
    "\n",
    "        net = Net(net_params=net_params,train_params=train_params)\n",
    "        net.to(train_params['device'])\n",
    "        return net\n",
    "    \n",
    "    elif method == 'mc_mod_sml25':  # dropout in all layers, sml loss\n",
    "        net_params['n_output']       = 1    \n",
    "        train_params['drop_bool']    = True\n",
    "        train_params['drop_bool_ll'] = True\n",
    "        #train_params['num_epochs']   = 2000\n",
    "        train_params['loss_func']    = sml_loss\n",
    "        train_params['sml_loss_params']= [1,0.25,0]\n",
    "\n",
    "        net = Net(net_params=net_params,train_params=train_params)\n",
    "        net.to(train_params['device'])\n",
    "        return net\n",
    "    \n",
    "    elif method == 'mc_mod_sml1':  # dropout in all layers, sml loss\n",
    "        net_params['n_output']       = 1    \n",
    "        train_params['drop_bool']    = True\n",
    "        train_params['drop_bool_ll'] = True\n",
    "        #train_params['num_epochs']   = 2000\n",
    "        train_params['loss_func']    = sml_loss\n",
    "        train_params['sml_loss_params']= [1,0.1,0]\n",
    "\n",
    "        net = Net(net_params=net_params,train_params=train_params)\n",
    "        net.to(train_params['device'])\n",
    "        return net\n",
    "    \n",
    "    elif method == 'mc_mod_sml9':  # dropout in all layers, sml loss\n",
    "        net_params['n_output']       = 1    \n",
    "        train_params['drop_bool']    = True\n",
    "        train_params['drop_bool_ll'] = True\n",
    "        #train_params['num_epochs']   = 2000\n",
    "        train_params['loss_func']    = sml_loss\n",
    "        train_params['sml_loss_params'] = [1,0.9,0]\n",
    "\n",
    "        net = Net(net_params=net_params,train_params=train_params)\n",
    "        net.to(train_params['device'])\n",
    "        return net\n",
    "\n",
    "    elif method == 'mc_mod_2moments': # dropout in all layers, 2 moments loss\n",
    "        net_params['n_output']       = 1    \n",
    "        train_params['drop_bool']    = True\n",
    "        train_params['drop_bool_ll'] = True\n",
    "        #train_params['num_epochs']   = 2000\n",
    "        train_params['loss_func']    = train_second_moments_loss\n",
    "        train_params['sml_loss_params'] = [1,0.5,0]\n",
    "\n",
    "        net = Net(net_params=net_params,train_params=train_params)\n",
    "        net.to(train_params['device'])\n",
    "        return net\n",
    "\n",
    "    elif method == 'pu':    # trains mu, sigma uses nll loss\n",
    "        net_params['n_output']       = 2\n",
    "        train_params['drop_bool']    = False\n",
    "        train_params['drop_bool_ll'] = False\n",
    "        train_params['loss_func']    = nll_floored\n",
    "        #train_params['num_epochs']   = 2000\n",
    "        train_params['sml_loss_params'] = None\n",
    "\n",
    "        net = Net_PU(net_params=net_params,train_params=train_params)\n",
    "        net.to(train_params['device'])\n",
    "        return net\n",
    "        \n",
    "    elif method == 'de':\n",
    "        net_params['n_output']       = 1\n",
    "        train_params['drop_bool']    = False\n",
    "        train_params['drop_bool_ll'] = False\n",
    "        train_params['loss_func']    = torch.nn.MSELoss(reduction='mean')\n",
    "        #train_params['num_epochs']   = 2000\n",
    "        train_params['sml_loss_params'] = None\n",
    "        \n",
    "        net = []\n",
    "        for i in range(net_params['de_components']):\n",
    "            net_ = Net(net_params=net_params,train_params=train_params)\n",
    "            net_.to(train_params['device'])\n",
    "            net.append(net_)\n",
    "            \n",
    "    elif method == 'pu_de':  \n",
    "        net_params['n_output']       = 2\n",
    "        train_params['drop_bool']    = False\n",
    "        train_params['drop_bool_ll'] = False\n",
    "        train_params['loss_func']    = nll_floored\n",
    "        #train_params['num_epochs']   = 2000\n",
    "        train_params['sml_loss_params'] = None\n",
    "\n",
    "        net = []\n",
    "        for i in range(net_params['de_components']):\n",
    "            net_ = Net_PU(net_params=net_params,train_params=train_params)\n",
    "            net_.to(train_params['device'])\n",
    "            net.append(net_)\n",
    "\n",
    "    elif method == 'sml_de':  \n",
    "        net_params['n_output']       = 1\n",
    "        train_params['drop_bool']    = True\n",
    "        train_params['drop_bool_ll'] = True\n",
    "        train_params['loss_func']    = sml_loss\n",
    "        #train_params['num_epochs']   = 2000\n",
    "        train_params['sml_loss_params'] = [1,0.5,0]\n",
    "\n",
    "        net = []\n",
    "        for i in range(net_params['de_components']):\n",
    "            net_ = Net(net_params=net_params,train_params=train_params)\n",
    "            net_.to(train_params['device'])\n",
    "            net.append(net_)\n",
    "            \n",
    "    return net"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\" Timestamp of the format: hour:minute:second \"\"\"                  \n",
    "def timestamp(dt_obj):\n",
    "    return \"%d_%d_%d_%d_%d_%d\" % (dt_obj.year, dt_obj.month, dt_obj.day, dt_obj.hour, dt_obj.minute, dt_obj.second)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "available_datasets = {'boston', 'concrete', 'energy', 'abalone', 'naval', \n",
    "                      'power', 'protein', 'wine_red', 'yacht', 'year', \n",
    "                      'california', 'diabetes', 'superconduct', 'toy', 'toy_noise',\n",
    "                     'toy_uniform', 'toy_modulated', 'toy_noise_strong', 'toy_hf'}\n",
    "\n",
    "toy_datasets = {'toy', 'toy_hf', 'toy_noise', 'toy_uniform', 'toy_modulated', 'toy_noise_strong'}\n",
    "small_datasets = {'toy', 'toy_hf', 'toy_noise', 'toy_uniform', 'toy_modulated', 'toy_noise_strong',\n",
    "                  'yacht', 'diabetes', 'boston', 'energy', 'concrete', 'wine_red'}\n",
    "large_datasets = {'abalone', 'naval', 'power', 'superconduct', 'protein'} #'california',\n",
    "very_large_datasets = {'year'}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "available_splits = {'random_folds', 'single_random_split', 'single_label_split', 'label_folds', 'single_pca_split', 'pca_folds'}\n",
    "\n",
    "available_methods = {'de','pu','mc_mod_sml','mc_ll','mc', 'mc_mod_sml9','pu_de','sml_de'}\n",
    "\n",
    "# Base parameters\n",
    "net_params = {'n_output':1,\n",
    "            'layer_width':50,\n",
    "            'num_layers':2,\n",
    "            'nonlinearity':nn.ReLU(), #tanh,sigmoid\n",
    "            'init_corrcoef':0.0,\n",
    "            'de_components': 5} \n",
    "\n",
    "train_params = {'device': 'cpu', #torch.device('cuda' if torch.cuda.is_available() else 'cpu'),\n",
    "              'drop_bool':True,\n",
    "              'drop_bool_ll':True,\n",
    "              'drop_p':0.1,\n",
    "              'num_epochs': 1000,\n",
    "              'batch_size': 100,\n",
    "              'learning_rate': 0.001,\n",
    "              'loss_func':torch.nn.MSELoss(reduction='mean'),\n",
    "              'sml_loss_params':[1,0.5,0]}\n",
    "\n",
    "dt_now = dt.datetime.now()\n",
    "exp_dir = './INSERT/PATH/TO/EXP/DIR/HERE/NAME_%s' % timestamp(dt_now)\n",
    "os.makedirs(exp_dir, exist_ok=True)\n",
    "\n",
    "methods = available_methods # use all methods (de, pu, mc, ...)\n",
    "\n",
    "start_ = t.time()\n",
    "\n",
    "for dataset_id in available_datasets: # use all datasets\n",
    "    \n",
    "    X, y = load_dataset(dataset_id)\n",
    "    n_feat = X.shape[1]\n",
    "    \n",
    "    net_params_ = dict(net_params)\n",
    "    train_params_ = dict(train_params)\n",
    "    \n",
    "    if dataset_id in very_large_datasets:\n",
    "        splits = compute_idx_splits(X, y, fold_idxs=[0, 3, 5, 7, 9], splits=['random_folds', 'label_folds', 'pca_folds'])\n",
    "        train_params_['num_epochs'] = 150\n",
    "        train_params_['batch_size'] = 500\n",
    "    \n",
    "    elif dataset_id in large_datasets:\n",
    "        splits = compute_idx_splits(X, y, fold_idxs=[0, 3, 5, 7, 9], splits=available_splits)\n",
    "        train_params_['num_epochs'] = 150 \n",
    "    else:\n",
    "        splits = compute_idx_splits(X, y, splits=available_splits) # use 10-folds\n",
    "    \n",
    "    projections = splits['projections']\n",
    "    \n",
    "    for split_mode in splits:\n",
    "        \n",
    "        if split_mode == 'projections':\n",
    "            continue\n",
    "        \n",
    "        folds = splits[split_mode]\n",
    "        \n",
    "        if (type(folds) == tuple) and (len(folds) == 2):\n",
    "            folds = [folds]\n",
    "        \n",
    "        for fold_idx, (train_idxs, test_idxs) in enumerate(folds):\n",
    "            \n",
    "            identifier = 'dataset=%s_splitmode=%s_foldidx=%d' % (dataset_id, split_mode, fold_idx)\n",
    "            \n",
    "            X_train = X[train_idxs]\n",
    "            X_test = X[test_idxs]\n",
    "            y_train = y[train_idxs]\n",
    "            y_test = y[test_idxs]\n",
    "            \n",
    "            X_train, y_train, X_test, y_test = scale_to_standard(X_train, y_train, X_test, y_test)\n",
    "            \n",
    "            # choose a bunch of uncertainty methods and train the respective models \n",
    "            method_dict = {}\n",
    "            method_dict_json = {}\n",
    "            net_dict = {}\n",
    "            for method in methods: \n",
    "                method_identifier = '%s_method=%s' % (identifier, method)\n",
    "                print(method_identifier)\n",
    "\n",
    "                net = get_net_from_method(method, n_feat, net_params_, train_params_)\n",
    "                print(net_params_, train_params_)\n",
    "                train_network(net=net,data=[X_train,y_train], train_params=train_params_, method=method)\n",
    "\n",
    "                iso_reg = []\n",
    "                df_train = calc_datapoint_statistics(net=net,data=[X_train,y_train],method=method, iso_reg=iso_reg)\n",
    "                df_test  = calc_datapoint_statistics(net=net,data=[X_test,y_test],  method=method, iso_reg=iso_reg)\n",
    "\n",
    "                df_train['pca0_projection'] = projections[train_idxs] \n",
    "                df_test['pca0_projection']  = projections[test_idxs] \n",
    "\n",
    "                method_dict[method] = [df_train,df_test]\n",
    "                method_dict_json[method] = [df_train.to_json(), df_test.to_json()]\n",
    "            \n",
    "                # store model\n",
    "                if isinstance(net, list):\n",
    "                    for i, subnet in enumerate(net):\n",
    "                        net_dict['%s_sub=%d' % (method_identifier, i)] = copy.deepcopy(subnet.state_dict())\n",
    "                else:\n",
    "                    net_dict[method_identifier] = copy.deepcopy(net.state_dict())\n",
    "            \n",
    "            exp_dataset_dir = '%s/%s' % (exp_dir, dataset_id)\n",
    "            os.makedirs(exp_dataset_dir, exist_ok=True)\n",
    "            \n",
    "            with gzip.open('%s/method_dict_%s.json.zip' % (exp_dataset_dir, identifier), 'wt', encoding='ascii') as fp:\n",
    "                json.dump(method_dict_json, fp)\n",
    "                \n",
    "            plot_results(method_dict, '%s/%s.png' % (exp_dataset_dir, identifier))\n",
    "            \n",
    "            # print global statistics for the different methods (for both train and test)\n",
    "            global_stats = {}\n",
    "            for method in method_dict:\n",
    "                global_stats[method] = []\n",
    "                for i in [0, 1]:\n",
    "                    global_stats[method].append(calc_global_statistics(method_dict[method][i]))\n",
    "            \n",
    "            with gzip.open('%s/global_stats_%s.json.zip' % (exp_dataset_dir, identifier), 'wt', encoding='ascii') as fp:\n",
    "                json.dump(global_stats, fp)\n",
    "                \n",
    "            torch.save(net_dict, '%s/model_%s.pt' % (exp_dataset_dir, identifier))\n",
    "\n",
    "print(t.time() - start_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:.conda-torch-env]",
   "language": "python",
   "name": "conda-env-.conda-torch-env-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
