{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "DUQ_CIFAR10_final.ipynb",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "code",
      "metadata": {
        "id": "KooQVzCOXU-V"
      },
      "source": [
        "!pip install pytorch-ignite"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "hM66GgdYaIu7"
      },
      "source": [
        "import argparse\n",
        "import json\n",
        "import pathlib\n",
        "import random\n",
        "\n",
        "import torch\n",
        "import torch.nn.functional as F\n",
        "import torch.utils.data\n",
        "\n",
        "from ignite.engine import Events, Engine\n",
        "from ignite.metrics import Accuracy, Average, Loss\n",
        "from ignite.contrib.handlers import ProgressBar\n",
        "\n",
        "from utils.resnet_duq import ResNet_DUQ\n",
        "from utils.datasets import all_datasets\n",
        "from utils.evaluate_ood import get_cifar_svhn_ood, get_auroc_classification"
      ],
      "execution_count": 18,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "iD-Lt3L_SVwP"
      },
      "source": [
        "model={}\n",
        "results=[]\n",
        "def main(\n",
        "    batch_size,\n",
        "    epochs,\n",
        "    length_scale,\n",
        "    centroid_size,\n",
        "    model_output_size,\n",
        "    learning_rate,\n",
        "    l_gradient_penalty,\n",
        "    gamma,\n",
        "    weight_decay,\n",
        "    final_model, \n",
        "    input_dep_ls,\n",
        "    use_grad_norm\n",
        "):\n",
        "    \n",
        "    # Dataset prep\n",
        "    ds = all_datasets[\"CIFAR10\"]()\n",
        "    input_size, num_classes, dataset, test_dataset = ds\n",
        "\n",
        "    # Split up training set\n",
        "    idx = list(range(len(dataset)))\n",
        "    random.shuffle(idx)\n",
        "\n",
        "    if final_model:\n",
        "        train_dataset = dataset\n",
        "        val_dataset = test_dataset\n",
        "    else:\n",
        "        val_size = int(len(dataset) * 0.8)\n",
        "        train_dataset = torch.utils.data.Subset(dataset, idx[:val_size])\n",
        "        val_dataset = torch.utils.data.Subset(dataset, idx[val_size:])\n",
        "\n",
        "        val_dataset.transform = (\n",
        "            test_dataset.transform\n",
        "        ) \n",
        "    kwargs = {\"num_workers\": 4, \"pin_memory\": True}\n",
        "    train_loader = torch.utils.data.DataLoader(\n",
        "        train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, **kwargs\n",
        "    )\n",
        "    \n",
        "    # Model\n",
        "    global model\n",
        "    model = ResNet_DUQ(\n",
        "        input_size, num_classes, centroid_size, model_output_size, length_scale, gamma\n",
        "    )\n",
        "    \n",
        "    model = model.cuda()\n",
        "    #model.load_state_dict(torch.load(\"DUQ_CIFAR_75.pt\"))\n",
        "\n",
        "    # Optimiser\n",
        "    optimizer = torch.optim.SGD(\n",
        "        model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=weight_decay\n",
        "    )\n",
        "\n",
        "    scheduler = torch.optim.lr_scheduler.MultiStepLR(\n",
        "        optimizer, milestones=[25, 50, 75], gamma=0.2\n",
        "    )\n",
        "\n",
        "    def bce_loss_fn(y_pred, y):\n",
        "        bce = F.binary_cross_entropy(y_pred, y, reduction=\"sum\").div(\n",
        "            num_classes * y_pred.shape[0]\n",
        "        )\n",
        "        return bce\n",
        "\n",
        "    def output_transform_bce(output):\n",
        "        y_pred, y, x = output\n",
        "\n",
        "        y = F.one_hot(y, num_classes).float()\n",
        "\n",
        "        return y_pred, y\n",
        "\n",
        "    def output_transform_acc(output):\n",
        "        y_pred, y, x = output\n",
        "\n",
        "        return y_pred, y\n",
        "\n",
        "    def output_transform_gp(output):\n",
        "        y_pred, y, x = output\n",
        "\n",
        "        return x, y_pred\n",
        "\n",
        "    def calc_gradients_input(x, y_pred):\n",
        "        gradients = torch.autograd.grad(\n",
        "            outputs=y_pred,\n",
        "            inputs=x,\n",
        "            grad_outputs=torch.ones_like(y_pred),\n",
        "            create_graph=True,\n",
        "        )[0]\n",
        "\n",
        "        gradients = gradients.flatten(start_dim=1)\n",
        "\n",
        "        return gradients\n",
        "\n",
        "    def calc_gradient_penalty(x, y_pred):\n",
        "        gradients = calc_gradients_input(x, y_pred)\n",
        "\n",
        "        # L2 norm\n",
        "        grad_norm = gradients.norm(2, dim=1)\n",
        "\n",
        "        # Two sided penalty\n",
        "        gradient_penalty = ((grad_norm - 1) ** 2).mean()\n",
        "\n",
        "        return gradient_penalty\n",
        "\n",
        "    def step(engine, batch):\n",
        "        model.train()\n",
        "\n",
        "        optimizer.zero_grad()\n",
        "\n",
        "        x, y = batch\n",
        "        x, y = x.cuda(), y.cuda()\n",
        "\n",
        "        if l_gradient_penalty > 0:\n",
        "            x.requires_grad_(True)\n",
        "\n",
        "        z, y_pred = model(x)\n",
        "        y = F.one_hot(y, num_classes).float()\n",
        "\n",
        "        loss = bce_loss_fn(y_pred, y)\n",
        "\n",
        "        # Avoid calc of computing\n",
        "        if l_gradient_penalty > 0:\n",
        "            loss += l_gradient_penalty * calc_gradient_penalty(x, y_pred)\n",
        "\n",
        "        if use_grad_norm:\n",
        "            #gradient normalization\n",
        "            loss/=(1+l_gradient_penalty)\n",
        "\n",
        "        loss.backward()\n",
        "        optimizer.step()\n",
        "\n",
        "        x.requires_grad_(False)\n",
        "\n",
        "        with torch.no_grad():\n",
        "            model.eval()\n",
        "            model.update_embeddings(x, y)\n",
        "\n",
        "        return loss.item()\n",
        "\n",
        "\n",
        "    trainer = Engine(step)\n",
        "\n",
        "    @trainer.on(Events.EPOCH_COMPLETED)\n",
        "    def log_results(trainer):\n",
        "\n",
        "        # logging every 10 epoch or last epochs\n",
        "        if trainer.state.epoch % 10 == 0 or trainer.state.epoch > epochs-5:\n",
        "          \n",
        "            #acc on cifar test set and auroc on cifar+svhn testsets\n",
        "            testacc,auroc_cifsv = get_cifar_svhn_ood(model)\n",
        "          \n",
        "            #acc on cifar val set and self auroc on cifar valset\n",
        "            val_acc, self_auroc = get_auroc_classification(val_dataset, model)\n",
        "\n",
        "            print(f\"Test Accuracy: {testacc}, AUROC: {auroc_cifsv}\")\n",
        "            print(f\"AUROC - uncertainty: {self_auroc}, Val Accuracy : {val_acc}\")\n",
        "\n",
        "        scheduler.step()\n",
        "\n",
        "        # save\n",
        "        if trainer.state.epoch == epochs-1:\n",
        "            torch.save(\n",
        "                model.state_dict(), f\"model_{trainer.state.epoch}.pt\"\n",
        "            )\n",
        "\n",
        "    pbar = ProgressBar(dynamic_ncols=True)\n",
        "    pbar.attach(trainer)\n",
        "    trainer.run(train_loader, max_epochs=epochs)\n",
        "    \n",
        "    testacc,auroc_cifsv = get_cifar_svhn_ood(model)\n",
        "    val_acc, self_auroc = get_auroc_classification(val_dataset, model)\n",
        "\n",
        "    return testacc,auroc_cifsv,val_acc, self_auroc"
      ],
      "execution_count": 19,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "gj24X-O9SYsK"
      },
      "source": [
        "if __name__ == \"__main__\":\n",
        "\n",
        "    final_model = True  # Train on full train dataset \n",
        "    input_dep_ls = False #input dependent length scale (sigma)\n",
        "    use_grad_norm = False #gradient normalization\n",
        "    \n",
        "    repitition = 3\n",
        "    test_acc_l=[]\n",
        "    auroc_cifsv_l=[]\n",
        "    val_acc_l=[]\n",
        "    self_auroc_l=[]\n",
        "    \n",
        "    for i in range(repitition):\n",
        "        print(f\"run {i}\\n\")\n",
        "\n",
        "        testacc,auroc_cifsv,val_acc, self_auroc= main(batch_size=128, \n",
        "              epochs=75,\n",
        "              length_scale=0.1,\n",
        "              centroid_size=512,\n",
        "              model_output_size=512,\n",
        "              learning_rate=0.05,\n",
        "              l_gradient_penalty=0,         \n",
        "              gamma=0.999,\n",
        "              weight_decay=5e-4,\n",
        "              final_model=final_model,\n",
        "              input_dep_ls = input_dep_ls, \n",
        "              use_grad_norm=use_grad_norm)\n",
        "        test_acc_l.append(testacc)\n",
        "        auroc_cifsv_l.append(auroc_cifsv)\n",
        "        val_acc_l.append(val_acc)\n",
        "        self_auroc_l.append(self_auroc)\n",
        "    \n",
        "    print([\n",
        "                (\"val acc\", np.mean(val_acc_l),np.std(val_acc_l)),\n",
        "                (\"test acc\", np.mean(test_acc_l), np.std(test_acc_l)),\n",
        "                (\"CIFAR_SVHN auroc\", np.mean(auroc_cifsv_l), np.std(auroc_cifsv_l)),\n",
        "                (\"Self auroc\", np.mean(self_auroc_l), np.std(self_auroc_l)),\n",
        "            ])\n",
        "    \n",
        "    torch.save(model.state_dict(), \"DUQ_CIFAR_75.pt\")\n",
        "    print(results)\n",
        "    "
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}