{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "LSexp_FM_DUQ.ipynb",
      "provenance": [],
      "collapsed_sections": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "code",
      "metadata": {
        "id": "WJRVo8zSLiAc",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 67
        },
        "outputId": "e275fc46-6ce6-49a3-a81e-216c98309bf5"
      },
      "source": [
        "!mkdir -p data && cd data && curl -O \"http://yaroslavvb.com/upload/notMNIST/notMNIST_small.mat\""
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current\n",
            "                                 Dload  Upload   Total   Spent    Left  Speed\n",
            "100  112M  100  112M    0     0  38.1M      0  0:00:02  0:00:02 --:--:-- 38.1M\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "uKMj1TGKJ2RI"
      },
      "source": [
        "import random\n",
        "import numpy as np\n",
        "\n",
        "import torch\n",
        "import torch.utils.data\n",
        "from torch.nn import functional as F\n",
        "!pip install pytorch-ignite\n",
        "from ignite.engine import Events, Engine\n",
        "from ignite.metrics import Accuracy, Loss\n",
        "\n",
        "from ignite.contrib.handlers.tqdm_logger import ProgressBar\n",
        "\n",
        "from utils.evaluate_ood import (\n",
        "    get_fashionmnist_mnist_ood,\n",
        "    get_fashionmnist_notmnist_ood,\n",
        ")\n",
        "from utils.datasets import FastFashionMNIST, get_FashionMNIST\n",
        "from utils.cnn_duq import CNN_DUQ\n",
        "\n",
        "\n",
        "def train_model(l_gradient_penalty, length_scale, final_model):\n",
        "    dataset = FastFashionMNIST(\"data/\", train=True, download=True)\n",
        "    test_dataset = FastFashionMNIST(\"data/\", train=False, download=True)\n",
        "\n",
        "    idx = list(range(60000))\n",
        "    random.shuffle(idx)\n",
        "\n",
        "    if final_model:\n",
        "        train_dataset = dataset\n",
        "        val_dataset = test_dataset\n",
        "    else:\n",
        "        train_dataset = torch.utils.data.Subset(dataset, indices=idx[:55000])\n",
        "        val_dataset = torch.utils.data.Subset(dataset, indices=idx[55000:])\n",
        "\n",
        "    input_size = 28\n",
        "    num_classes = 10\n",
        "    embedding_size = 256\n",
        "    learnable_length_scale = False\n",
        "    gamma = 0.999\n",
        "\n",
        "    model = CNN_DUQ(\n",
        "        input_size,\n",
        "        num_classes,\n",
        "        embedding_size,\n",
        "        learnable_length_scale,\n",
        "        length_scale,\n",
        "        gamma,\n",
        "    )\n",
        "    model = model.cuda()\n",
        "\n",
        "    optimizer = torch.optim.SGD(\n",
        "        model.parameters(), lr=0.05, momentum=0.9, weight_decay=1e-4\n",
        "    )\n",
        "\n",
        "    def output_transform_bce(output):\n",
        "        y_pred, y, _, _ = output\n",
        "        return y_pred, y\n",
        "\n",
        "    def output_transform_acc(output):\n",
        "        y_pred, y, _, _ = output\n",
        "        return y_pred, torch.argmax(y, dim=1)\n",
        "\n",
        "    def output_transform_gp(output):\n",
        "        y_pred, y, x, y_pred_sum = output\n",
        "        return x, y_pred_sum\n",
        "\n",
        "    def calc_gradient_penalty(x, y_pred_sum):\n",
        "        gradients = torch.autograd.grad(\n",
        "            outputs=y_pred_sum,\n",
        "            inputs=x,\n",
        "            grad_outputs=torch.ones_like(y_pred_sum),\n",
        "            create_graph=True,\n",
        "            retain_graph=True,\n",
        "        )[0]\n",
        "\n",
        "        gradients = gradients.flatten(start_dim=1)\n",
        "\n",
        "        # L2 norm\n",
        "        grad_norm = gradients.norm(2, dim=1)\n",
        "\n",
        "        # Two sided penalty\n",
        "        gradient_penalty = ((grad_norm - 1) ** 2).mean()\n",
        "\n",
        "        return gradient_penalty\n",
        "\n",
        "    def step(engine, batch):\n",
        "        model.train()\n",
        "        optimizer.zero_grad()\n",
        "\n",
        "        x, y = batch\n",
        "        y = F.one_hot(y, num_classes=10).float()\n",
        "\n",
        "        x, y = x.cuda(), y.cuda()\n",
        "\n",
        "        x.requires_grad_(True)\n",
        "\n",
        "        z, y_pred = model(x)\n",
        "\n",
        "        loss = F.binary_cross_entropy(y_pred, y)\n",
        "        loss += l_gradient_penalty * calc_gradient_penalty(x, y_pred.sum(1))\n",
        "\n",
        "        x.requires_grad_(False)\n",
        "\n",
        "        loss.backward()\n",
        "        optimizer.step()\n",
        "\n",
        "        with torch.no_grad():\n",
        "            model.eval()\n",
        "            model.update_embeddings(x, y)\n",
        "\n",
        "        return loss.item()\n",
        "\n",
        "    def eval_step(engine, batch):\n",
        "        model.eval()\n",
        "\n",
        "        x, y = batch\n",
        "        y = F.one_hot(y, num_classes=10).float()\n",
        "\n",
        "        x, y = x.cuda(), y.cuda()\n",
        "\n",
        "        x.requires_grad_(True)\n",
        "\n",
        "        z, y_pred = model(x)\n",
        "\n",
        "        return y_pred, y, x, y_pred.sum(1)\n",
        "\n",
        "    trainer = Engine(step)\n",
        "    evaluator = Engine(eval_step)\n",
        "\n",
        "    metric = Accuracy(output_transform=output_transform_acc)\n",
        "    metric.attach(evaluator, \"accuracy\")\n",
        "\n",
        "    metric = Loss(F.binary_cross_entropy, output_transform=output_transform_bce)\n",
        "    metric.attach(evaluator, \"bce\")\n",
        "\n",
        "    metric = Loss(calc_gradient_penalty, output_transform=output_transform_gp)\n",
        "    metric.attach(evaluator, \"gradient_penalty\")\n",
        "\n",
        "    scheduler = torch.optim.lr_scheduler.MultiStepLR(\n",
        "        optimizer, milestones=[10, 20], gamma=0.2\n",
        "    )\n",
        "\n",
        "    dl_train = torch.utils.data.DataLoader(\n",
        "        train_dataset, batch_size=128, shuffle=True, num_workers=0, drop_last=True\n",
        "    )\n",
        "\n",
        "    dl_val = torch.utils.data.DataLoader(\n",
        "        val_dataset, batch_size=2000, shuffle=False, num_workers=0\n",
        "    )\n",
        "\n",
        "    dl_test = torch.utils.data.DataLoader(\n",
        "        test_dataset, batch_size=2000, shuffle=False, num_workers=0\n",
        "    )\n",
        "\n",
        "    pbar = ProgressBar()\n",
        "    pbar.attach(trainer)\n",
        "\n",
        "    @trainer.on(Events.EPOCH_COMPLETED)\n",
        "    def log_results(trainer):\n",
        "        scheduler.step()\n",
        "\n",
        "        if trainer.state.epoch % 5 == 0:\n",
        "            evaluator.run(dl_val)\n",
        "            accuracy, roc_auc_mnist = get_fashionmnist_mnist_ood(model)\n",
        "            accuracy, roc_auc_notmnist = get_fashionmnist_notmnist_ood(model)\n",
        "\n",
        "            metrics = evaluator.state.metrics\n",
        "\n",
        "            print(\n",
        "                f\"Validation Results - Epoch: {trainer.state.epoch} \"\n",
        "                f\"Acc: {metrics['accuracy']:.4f} \"\n",
        "                f\"BCE: {metrics['bce']:.2f} \"\n",
        "                f\"GP: {metrics['gradient_penalty']:.6f} \"\n",
        "                f\"AUROC MNIST: {roc_auc_mnist:.2f} \"\n",
        "                f\"AUROC NotMNIST: {roc_auc_notmnist:.2f} \"\n",
        "            )\n",
        "            print(f\"Sigma: {model.sigma}\")\n",
        "\n",
        "    trainer.run(dl_train, max_epochs=30)\n",
        "\n",
        "    evaluator.run(dl_val)\n",
        "    val_accuracy = evaluator.state.metrics[\"accuracy\"]\n",
        "\n",
        "    evaluator.run(dl_test)\n",
        "    test_accuracy = evaluator.state.metrics[\"accuracy\"]\n",
        "\n",
        "    return model, val_accuracy, test_accuracy\n",
        "\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    _, _, _, fashionmnist_test_dataset = get_FashionMNIST()\n",
        "\n",
        "    # Finding length scale - decided based on validation accuracy\n",
        "    l_gradient_penalties = [0.0]\n",
        "    length_scales = [0.05, 0.1, 0.2, 0.3, 0.5, 1.0]\n",
        "\n",
        "    # Finding gradient penalty - decided based on AUROC on NotMNIST\n",
        "    # l_gradient_penalties = [0.0, 0.05, 0.1, 0.2, 0.3, 0.5, 1.0]\n",
        "    # length_scales = [0.1]\n",
        "\n",
        "    repetition = 1  # Increase for multiple repetitions\n",
        "    final_model = False  # set true for final model to train on full train set\n",
        "\n",
        "    results = {}\n",
        "\n",
        "    for l_gradient_penalty in l_gradient_penalties:\n",
        "        for length_scale in length_scales:\n",
        "            val_accuracies = []\n",
        "            test_accuracies = []\n",
        "            roc_aucs_mnist = []\n",
        "            roc_aucs_notmnist = []\n",
        "\n",
        "            for _ in range(repetition):\n",
        "                print(\" ### NEW MODEL ### \")\n",
        "                model, val_accuracy, test_accuracy = train_model(\n",
        "                    l_gradient_penalty, length_scale, final_model\n",
        "                )\n",
        "                accuracy, roc_auc_mnist = get_fashionmnist_mnist_ood(model)\n",
        "                _, roc_auc_notmnist = get_fashionmnist_notmnist_ood(model)\n",
        "\n",
        "                val_accuracies.append(val_accuracy)\n",
        "                test_accuracies.append(test_accuracy)\n",
        "                roc_aucs_mnist.append(roc_auc_mnist)\n",
        "                roc_aucs_notmnist.append(roc_auc_notmnist)\n",
        "\n",
        "            results[f\"lgp{l_gradient_penalty}_ls{length_scale}\"] = [\n",
        "                (\"val acc\", np.mean(val_accuracies)),\n",
        "                (\"test acc\", np.mean(test_accuracies)),\n",
        "                (\"M auroc\", np.mean(roc_aucs_mnist)),\n",
        "                (\"NM auroc\", np.mean(roc_aucs_notmnist)),\n",
        "            ]\n",
        "            #print(results[f\"lgp{l_gradient_penalty}_ls{length_scale}\"])\n",
        "\n",
        "    print(results)\n"
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}