{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "DUQ_FM_final.ipynb",
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "code",
      "metadata": {
        "id": "2lppBolPJ60z",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "124542a1-9326-45d2-c65b-5686000873bb"
      },
      "source": [
        "!mkdir -p data && cd data && curl -O \"http://yaroslavvb.com/upload/notMNIST/notMNIST_small.mat\"\n",
        "\n",
        "!pip install pytorch-ignite\n",
        "#loss"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current\n",
            "                                 Dload  Upload   Total   Spent    Left  Speed\n",
            "100  112M  100  112M    0     0  12.8M      0  0:00:08  0:00:08 --:--:-- 16.3M\n",
            "Requirement already satisfied: pytorch-ignite in /usr/local/lib/python3.6/dist-packages (0.4.2)\n",
            "Requirement already satisfied: torch<2,>=1.3 in /usr/local/lib/python3.6/dist-packages (from pytorch-ignite) (1.7.0+cu101)\n",
            "Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from torch<2,>=1.3->pytorch-ignite) (1.18.5)\n",
            "Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from torch<2,>=1.3->pytorch-ignite) (0.16.0)\n",
            "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.6/dist-packages (from torch<2,>=1.3->pytorch-ignite) (3.7.4.3)\n",
            "Requirement already satisfied: dataclasses in /usr/local/lib/python3.6/dist-packages (from torch<2,>=1.3->pytorch-ignite) (0.8)\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "hdvCGqwyfbcY"
      },
      "source": [
        "import random\n",
        "import numpy as np\n",
        "\n",
        "import torch\n",
        "import torch.utils.data\n",
        "from torch.nn import functional as F\n",
        "\n",
        "from ignite.engine import Events, Engine\n",
        "from ignite.metrics import Accuracy, Loss\n",
        "\n",
        "from ignite.contrib.handlers.tqdm_logger import ProgressBar\n",
        "\n",
        "from utils.evaluate_ood import (\n",
        "    get_fashionmnist_mnist_ood,\n",
        "    get_fashionmnist_notmnist_ood,\n",
        ")\n",
        "from utils.datasets import FastFashionMNIST, get_FashionMNIST\n",
        "from utils.cnn_duq import CNN_DUQ\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "tJOnZiIwRjnw"
      },
      "source": [
        "model={}\n",
        "def train_model(l_gradient_penalty, length_scale, final_model,epochs):\n",
        "\n",
        "    input_size = 28\n",
        "    num_classes = 10\n",
        "    embedding_size = 256\n",
        "    learnable_length_scale = False\n",
        "    gamma = 0.999\n",
        "\n",
        "\n",
        "    ## Main (FashionMNIST) and ood (Mnist) Dataset\n",
        "    dataset = FastFashionMNIST(\"data/\", train=True, download=True)\n",
        "    test_dataset = FastFashionMNIST(\"data/\", train=False, download=True)\n",
        "\n",
        "    idx = list(range(60000))\n",
        "    random.shuffle(idx)\n",
        "\n",
        "    if final_model:\n",
        "        train_dataset = dataset\n",
        "        val_dataset = test_dataset\n",
        "    else:\n",
        "        train_dataset = torch.utils.data.Subset(dataset, indices=idx[:55000])\n",
        "        val_dataset = torch.utils.data.Subset(dataset, indices=idx[55000:])\n",
        "\n",
        "    dl_train = torch.utils.data.DataLoader(\n",
        "        train_dataset, batch_size=128, shuffle=True, num_workers=0, drop_last=True\n",
        "    )\n",
        "\n",
        "    dl_val = torch.utils.data.DataLoader(\n",
        "        val_dataset, batch_size=2000, shuffle=False, num_workers=0\n",
        "    )\n",
        "\n",
        "    dl_test = torch.utils.data.DataLoader(\n",
        "        test_dataset, batch_size=2000, shuffle=False, num_workers=0\n",
        "    )\n",
        "\n",
        "\n",
        "    # Model\n",
        "    global model\n",
        "    model = CNN_DUQ(\n",
        "        input_size,\n",
        "        num_classes,\n",
        "        embedding_size,\n",
        "        learnable_length_scale,\n",
        "        length_scale,\n",
        "        gamma,\n",
        "    )\n",
        "    \n",
        "    model = model.cuda()\n",
        "    #model.load_state_dict(torch.load(\"DUQ_FM_30_FULL.pt\"))\n",
        "\n",
        "    # Optimiser\n",
        "    optimizer = torch.optim.SGD(\n",
        "        model.parameters(), lr=0.05, momentum=0.9, weight_decay=1e-4\n",
        "    )\n",
        "\n",
        "    def output_transform_bce(output):\n",
        "        y_pred, y, _, _ = output\n",
        "        return y_pred, y\n",
        "\n",
        "    def output_transform_acc(output):\n",
        "        y_pred, y, _, _ = output\n",
        "        return y_pred, torch.argmax(y, dim=1)\n",
        "\n",
        "    def output_transform_gp(output):\n",
        "        y_pred, y, x, y_pred_sum = output\n",
        "        return x, y_pred_sum\n",
        "\n",
        "    def calc_gradient_penalty(x, y_pred_sum):\n",
        "        gradients = torch.autograd.grad(\n",
        "            outputs=y_pred_sum,\n",
        "            inputs=x,\n",
        "            grad_outputs=torch.ones_like(y_pred_sum),\n",
        "            create_graph=True,\n",
        "            retain_graph=True,\n",
        "        )[0]\n",
        "\n",
        "        gradients = gradients.flatten(start_dim=1)\n",
        "\n",
        "        # L2 norm\n",
        "        grad_norm = gradients.norm(2, dim=1)\n",
        "\n",
        "        # Two sided penalty\n",
        "        gradient_penalty = ((grad_norm - 1) ** 2).mean()\n",
        "\n",
        "        return gradient_penalty\n",
        "\n",
        "    def step(engine, batch):\n",
        "        model.train()\n",
        "        optimizer.zero_grad()\n",
        "\n",
        "        x, y = batch\n",
        "        y = F.one_hot(y, num_classes=10).float()\n",
        "\n",
        "        x, y = x.cuda(), y.cuda()\n",
        "\n",
        "        x.requires_grad_(True)\n",
        "\n",
        "        z, y_pred = model(x)\n",
        "\n",
        "        loss = F.binary_cross_entropy(y_pred, y)\n",
        "        loss += l_gradient_penalty * calc_gradient_penalty(x, y_pred.sum(1))\n",
        "        loss/=(1+l_gradient_penalty)\n",
        "\n",
        "        x.requires_grad_(False)\n",
        "\n",
        "        loss.backward()\n",
        "        optimizer.step()\n",
        "\n",
        "        with torch.no_grad():\n",
        "            model.eval()\n",
        "            model.update_embeddings(x, y)\n",
        "\n",
        "        return loss.item()\n",
        "\n",
        "    def eval_step(engine, batch):\n",
        "        model.eval()\n",
        "\n",
        "        x, y = batch\n",
        "        y = F.one_hot(y, num_classes=10).float()\n",
        "\n",
        "        x, y = x.cuda(), y.cuda()\n",
        "\n",
        "        x.requires_grad_(True)\n",
        "\n",
        "        z, y_pred = model(x)\n",
        "\n",
        "        return y_pred, y, x, y_pred.sum(1)\n",
        "\n",
        "    trainer = Engine(step)\n",
        "    evaluator = Engine(eval_step)\n",
        "\n",
        "    metric = Accuracy(output_transform=output_transform_acc)\n",
        "    metric.attach(evaluator, \"accuracy\")\n",
        "\n",
        "    metric = Loss(F.binary_cross_entropy, output_transform=output_transform_bce)\n",
        "    metric.attach(evaluator, \"bce\")\n",
        "\n",
        "    metric = Loss(calc_gradient_penalty, output_transform=output_transform_gp)\n",
        "    metric.attach(evaluator, \"gradient_penalty\")\n",
        "\n",
        "    scheduler = torch.optim.lr_scheduler.MultiStepLR(\n",
        "        optimizer, milestones=[10, 20], gamma=0.2\n",
        "    )\n",
        "\n",
        "    pbar = ProgressBar()\n",
        "    pbar.attach(trainer)\n",
        "\n",
        "    @trainer.on(Events.EPOCH_COMPLETED)\n",
        "    def log_results(trainer):\n",
        "        scheduler.step()\n",
        "\n",
        "        # logging every 5 epoch\n",
        "        if trainer.state.epoch % 5 == 0:\n",
        "            evaluator.run(dl_val)\n",
        "\n",
        "            # AUROC on FashionMNIST + Mnist / NotMnist\n",
        "            accuracy, roc_auc_mnist = get_fashionmnist_mnist_ood(model)\n",
        "            accuracy, roc_auc_notmnist = get_fashionmnist_notmnist_ood(model)\n",
        "            metrics = evaluator.state.metrics\n",
        "\n",
        "            print(\n",
        "                f\"Validation Results - Epoch: {trainer.state.epoch} \"\n",
        "                f\"Val_Acc: {metrics['accuracy']:.4f} \"\n",
        "                f\"BCE: {metrics['bce']:.2f} \"\n",
        "                f\"GP: {metrics['gradient_penalty']:.6f} \"\n",
        "                f\"AUROC MNIST: {roc_auc_mnist:.4f} \"\n",
        "                f\"AUROC NotMNIST: {roc_auc_notmnist:.2f} \"\n",
        "            )\n",
        "            print(f\"Sigma: {model.sigma}\")\n",
        "\n",
        "    # Train\n",
        "    trainer.run(dl_train, max_epochs=epochs)\n",
        "\n",
        "    # Validation\n",
        "    evaluator.run(dl_val)\n",
        "    val_accuracy = evaluator.state.metrics[\"accuracy\"]\n",
        "\n",
        "    # Test\n",
        "    evaluator.run(dl_test)\n",
        "    test_accuracy = evaluator.state.metrics[\"accuracy\"]\n",
        "\n",
        "    return model, val_accuracy, test_accuracy"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "tdAmrBrcSCRc"
      },
      "source": [
        "if __name__ == \"__main__\":\n",
        "    _, _, _, fashionmnist_test_dataset = get_FashionMNIST()\n",
        "\n",
        "    l_gradient_penalties = [0.05,0.1,0.2,0.5,1,2]\n",
        "    length_scales = [0.1]\n",
        "    epochs=30\n",
        "\n",
        "    repetition = 3  # Increase for multiple repetitions\n",
        "    final_model = True  # set true for final model to train on full train set\n",
        "\n",
        "    results = {}\n",
        "\n",
        "    for l_gradient_penalty in l_gradient_penalties:\n",
        "        for length_scale in length_scales:\n",
        "            val_accuracies = []\n",
        "            test_accuracies = []\n",
        "            roc_aucs_mnist = []\n",
        "            roc_aucs_notmnist = []\n",
        "\n",
        "            for _ in range(repetition):\n",
        "                print(f\" ### NEW MODEL ### gp = {l_gradient_penalty}\")\n",
        "                model, val_accuracy, test_accuracy = train_model(\n",
        "                    l_gradient_penalty, length_scale, final_model, epochs\n",
        "                )\n",
        "                accuracy, roc_auc_mnist = get_fashionmnist_mnist_ood(model)\n",
        "                _, roc_auc_notmnist = get_fashionmnist_notmnist_ood(model)\n",
        "\n",
        "                val_accuracies.append(val_accuracy)\n",
        "                test_accuracies.append(test_accuracy)\n",
        "                roc_aucs_mnist.append(roc_auc_mnist)\n",
        "                roc_aucs_notmnist.append(roc_auc_notmnist)\n",
        "            \n",
        "            # All stats\n",
        "            results[f\"lgp{l_gradient_penalty}_ls{length_scale}\"] = [\n",
        "                (\"val acc\", np.mean(val_accuracies)),\n",
        "                (\"test acc\", np.mean(test_accuracies)),\n",
        "                (\"M auroc\", np.mean(roc_aucs_mnist)),\n",
        "                (\"NM auroc\", np.mean(roc_aucs_notmnist)),\n",
        "            ]\n",
        "            #print(results[f\"lgp{l_gradient_penalty}_ls{length_scale}\"])\n",
        "    \n",
        "    # Save\n",
        "    torch.save(model.state_dict(), \"DUQ_FM_30_FULL.pt\")\n",
        "    print(results)\n"
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}