{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "name": "DUQ_hist_CIFAR.ipynb",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "cSFTyGnz9975",
        "outputId": "1af3d83d-bdd9-4b63-e64c-949e04a15693"
      },
      "source": [
        "from google.colab import drive\n",
        "drive.mount('/content/gdrive')"
      ],
      "execution_count": 7,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount(\"/content/gdrive\", force_remount=True).\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "9xphocsGyXUa",
        "outputId": "a2fb0fc0-be64-4c4f-abf9-a0b1f1cbb99c"
      },
      "source": [
        "import torch\n",
        "from utils.datasets import all_datasets\n",
        "import matplotlib.pyplot as plt \n",
        "import numpy as np\n",
        "\n",
        "from utils.cnn_duq import CNN_DUQ\n",
        "from utils.cnn_duq import SoftmaxModel as CNN\n",
        "\n",
        "from utils.resnet import ResNet\n",
        "from utils.resnet_duq import ResNet_DUQ\n",
        "\n",
        "\n",
        "ds3 = all_datasets[\"CIFAR10\"]()\n",
        "input_size, num_classes, _ , test_dataset_CIFAR = ds3\n",
        "ds4 = all_datasets[\"SVHN\"]()\n",
        "_,_,_, test_dataset_SVHN = ds4\n",
        "\n",
        "\n",
        "mod='DUQ' #['DUQ','DE']\n",
        "\n",
        "if mod=='DUQ':\n",
        "  model_CIFAR = ResNet_DUQ(input_size, num_classes, 512 ,512 ,0.1 ,0.999).cuda() \n",
        "  model_CIFAR.load_state_dict(torch.load(\"/content/gdrive/My Drive/Colab Notebooks/DUQ_CIFAR_75.pt\"))\n",
        "else:\n",
        "  ensemble = [\n",
        "            ResNet(input_size, num_classes).cuda() for _ in range(5)\n",
        "        ]\n",
        "  model_CIFAR = torch.nn.ModuleList(ensemble);\n",
        "  model_CIFAR.load_state_dict(torch.load('/content/gdrive/My Drive/Colab Notebooks/CIFAR10_5_ensemble.pt'))\n",
        "model_CIFAR.eval()\n",
        "\n",
        "b=50\n",
        "r=len(test_dataset_CIFAR)\n",
        "ls=[]\n",
        "m=test_dataset_CIFAR[0][0].max()\n",
        "\n",
        "\n",
        "for i in range(int(r/b)+1):\n",
        "    data=[]\n",
        "    for j in range(0,b):\n",
        "        c=i*b+j\n",
        "        if(c>=r):\n",
        "          break\n",
        "        data.append(test_dataset_CIFAR[c][0])\n",
        "    if(len(data)==0):\n",
        "      break\n",
        "    data=torch.stack(data)\n",
        "    ls.append(data)\n",
        "\n",
        "\n",
        "kernel_distace_SVHN = np.zeros((r))\n",
        "for i in range(int(r/b)+1):\n",
        "  data=[]\n",
        "  for j in range(0,b):\n",
        "      c=i*b+j\n",
        "      if(c>=r):\n",
        "        break\n",
        "      data.append(test_dataset_SVHN[c][0])\n",
        "  if(len(data)==0):\n",
        "    break\n",
        "  data=torch.stack(data)\n",
        "\n",
        "  if mod=='DUQ':\n",
        "    res=model_CIFAR(data.cuda())[1].max(1)[0].detach()\n",
        "  else:\n",
        "    predictions = torch.stack([model(data.cuda()) for model in model_CIFAR])\n",
        "    mean_prediction = torch.mean(predictions.exp(), dim=0)\n",
        "    res=torch.sum(mean_prediction * torch.log(mean_prediction), dim=1)\n",
        "\n",
        "  for j in range(0,len(res)):\n",
        "      c=i*b+j\n",
        "      kernel_distace_SVHN[c] = res[j].item()\n",
        "  if(i%20==0):\n",
        "    print(c)\n",
        "\n",
        "\n",
        "fig = plt.figure(figsize=(10,20))\n",
        "er=[0,0.05,0.07,0.08,0.1,0.2,0.4]\n",
        "kernel_distace_CIFAR_e=[]\n",
        "\n",
        "for p,e in enumerate(er):\n",
        "    print(p,'noise=',e)\n",
        "    sample=(test_dataset_CIFAR[10][0]+e*m*torch.randn(3,32,32)).numpy()\n",
        "\n",
        "    kernel_distace_CIFAR_e.append(np.zeros(r))\n",
        "    c=0\n",
        "    for data in ls:\n",
        "      Data=data+e*m*torch.randn(len(data),3,32,32)\n",
        "      if mod=='DUQ':\n",
        "          res=model_CIFAR(Data.cuda())[1].max(1)[0].detach()\n",
        "      else:\n",
        "          predictions = torch.stack([model(Data.cuda()) for model in model_CIFAR])\n",
        "          mean_prediction = torch.mean(predictions.exp(), dim=0)\n",
        "          res=torch.sum(mean_prediction * torch.log(mean_prediction), dim=1)\n",
        "\n",
        "      for j in range(0,len(res)):\n",
        "        kernel_distace_CIFAR_e[p][c] = res[j].item()\n",
        "        c+=1\n",
        "      if(c%1000==0):\n",
        "        print(c)\n",
        "\n",
        "\n",
        "    ax=fig.add_subplot(len(er),2,2*p+1,xlabel='Value',ylabel='Freq',title='noise = '+str(e))\n",
        "    n, bins, patches = ax.hist(x=[kernel_distace_CIFAR_e[0],kernel_distace_CIFAR_e[p],kernel_distace_SVHN], bins='auto', color=['#0000FF','#FF0000','#00FF00']                      )\n",
        "    ax.legend(['CIFAR','Noise','SVHN'],loc='upper left')\n",
        "    \n",
        "    ax=fig.add_subplot(len(er),2,2*p+2)\n",
        "    sample-=sample.min()   \n",
        "    sample/=sample.max()\n",
        "    sample=sample.transpose((1,2,0))\n",
        "    ax.imshow(sample)\n",
        "\n",
        "fig.tight_layout()\n",
        "plt.show()\n"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Files already downloaded and verified\n",
            "Files already downloaded and verified\n",
            "Using downloaded and verified file: ./data/SVHN/train_32x32.mat\n",
            "Using downloaded and verified file: ./data/SVHN/test_32x32.mat\n",
            "49\n",
            "1049\n",
            "2049\n",
            "3049\n",
            "4049\n",
            "5049\n",
            "6049\n"
          ],
          "name": "stdout"
        }
      ]
    }
  ]
}