{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "DE_final.ipynb",
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "code",
      "metadata": {
        "id": "bjTf41lfeC0j"
      },
      "source": [
        "!mkdir -p data && cd data && curl -O \"http://yaroslavvb.com/upload/notMNIST/notMNIST_small.mat\""
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ZnDmCGJiaWRm"
      },
      "source": [
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "\n",
        "from tqdm import tqdm\n",
        "\n",
        "from utils.datasets import all_datasets\n",
        "from utils.cnn_duq import SoftmaxModel as CNN\n",
        "from torchvision.models import resnet18\n",
        "from utils.resnet import ResNet\n",
        "from utils.ensemble_eval import (get_fm_mnist_ood_ensemble, get_cifar10_svhn_ood_ensemble)\n",
        "\n",
        "\n",
        "def train(model, train_loader, optimizer, epoch, loss_fn):\n",
        "\n",
        "    ##Train function##########\n",
        "\n",
        "    model.train()\n",
        "\n",
        "    total_loss = []\n",
        "\n",
        "    for batch_idx, (data, target) in enumerate(tqdm(train_loader)):\n",
        "        data = data.cuda()\n",
        "        target = target.cuda()\n",
        "\n",
        "        optimizer.zero_grad()\n",
        "\n",
        "        prediction = model(data)\n",
        "        loss = loss_fn(prediction, target)\n",
        "\n",
        "        loss.backward()\n",
        "        optimizer.step()\n",
        "\n",
        "        total_loss.append(loss.item())\n",
        "\n",
        "    avg_loss = torch.tensor(total_loss).mean()\n",
        "    print(f\"Epoch: {epoch}:\")\n",
        "    print(f\"Train Set: Average Loss: {avg_loss:.2f}\")\n",
        "\n",
        "\n",
        "def test(models, test_loader, loss_fn):    \n",
        "    \n",
        "    ##Test function##########\n",
        "\n",
        "    models.eval()\n",
        "    loss = 0\n",
        "    correct = 0\n",
        "\n",
        "    for data, target in test_loader:\n",
        "        with torch.no_grad():\n",
        "            data = data.cuda()\n",
        "            target = target.cuda()\n",
        "\n",
        "            losses = torch.empty(len(models), data.shape[0])\n",
        "            predictions = []\n",
        "            for i, model in enumerate(models):\n",
        "                predictions.append(model(data))\n",
        "                losses[i, :] = loss_fn(predictions[i], target, reduction=\"sum\")\n",
        "\n",
        "            predictions = torch.stack(predictions)\n",
        "\n",
        "            loss += torch.mean(losses)\n",
        "            avg_prediction = predictions.exp().mean(0)\n",
        "\n",
        "            # get the index of the max log-probability\n",
        "            class_prediction = avg_prediction.max(1)[1]\n",
        "            correct += (\n",
        "                class_prediction.eq(target.view_as(class_prediction)).sum().item()\n",
        "            )\n",
        "\n",
        "    loss /= len(test_loader.dataset)\n",
        "\n",
        "    percentage_correct = 100.0 * correct / len(test_loader.dataset)\n",
        "\n",
        "    print(\n",
        "        \"Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\".format(\n",
        "            loss, correct, len(test_loader.dataset), percentage_correct\n",
        "        )\n",
        "    )\n",
        "\n",
        "    return loss, percentage_correct\n",
        "\n",
        "\n",
        "def main():\n",
        "    \n",
        "    ## Epochs, lr, Dataset={\"FashionMNIST\",\"CIFAR10\"}\n",
        "\n",
        "    args={'epochs':30,'lr':0.05,'ensemble':5,'dataset':\"FashionMNIST\"}                    \n",
        "    loss_fn = F.nll_loss\n",
        "\n",
        "\n",
        "    #Selecting Main Dataset\n",
        "    #FashionMNIST-Mnist\n",
        "    #CIFAR10-SVHN\n",
        "    ds = all_datasets[args['dataset']]()\n",
        "    input_size, num_classes, train_dataset, test_dataset = ds\n",
        "    kwargs = {\"num_workers\": 4, \"pin_memory\": True}\n",
        "\n",
        "    train_loader = torch.utils.data.DataLoader(\n",
        "        train_dataset, batch_size=128, shuffle=True, **kwargs\n",
        "    )\n",
        "    test_loader = torch.utils.data.DataLoader(\n",
        "        test_dataset, batch_size=5000, shuffle=False, **kwargs\n",
        "    )\n",
        "\n",
        "\n",
        "    #Selecting model CNN for FashionMNIST and Resnet for CIFAR10\n",
        "\n",
        "    if args['dataset'] == \"FashionMNIST\":\n",
        "        milestones = [10, 20]\n",
        "        ensemble = [CNN(input_size, num_classes).cuda() for _ in range(args['ensemble'])]\n",
        "    else:\n",
        "        milestones = [25, 50]\n",
        "        ensemble = [\n",
        "            ResNet(input_size, num_classes).cuda() for _ in range(args['ensemble'])\n",
        "        ]\n",
        "\n",
        "    ensemble = torch.nn.ModuleList(ensemble)\n",
        "    #ensemble.load_state_dict(torch.load(\"FM_5_ensemble_30.pt\"))\n",
        "    \n",
        "    optimizers = []\n",
        "    schedulers = []\n",
        "\n",
        "    for model in ensemble:\n",
        "        # Need different optimisers to apply weight decay and momentum properly\n",
        "        # when only optimising one element of the ensemble\n",
        "        optimizers.append(\n",
        "            torch.optim.SGD(\n",
        "                model.parameters(), lr=args['lr'], momentum=0.9, weight_decay=5e-4\n",
        "            )\n",
        "        )\n",
        "\n",
        "        schedulers.append(\n",
        "            torch.optim.lr_scheduler.MultiStepLR(\n",
        "                optimizers[-1], milestones=milestones, gamma=0.1\n",
        "            )\n",
        "        )\n",
        "\n",
        "    for epoch in range(1, args['epochs'] + 1):\n",
        "        #####Train#####\n",
        "        for i, model in enumerate(ensemble):                                             \n",
        "            train(model, train_loader, optimizers[i], epoch, loss_fn)\n",
        "            schedulers[i].step()\n",
        "\n",
        "        #####Test######\n",
        "        #Test on testset of main dataset\n",
        "        test(ensemble, test_loader, loss_fn)   \n",
        "\n",
        "        #####AUROC######   \n",
        "        #AUROC on Main + ood                                  \n",
        "        if(args['dataset'] == \"FashionMNIST\"):\n",
        "            accuracy, auroc = get_fm_mnist_ood_ensemble(ensemble)                        \n",
        "            print({'mnist_ood_auroc':auroc})\n",
        "        else:\n",
        "            accuracy, auroc = get_cifar10_svhn_ood_ensemble(ensemble)\n",
        "            print({'cifar10_ood_auroc':auroc})   \n",
        "\n",
        "    #Save\n",
        "    path = f\"model{args['dataset']}_{len(ensemble)}\"\n",
        "    torch.save(ensemble.state_dict(), path + \"_ensemble.pt\")\n",
        "\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}