{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "DUQ_hist_CIFAR.ipynb",
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "6Xl8zK63Km6P",
        "outputId": "0b66bcff-648e-4a70-c5ef-d03fe4797469"
      },
      "source": [
        "from google.colab import drive\n",
        "drive.mount('/content/gdrive')"
      ],
      "execution_count": 2,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Mounted at /content/gdrive\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "9xphocsGyXUa",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "3e6254c8-2f98-4269-88b6-fc8eed11e621"
      },
      "source": [
        "import torch\n",
        "from utils.datasets import all_datasets\n",
        "import matplotlib.pyplot as plt \n",
        "import numpy as np\n",
        "\n",
        "from utils.cnn_duq import CNN_DUQ\n",
        "from utils.cnn_duq import SoftmaxModel as CNN\n",
        "\n",
        "from utils.resnet import ResNet\n",
        "from utils.resnet_duq import ResNet_DUQ\n",
        "\n",
        "\n",
        "ds3 = all_datasets[\"CIFAR10\"]()\n",
        "input_size, num_classes, _ , test_dataset_CIFAR = ds3\n",
        "ds4 = all_datasets[\"SVHN\"]()\n",
        "_,_,_, test_dataset_SVHN = ds4\n",
        "\n",
        "\n",
        "mod='DUQ' #['DUQ','DE']\n",
        "\n",
        "if mod=='DUQ':\n",
        "  model_CIFAR = ResNet_DUQ(input_size, num_classes, 512 ,512 ,0.1 ,0.999).cuda() \n",
        "  model_CIFAR.load_state_dict(torch.load(\"/content/gdrive/My Drive/Colab Notebooks/DUQ_CIFAR_75.pt\"))\n",
        "else:\n",
        "  ensemble = [\n",
        "            ResNet(input_size, num_classes).cuda() for _ in range(5)\n",
        "        ]\n",
        "  model_CIFAR = torch.nn.ModuleList(ensemble);\n",
        "  model_CIFAR.load_state_dict(torch.load('/content/gdrive/My Drive/Colab Notebooks/CIFAR10_5_ensemble.pt'))\n",
        "model_CIFAR.eval()\n",
        "\n",
        "b=50\n",
        "r=len(test_dataset_CIFAR)\n",
        "ls=[]\n",
        "m=test_dataset_CIFAR[0][0].max()\n",
        "\n",
        "\n",
        "for i in range(int(r/b)+1):\n",
        "    data=[]\n",
        "    for j in range(0,b):\n",
        "        c=i*b+j\n",
        "        if(c>=r):\n",
        "          break\n",
        "        data.append(test_dataset_CIFAR[c][0])\n",
        "    if(len(data)==0):\n",
        "      break\n",
        "    data=torch.stack(data)\n",
        "    ls.append(data)\n",
        "\n",
        "\n",
        "kernel_distace_SVHN = np.zeros((r))\n",
        "for i in range(int(r/b)+1):\n",
        "  data=[]\n",
        "  for j in range(0,b):\n",
        "      c=i*b+j\n",
        "      if(c>=r):\n",
        "        break\n",
        "      data.append(test_dataset_SVHN[c][0])\n",
        "  if(len(data)==0):\n",
        "    break\n",
        "  data=torch.stack(data)\n",
        "\n",
        "  if mod=='DUQ':\n",
        "    res=model_CIFAR(data.cuda())[1].max(1)[0].detach()\n",
        "  else:\n",
        "    predictions = torch.stack([model(data.cuda()) for model in model_CIFAR])\n",
        "    mean_prediction = torch.mean(predictions.exp(), dim=0)\n",
        "    res=torch.sum(mean_prediction * torch.log(mean_prediction), dim=1)\n",
        "\n",
        "  for j in range(0,len(res)):\n",
        "      c=i*b+j\n",
        "      kernel_distace_SVHN[c] = res[j].item()\n",
        "  if(i%20==0):\n",
        "    print(c)\n",
        "\n",
        "\n",
        "fig = plt.figure(figsize=(5,3))\n",
        "er=[0]\n",
        "\n",
        "for p,e in enumerate(er):\n",
        "    sample=(test_dataset_CIFAR[10][0]+e*m*torch.randn(3,32,32)).numpy()\n",
        "\n",
        "    kernel_distace_CIFAR_e=np.zeros(r)\n",
        "    c=0\n",
        "    for data in ls:\n",
        "      Data=data+e*m*torch.randn(len(data),3,32,32)\n",
        "      if mod=='DUQ':\n",
        "          res=model_CIFAR(Data.cuda())[1].max(1)[0].detach()\n",
        "      else:\n",
        "          predictions = torch.stack([model(Data.cuda()) for model in model_CIFAR])\n",
        "          mean_prediction = torch.mean(predictions.exp(), dim=0)\n",
        "          res=torch.sum(mean_prediction * torch.log(mean_prediction), dim=1)\n",
        "\n",
        "      for j in range(0,len(res)):\n",
        "        kernel_distace_CIFAR_e[c] = res[j].item()\n",
        "        c+=1\n",
        "      if(c%1000==0):\n",
        "        print(c)\n",
        "\n",
        "\n",
        "    ax=fig.add_subplot(len(er),1,p+1,xlabel='Confidence',ylabel='Freq',title='Hist')\n",
        "    n, bins, patches = ax.hist(x=[kernel_distace_CIFAR_e,kernel_distace_SVHN], bins='auto', color=['#0000FF','#FFA500']                      )\n",
        "    ax.legend(['CIFAR','SVHN'],loc='upper left')\n",
        "    \n",
        "\n",
        "fig.tight_layout()\n",
        "plt.show()\n"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Files already downloaded and verified\n",
            "Files already downloaded and verified\n"
          ],
          "name": "stdout"
        }
      ]
    }
  ]
}