{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Importing Libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch as th\n",
    "from torch.utils.data import DataLoader\n",
    "from torch.utils import data\n",
    "from torchvision import datasets, transforms\n",
    "import torch.nn.functional as F\n",
    "from torch import nn\n",
    "from torch.nn.parameter import Parameter\n",
    "from torch.distributions import Dirichlet, Normal\n",
    "from torch.distributions.kl import kl_divergence\n",
    "import numpy as np\n",
    "import math\n",
    "from copy import deepcopy\n",
    "\n",
    "\n",
    "batch_size = 128\n",
    "max_epochs = 50"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Data Preparing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def prepare_data_fashion(batch_size):\n",
    "    kwargs = {'num_workers': 20, 'pin_memory': True}\n",
    "\n",
    "    trafos = transforms.Compose([\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize(mean=(0.2860,), std=(0.3530,))\n",
    "    ])\n",
    "\n",
    "    datadir = \"./data/\"\n",
    "\n",
    "    train_data = datasets.FashionMNIST(datadir + \"fashion\", train=True, download=True, transform=trafos)\n",
    "    test_data = datasets.FashionMNIST(datadir + \"fashion\", train=False, download=True, transform=trafos)\n",
    "    ood_data = datasets.MNIST(datadir + \"mnist\", train=False, download=True, transform=trafos)\n",
    "\n",
    "    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last=True, **kwargs)\n",
    "    test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, drop_last=True, **kwargs)\n",
    "    ood_loader = DataLoader(ood_data, batch_size=batch_size, shuffle=False, drop_last=True, **kwargs)\n",
    "\n",
    "    return train_loader, test_loader, ood_loader\n",
    "\n",
    "\n",
    "\n",
    "def prepare_data(batch_size):\n",
    "    train_loader, test_loader, ood_loader = prepare_data_fashion(\n",
    "        batch_size=batch_size)\n",
    "    n_channel = 1\n",
    "    n_classes = 10\n",
    "    \n",
    "    return train_loader, test_loader, ood_loader, n_channel, n_classes\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Architecture"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class VBLinear(nn.Module):\n",
    "    def __init__(self, in_features, out_features, prior_prec=10, map=True):\n",
    "        super(VBLinear, self).__init__()\n",
    "        self.n_in = in_features\n",
    "        self.n_out = out_features\n",
    "\n",
    "        self.prior_prec = prior_prec\n",
    "        self.map = map\n",
    "\n",
    "        self.bias = nn.Parameter(th.Tensor(out_features))\n",
    "        self.mu_w = Parameter(th.Tensor(out_features, in_features))\n",
    "        self.logsig2_w = nn.Parameter(th.Tensor(out_features, in_features))\n",
    "        self.reset_parameters()\n",
    "\n",
    "    def reset_parameters(self):\n",
    "        stdv = 1. / math.sqrt(self.mu_w.size(1))\n",
    "        self.mu_w.data.normal_(0, stdv)\n",
    "        self.logsig2_w.data.zero_().normal_(-9, 0.001) # var init via Louizos\n",
    "        self.bias.data.zero_()\n",
    "\n",
    "    def KL(self, loguniform=False):\n",
    "        if loguniform:\n",
    "            k1 = 0.63576; k2 = 1.87320; k3 = 1.48695\n",
    "            log_alpha = self.logsig2_w - 2 * th.log(self.mu_w.abs() + 1e-8)\n",
    "            kl = -th.sum(k1 * th.sigmoid(k2 + k3 * log_alpha) - 0.5 * F.softplus(-log_alpha) - k1)\n",
    "        else:\n",
    "            logsig2_w = self.logsig2_w.clamp(-11, 11)\n",
    "            kl = 0.5 * (self.prior_prec * (self.mu_w.pow(2) + logsig2_w.exp()) - logsig2_w - 1 - np.log(self.prior_prec)).sum()\n",
    "        return kl\n",
    "\n",
    "    def forward(self, input):\n",
    "        # Sampling free forward pass only if MAP prediction and no training rounds\n",
    "        if self.map and not self.training:\n",
    "            return F.linear(input, self.mu_w, self.bias)\n",
    "        else:\n",
    "            mu_out = F.linear(input, self.mu_w, self.bias)\n",
    "            logsig2_w = self.logsig2_w.clamp(-11, 11)\n",
    "            s2_w = logsig2_w.exp()\n",
    "            var_out = F.linear(input.pow(2), s2_w) + 1e-8\n",
    "            return mu_out + var_out.sqrt() * th.randn_like(mu_out)\n",
    "\n",
    "\n",
    "class LeNet5(nn.Module):\n",
    "    def __init__(self, n_channels=1, n_classes=10, isbnn=False, prior_precision=10):\n",
    "        \"\"\"\n",
    "        :param n_channels: 1 creates MNIST arch, 3 creates Cifar arch\n",
    "        :param n_classes: 10 target classes\n",
    "        \"\"\"\n",
    "        super(LeNet5, self).__init__()\n",
    "\n",
    "        self.isbnn = isbnn\n",
    "        self.drop_rate = 0.25\n",
    "        self.n_samples = 10\n",
    "\n",
    "        self.n_channels = n_channels\n",
    "        self.n_classes = n_classes\n",
    "        self.dataset_size = None # will be initialized\n",
    "\n",
    "        if n_channels == 1:\n",
    "            self.conv1 = nn.Conv2d(1, 20, 5, bias=True)\n",
    "            self.conv2 = nn.Conv2d(20, 50, 5, bias=True)\n",
    "            dim_cf = 4 * 4 * 50\n",
    "            if isbnn:\n",
    "                self.fc1 = VBLinear(dim_cf, 500, prior_prec=prior_precision)\n",
    "                self.fc2 = VBLinear(500, n_classes, prior_prec=prior_precision)\n",
    "            else:\n",
    "                self.fc1 = nn.Linear(dim_cf, 500, bias=True)\n",
    "                self.fc2 = nn.Linear(500, n_classes)\n",
    "\n",
    "        elif n_channels == 3:\n",
    "            self.conv1 = nn.Conv2d(3, 192, 5, bias=True)\n",
    "            self.conv2 = nn.Conv2d(192, 192, 5, bias=True)\n",
    "\n",
    "            dim_cf = 5 * 5 * 192\n",
    "\n",
    "            if isbnn:\n",
    "                self.fc1 = VBLinear(dim_cf, 1000, prior_prec=prior_precision)\n",
    "                self.fc2 = VBLinear(1000, n_classes, prior_prec=prior_precision)\n",
    "            else:\n",
    "                self.fc1 = nn.Linear(dim_cf, 1000, bias=True)\n",
    "                self.fc2 = nn.Linear(1000, n_classes)\n",
    "        else:\n",
    "            raise NotImplementedError(f\"Sorry {n_channels} is currently not possible\")\n",
    "\n",
    "    def forward(self, input, context=None):\n",
    "\n",
    "        out = F.relu(self.conv1(input))\n",
    "        out = F.max_pool2d(out, 2)\n",
    "        out = F.relu(self.conv2(out))\n",
    "        out = F.max_pool2d(out, 2)\n",
    "        out = th.flatten(out, 1)\n",
    "\n",
    "        out = F.relu(self.fc1(out))\n",
    "        out = self.fc2(out)\n",
    "        return out\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ETP-BETP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ETPHyperParams:\n",
    "    def __init__(self, n_classes=None):\n",
    "        super(ETPHyperParams, self).__init__()\n",
    "        self.memory_learning_rate = 0.99\n",
    "        self.memory_size = 20\n",
    "        self.anneal_factor = 1e-3\n",
    "        self.memo_variance = 0.1\n",
    "        \n",
    "\n",
    "\n",
    "class EvidentialTuringProcess(nn.Module):\n",
    "    def __init__(self, arch=None):\n",
    "        super(EvidentialTuringProcess, self).__init__()\n",
    "\n",
    "        self.arch = deepcopy(arch)\n",
    "        self.hyperparams = ETPHyperParams(self.arch.n_classes)\n",
    "\n",
    "        self.isbnn = self.arch.isbnn\n",
    "\n",
    "        x_dim = [28, 32][self.arch.n_channels == 3];\n",
    "        y_dim = x_dim\n",
    "        self.x_dim = x_dim;\n",
    "        self.y_dim = y_dim;\n",
    "        self.n_channels = self.arch.n_channels\n",
    "\n",
    "        self.memory = nn.Parameter(th.Tensor(self.hyperparams.memory_size, self.arch.n_classes),\n",
    "                                   requires_grad=False).cuda()\n",
    "        self.memory.data.normal_(0, 0.01)\n",
    "        self.memory.data.pow_(2)\n",
    "\n",
    "        self.fc1_enc_to_pred = nn.Linear(self.arch.n_classes * 2, self.arch.n_classes)\n",
    "        self.fc1_key = nn.Linear(self.arch.n_classes, self.arch.n_classes)\n",
    "\n",
    "\n",
    "    def KL(self):\n",
    "        return sum(l.KL() for l in [self.arch.parameters()] if hasattr(l, \"KL\"))\n",
    "\n",
    "\n",
    "    def update_memory(self, x_embed, y, max_size=50):\n",
    "        n_context = np.random.randint(3, max_size)\n",
    "        x_given_embed = x_embed[:n_context, :]\n",
    "        y_given = y[:n_context].view(-1, 1)\n",
    "\n",
    "        mem_sample = self.get_memory_sample()\n",
    "\n",
    "        new_element = F.one_hot(y_given, self.arch.n_classes).view(-1, self.arch.n_classes) + th.softmax(x_given_embed,\n",
    "                                                                                                         1)\n",
    "        weight_new_element = self.get_attention_weights(x_given_embed, mem_sample)\n",
    "        add_new_element = th.mm(weight_new_element.transpose(0, 1), new_element)\n",
    "        gamma = self.hyperparams.memory_learning_rate\n",
    "\n",
    "        mem_offset = self.memory * (gamma - 1) + add_new_element * (1 - gamma)\n",
    "\n",
    "        self.memory.data.add_(mem_offset)\n",
    "        self.memory.data.tanh_()\n",
    "\n",
    "    def get_memory_sample(self):\n",
    "        sig2 = self.hyperparams.memo_variance\n",
    "        sig2_vec = th.ones(self.memory.shape).cuda() * sig2\n",
    "        return Normal(self.memory, sig2_vec).rsample()\n",
    "\n",
    "    def get_attention_weights(self, x_embed, mem_sample):\n",
    "        keys = self.fc1_key(mem_sample)\n",
    "        kq = th.mm(x_embed, keys.transpose(0, 1)) / np.sqrt(self.arch.n_classes)\n",
    "        return F.softmax(kq, 1)\n",
    "\n",
    "    def get_attention(self, x_embed):\n",
    "        mem_sample = self.get_memory_sample()\n",
    "        weights = self.get_attention_weights(x_embed, mem_sample)\n",
    "        return th.mm(weights, mem_sample)\n",
    "\n",
    "    def ood_predict(self, data, logits):\n",
    "        probs = sum(F.softmax(logits, 1) for _ in range(self.arch.n_samples)) / self.arch.n_samples\n",
    "        entropy_of_exp = -th.sum(probs * th.log(probs + 1e-8), axis=1)\n",
    "        alpha = th.exp(logits)\n",
    "        S = alpha.sum(1, keepdims=True)\n",
    "        expected_entropy = -th.sum((alpha / S) * (th.digamma(alpha + 1) - th.digamma(S + 1.0)), axis=1)\n",
    "        return entropy_of_exp # - expected_entropy\n",
    "\n",
    "    def predict(self, data):\n",
    "        x_embed = self.forward(data)\n",
    "        alpha = th.exp(x_embed)\n",
    "        S = alpha.sum(1, keepdims=True)\n",
    "        probs = alpha / S\n",
    "        classes = probs.max(1, keepdim=True)[1]\n",
    "        return classes, probs\n",
    "\n",
    "    def forward(self, input):\n",
    "        x_embed = self.arch(input)\n",
    "        attention = self.get_attention(x_embed)\n",
    "        logit = self.fc1_enc_to_pred(th.cat((x_embed, attention), dim=1))\n",
    "        return logit.clamp(max=15)\n",
    "\n",
    "    def kl_dirichlet(self, alpha, beta):\n",
    "        q = Dirichlet(alpha)\n",
    "        p = Dirichlet(beta)\n",
    "        return kl_divergence(q, p)\n",
    "\n",
    "    def loss(self, x, y, epoch):\n",
    "        y_one_hot = F.one_hot(y, self.arch.n_classes).view(-1, self.arch.n_classes)\n",
    "        x_embed_pre = self.arch(x)\n",
    "        attention = self.get_attention(x_embed_pre)\n",
    "        logit = self.fc1_enc_to_pred(th.cat((x_embed_pre, attention), dim=1))\n",
    "        x_embed = logit.clamp(max=15)\n",
    "\n",
    "        self.update_memory(x_embed_pre, y)\n",
    "\n",
    "        alpha = th.exp(x_embed)\n",
    "        S = alpha.sum(1, keepdims=True)\n",
    "        fit_term = (y_one_hot * (th.digamma(S + 1e-8) - th.digamma(alpha + 1e-8))).sum(axis=1)\n",
    "        reg_term = self.kl_dirichlet(alpha, th.exp(attention))\n",
    "        loss = (fit_term + reg_term * self.hyperparams.anneal_factor).mean()\n",
    "        \n",
    "        if self.isbnn:\n",
    "            loss += self.KL()\n",
    "        \n",
    "        return loss\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(model, train_loader, ood_loader, epoch, opt, scheduler, n_classes=10, lrate=1e-3, verbose=False):\n",
    "    model.train()\n",
    "    total_loss = 0.0\n",
    "    for (data, target) in (train_loader):\n",
    "        data, target = data.cuda(), target.cuda()\n",
    "        opt.zero_grad()\n",
    "    \n",
    "        loss = model.loss(data, target, epoch)\n",
    "\n",
    "        loss.backward()\n",
    "        total_loss += loss.item()\n",
    "        opt.step()\n",
    "\n",
    "    if verbose:\n",
    "        print(f\"{epoch}: AvgLoss = {total_loss / len(train_loader):.010f}\")\n",
    "\n",
    "    return opt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Eval"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def err(preds, target, minibatch=True):\n",
    "    preds = preds.argmax(1)\n",
    "    target = target.argmax(1)\n",
    "    if minibatch:\n",
    "        return ((preds != target) * 1.0).sum() * 100\n",
    "    else:\n",
    "        return ((preds != target) * 1.0).mean() * 100\n",
    "\n",
    "    \n",
    "def nll(preds, target, minibatch=True):\n",
    "    logpred = th.log(preds + 1e-8)\n",
    "    if minibatch:\n",
    "        return -(logpred * target).sum(1)\n",
    "    else:\n",
    "        return -(logpred * target).sum(1).mean()\n",
    "\n",
    "\n",
    "def ece(preds, target, minibatch=True):\n",
    "    confidences, predictions = th.max(preds, 1)\n",
    "    _, target_cls = th.max(target, 1)\n",
    "    accuracies = predictions.eq(target_cls)\n",
    "    n_bins = 100  # 30000\n",
    "    bin_boundaries = th.linspace(0, 1, n_bins + 1)\n",
    "    bin_lowers = bin_boundaries[:-1]\n",
    "    bin_uppers = bin_boundaries[1:]\n",
    "\n",
    "    ece = th.zeros(1, device=\"cuda\")\n",
    "    for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):\n",
    "        # Calculated |confidence - accuracy| in each bin\n",
    "        in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())\n",
    "        prop_in_bin = in_bin.float().mean()\n",
    "        if prop_in_bin.item() > 0:\n",
    "            accuracy_in_bin = accuracies[in_bin].float().mean()\n",
    "            avg_confidence_in_bin = confidences[in_bin].mean()\n",
    "            ece += th.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin * 100\n",
    "\n",
    "    return ece.item()\n",
    "    \n",
    "def compute_score(model, loader):\n",
    "    model.eval()\n",
    "    preds = []\n",
    "    targets = []\n",
    "    with th.no_grad():\n",
    "        for data, target in loader:\n",
    "            data, target = data.cuda(), target.cuda()\n",
    "\n",
    "            target = F.one_hot(target.long(), model.arch.n_classes)\n",
    "            _, pred_prob = model.predict(data)\n",
    "            preds.append(pred_prob)\n",
    "            targets.append(target)\n",
    "\n",
    "    targets = th.cat(targets).cuda()\n",
    "    preds = th.cat(preds).cuda()\n",
    "    _err = err(preds, targets, minibatch=False).item()\n",
    "    _nll = nll(preds, targets, minibatch=False).item()\n",
    "    _ece = ece(preds, targets, minibatch=False)\n",
    "\n",
    "    print(f\"Error: {_err:.5f}, NLL {_nll:.5f}, ECE {_ece:.5f}\")\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader, test_loader, ood_loader, n_channel, n_classes = prepare_data(batch_size)\n",
    "\n",
    "arch = LeNet5(isbnn=True) # bnn\n",
    "\n",
    "arch.dataset_size = len(train_loader.dataset)\n",
    "\n",
    "model = EvidentialTuringProcess(arch)\n",
    "\n",
    "lrate = 1e-3\n",
    "\n",
    "print(\"Training: \")\n",
    "model.cuda()\n",
    "\n",
    "opt = th.optim.Adam(model.parameters(), lr=lrate)\n",
    "scheduler = th.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=max_epochs)\n",
    "for epoch in range(1, max_epochs + 1):\n",
    "    opt = train(model, train_loader, ood_loader, epoch, opt, scheduler, n_classes=n_classes, lrate=lrate, verbose=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "compute_score(model, test_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
