{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "DUQ_FM_final.ipynb",
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "metadata": {
        "id": "2lppBolPJ60z"
      },
      "source": [
        "!mkdir -p data && cd data && curl -O \"http://yaroslavvb.com/upload/notMNIST/notMNIST_small.mat\"\n",
        "\n",
        "!pip install pytorch-ignite"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "hdvCGqwyfbcY"
      },
      "source": [
        "import random\n",
        "import numpy as np\n",
        "\n",
        "import torch\n",
        "import torch.utils.data\n",
        "from torch.nn import functional as F\n",
        "\n",
        "from ignite.engine import Events, Engine\n",
        "from ignite.metrics import Accuracy, Loss\n",
        "\n",
        "from ignite.contrib.handlers.tqdm_logger import ProgressBar\n",
        "\n",
        "from utils.evaluate_ood import (\n",
        "    get_fashionmnist_mnist_ood,\n",
        "    get_fashionmnist_notmnist_ood,\n",
        ")\n",
        "from utils.datasets import FastFashionMNIST, get_FashionMNIST\n",
        "from utils.cnn_duq import CNN_DUQ\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "tJOnZiIwRjnw"
      },
      "source": [
        "model={}\n",
        "def train_model(l_gradient_penalty, length_scale, final_model,epochs,input_dep_ls,use_grad_norm):\n",
        "\n",
        "    input_size = 28\n",
        "    num_classes = 10\n",
        "    embedding_size = 256\n",
        "    learnable_length_scale = False #Learnable length scale\n",
        "    gamma = 0.999\n",
        "    \n",
        "    if input_dep_ls and learnable_length_scale: #only one can be True\n",
        "        learnable_length_scale=False\n",
        "\n",
        "\n",
        "    ## Main (FashionMNIST) and ood (Mnist) Dataset\n",
        "    dataset = FastFashionMNIST(\"data/\", train=True, download=True)\n",
        "    test_dataset = FastFashionMNIST(\"data/\", train=False, download=True)\n",
        "    idx = list(range(60000))\n",
        "    random.shuffle(idx)\n",
        "\n",
        "    if final_model:\n",
        "        train_dataset = dataset\n",
        "        val_dataset = test_dataset\n",
        "    else:\n",
        "        train_dataset = torch.utils.data.Subset(dataset, indices=idx[:55000])\n",
        "        val_dataset = torch.utils.data.Subset(dataset, indices=idx[55000:])\n",
        "\n",
        "    dl_train = torch.utils.data.DataLoader(\n",
        "        train_dataset, batch_size=128, shuffle=True, num_workers=0, drop_last=True\n",
        "    )\n",
        "\n",
        "    dl_val = torch.utils.data.DataLoader(\n",
        "        val_dataset, batch_size=2000, shuffle=False, num_workers=0\n",
        "    )\n",
        "\n",
        "    dl_test = torch.utils.data.DataLoader(\n",
        "        test_dataset, batch_size=2000, shuffle=False, num_workers=0\n",
        "    )\n",
        "\n",
        "\n",
        "    # Model\n",
        "    global model\n",
        "    model = CNN_DUQ(\n",
        "        input_size,\n",
        "        num_classes,\n",
        "        embedding_size,\n",
        "        learnable_length_scale,\n",
        "        length_scale,\n",
        "        gamma,\n",
        "        input_dep_ls\n",
        "    )\n",
        "    \n",
        "    model = model.cuda()\n",
        "    #model.load_state_dict(torch.load(\"DUQ_FM_30_FULL.pt\"))\n",
        "\n",
        "    # Optimiser\n",
        "    optimizer = torch.optim.SGD(\n",
        "        model.parameters(), lr=0.05, momentum=0.9, weight_decay=1e-4\n",
        "    )\n",
        "\n",
        "    def output_transform_bce(output):\n",
        "        y_pred, y, _, _ = output\n",
        "        return y_pred, y\n",
        "\n",
        "    def output_transform_acc(output):\n",
        "        y_pred, y, _, _ = output\n",
        "        return y_pred, torch.argmax(y, dim=1)\n",
        "\n",
        "    def output_transform_gp(output):\n",
        "        y_pred, y, x, y_pred_sum = output\n",
        "        return x, y_pred_sum\n",
        "\n",
        "    def calc_gradient_penalty(x, y_pred_sum):\n",
        "        gradients = torch.autograd.grad(\n",
        "            outputs=y_pred_sum,\n",
        "            inputs=x,\n",
        "            grad_outputs=torch.ones_like(y_pred_sum),\n",
        "            create_graph=True,\n",
        "            retain_graph=True,\n",
        "        )[0]\n",
        "\n",
        "        gradients = gradients.flatten(start_dim=1)\n",
        "\n",
        "        # L2 norm\n",
        "        grad_norm = gradients.norm(2, dim=1)\n",
        "\n",
        "        # Two sided penalty\n",
        "        gradient_penalty = ((grad_norm - 1) ** 2).mean()\n",
        "\n",
        "        return gradient_penalty\n",
        "\n",
        "    def step(engine, batch):\n",
        "        model.train()\n",
        "        optimizer.zero_grad()\n",
        "\n",
        "        x, y = batch\n",
        "        y = F.one_hot(y, num_classes=10).float()\n",
        "\n",
        "        x, y = x.cuda(), y.cuda()\n",
        "\n",
        "        x.requires_grad_(True)\n",
        "\n",
        "        z, y_pred = model(x)\n",
        "\n",
        "        loss = F.binary_cross_entropy(y_pred, y)\n",
        "        loss += l_gradient_penalty * calc_gradient_penalty(x, y_pred.sum(1))\n",
        "        \n",
        "        if use_grad_norm:\n",
        "            #gradient normalization\n",
        "            loss/=(1+l_gradient_penalty)\n",
        "        \n",
        "        x.requires_grad_(False)\n",
        "\n",
        "        loss.backward()\n",
        "        optimizer.step()\n",
        "\n",
        "        with torch.no_grad():\n",
        "            model.eval()\n",
        "            model.update_embeddings(x, y)\n",
        "\n",
        "        return loss.item()\n",
        "\n",
        "    def eval_step(engine, batch):\n",
        "        model.eval()\n",
        "\n",
        "        x, y = batch\n",
        "        y = F.one_hot(y, num_classes=10).float()\n",
        "\n",
        "        x, y = x.cuda(), y.cuda()\n",
        "\n",
        "        x.requires_grad_(True)\n",
        "\n",
        "        z, y_pred = model(x)\n",
        "\n",
        "        return y_pred, y, x, y_pred.sum(1)\n",
        "\n",
        "    trainer = Engine(step)\n",
        "    evaluator = Engine(eval_step)\n",
        "\n",
        "    metric = Accuracy(output_transform=output_transform_acc)\n",
        "    metric.attach(evaluator, \"accuracy\")\n",
        "\n",
        "    metric = Loss(F.binary_cross_entropy, output_transform=output_transform_bce)\n",
        "    metric.attach(evaluator, \"bce\")\n",
        "\n",
        "    metric = Loss(calc_gradient_penalty, output_transform=output_transform_gp)\n",
        "    metric.attach(evaluator, \"gradient_penalty\")\n",
        "\n",
        "    scheduler = torch.optim.lr_scheduler.MultiStepLR(\n",
        "        optimizer, milestones=[10, 20], gamma=0.2\n",
        "    )\n",
        "\n",
        "    pbar = ProgressBar()\n",
        "    pbar.attach(trainer)\n",
        "\n",
        "    @trainer.on(Events.EPOCH_COMPLETED)\n",
        "    def log_results(trainer):\n",
        "        scheduler.step()\n",
        "\n",
        "        # logging every 5 epoch\n",
        "        if trainer.state.epoch % 5 == 0:\n",
        "            evaluator.run(dl_val)\n",
        "\n",
        "            # AUROC on FashionMNIST + Mnist / NotMnist\n",
        "            accuracy, roc_auc_mnist = get_fashionmnist_mnist_ood(model)\n",
        "            accuracy, roc_auc_notmnist = get_fashionmnist_notmnist_ood(model)\n",
        "            metrics = evaluator.state.metrics\n",
        "\n",
        "            print(\n",
        "                f\"Validation Results - Epoch: {trainer.state.epoch} \"\n",
        "                f\"Val_Acc: {metrics['accuracy']:.4f} \"\n",
        "                f\"BCE: {metrics['bce']:.2f} \"\n",
        "                f\"GP: {metrics['gradient_penalty']:.6f} \"\n",
        "                f\"AUROC MNIST: {roc_auc_mnist:.4f} \"\n",
        "                f\"AUROC NotMNIST: {roc_auc_notmnist:.2f} \"\n",
        "            )\n",
        "            #print(f\"Sigma: {model.sigma}\")\n",
        "\n",
        "    # Train\n",
        "    trainer.run(dl_train, max_epochs=epochs)\n",
        "\n",
        "    # Validation\n",
        "    evaluator.run(dl_val)\n",
        "    val_accuracy = evaluator.state.metrics[\"accuracy\"]\n",
        "\n",
        "    # Test\n",
        "    evaluator.run(dl_test)\n",
        "    test_accuracy = evaluator.state.metrics[\"accuracy\"]\n",
        "\n",
        "    return model, val_accuracy, test_accuracy"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "tdAmrBrcSCRc"
      },
      "source": [
        "if __name__ == \"__main__\":\n",
        "    _, _, _, fashionmnist_test_dataset = get_FashionMNIST()\n",
        "\n",
        "    l_gradient_penalties = [0.05]\n",
        "    length_scales = [0.1]\n",
        "    epochs=30\n",
        "\n",
        "    repetition = 3  # Increase for multiple repetitions\n",
        "    final_model = True  # set true for final model to train on full train set\n",
        "    input_dep_ls = False #input dependent length scale (sigma)\n",
        "    use_grad_norm = False #gradient normalization\n",
        "\n",
        "    results = {}\n",
        "\n",
        "    for l_gradient_penalty in l_gradient_penalties:\n",
        "        for length_scale in length_scales:\n",
        "            val_accuracies = []\n",
        "            test_accuracies = []\n",
        "            roc_aucs_mnist = []\n",
        "            roc_aucs_notmnist = []\n",
        "\n",
        "            for _ in range(repetition):\n",
        "                print(\" ### NEW MODEL ### \")\n",
        "                model, val_accuracy, test_accuracy = train_model(\n",
        "                    l_gradient_penalty, length_scale, final_model, epochs, input_dep_ls,use_grad_norm\n",
        "                )\n",
        "                accuracy, roc_auc_mnist = get_fashionmnist_mnist_ood(model)\n",
        "                _, roc_auc_notmnist = get_fashionmnist_notmnist_ood(model)\n",
        "\n",
        "                val_accuracies.append(val_accuracy)\n",
        "                test_accuracies.append(test_accuracy)\n",
        "                roc_aucs_mnist.append(roc_auc_mnist)\n",
        "                roc_aucs_notmnist.append(roc_auc_notmnist)\n",
        "            \n",
        "            # All stats\n",
        "            results[f\"lgp{l_gradient_penalty}_ls{length_scale}\"] = [\n",
        "                (\"val acc\", np.mean(val_accuracies),np.std(val_accuracies)),\n",
        "                (\"test acc\", np.mean(test_accuracies), np.std(test_accuracies)),\n",
        "                (\"M auroc\", np.mean(roc_aucs_mnist), np.std(roc_aucs_mnist)),\n",
        "                (\"NM auroc\", np.mean(roc_aucs_notmnist), np.std(roc_aucs_notmnist)),\n",
        "            ]\n",
        "            #print(results[f\"lgp{l_gradient_penalty}_ls{length_scale}\"])\n",
        "    \n",
        "    # Save\n",
        "    torch.save(model.state_dict(), \"DUQ_FM_30_FULL.pt\")\n",
        "    print(results)\n"
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}