{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "rejexp_aleatoric.ipynb",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "code",
      "metadata": {
        "id": "7NdrkVE3JhPt"
      },
      "source": [
        "from google.colab import drive\n",
        "drive.mount('/content/gdrive')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "31kViAtB9PEe"
      },
      "source": [
        "import torch\n",
        "import matplotlib.pyplot as plt \n",
        "import numpy as np\n",
        "from utils.cnn_duq import CNN_DUQ\n",
        "from utils.datasets import all_datasets\n",
        "from utils.cnn_duq import SoftmaxModel as CNN\n",
        "\n",
        "from utils.resnet import ResNet\n",
        "from utils.resnet_duq import ResNet_DUQ\n",
        "from utils.evaluate_ood import get_cifar_svhn_ood, get_auroc_classification\n",
        "\n",
        "mod='FMnist' #['CIFAR10','FMnist']\n",
        "\n",
        "if mod=='FMnist':\n",
        "    ds1 = all_datasets[\"FashionMNIST\"]()\n",
        "    ds2 = all_datasets[\"MNIST\"]()\n",
        "    input_size = 28\n",
        "    num_classes = 10\n",
        "    embedding_size = 256\n",
        "    learnable_length_scale = False\n",
        "    gamma = 0.999\n",
        "    length_scale = 0.1\n",
        "    d=28\n",
        "    c=1\n",
        "\n",
        "    model = CNN_DUQ(\n",
        "    input_size,\n",
        "    num_classes,\n",
        "    embedding_size,\n",
        "    learnable_length_scale,\n",
        "    length_scale,\n",
        "    gamma,\n",
        "    ).cuda()\n",
        "    model.load_state_dict(torch.load('/content/gdrive/My Drive/Colab Notebooks/DUQ_FM_30_FULL.pt'))\n",
        "\n",
        "    model_new = CNN_DUQ(\n",
        "    input_size,\n",
        "    num_classes,\n",
        "    embedding_size,\n",
        "    learnable_length_scale,\n",
        "    length_scale,\n",
        "    gamma, True\n",
        "    ).cuda()\n",
        "    model_new.load_state_dict(torch.load('/content/gdrive/My Drive/Colab Notebooks/DUQ_FM_30_FULL (1).pt'))\n",
        "\n",
        "    ensemble = [CNN(input_size, num_classes).cuda() for _ in range(5)]\n",
        "    ensemble = torch.nn.ModuleList(ensemble);\n",
        "    ensemble.load_state_dict(torch.load('/content/gdrive/My Drive/Colab Notebooks/FM_5_ensemble_30.pt'))\n",
        "\n",
        "else:\n",
        "    ds1 = all_datasets[\"CIFAR10\"]()\n",
        "    ds2 = all_datasets[\"SVHN\"]()\n",
        "    length_scale = 0.1\n",
        "    input_size, num_classes, dataset, test_dataset = ds1\n",
        "    centroid_size=512\n",
        "    model_output_size=512 \n",
        "    gamma = 0.999\n",
        "    d=32\n",
        "    c=3\n",
        "\n",
        "    model = ResNet_DUQ(\n",
        "            input_size, num_classes, centroid_size, model_output_size, length_scale, gamma\n",
        "        ).cuda()\n",
        "    model.load_state_dict(torch.load('/content/gdrive/My Drive/Colab Notebooks/DUQ_CIFAR_75.pt'))\n",
        "    ensemble = [\n",
        "            ResNet(input_size, num_classes).cuda() for _ in range(5)\n",
        "        ]\n",
        "    ensemble = torch.nn.ModuleList(ensemble);\n",
        "    ensemble.load_state_dict(torch.load('/content/gdrive/My Drive/Colab Notebooks/CIFAR10_5_ensemble.pt'))\n",
        "    \n",
        "model.eval()\n",
        "ensemble.eval()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "NT-7TDRzBwAa"
      },
      "source": [
        "args={'ensemble':5}\n",
        "input_size, num_classes, _ , test_dataset_n = ds1\n",
        "_ , _ , _ , test_dataset_o = ds2\n",
        "er=[0, 0.05, 0.1, 0.3, 0.5]\n",
        "fig = plt.figure(figsize=(12,4*len(er)))\n",
        "\n",
        "\n",
        "m=test_dataset_n[0][0].max()\n",
        "rejection_list = [0.1 , 0.2 , 0.3 ,0.4 , 0.5 ,0.6 , 0.7 , 0.8 , 0.9]\n",
        "\n",
        "\n",
        "for p,e in enumerate(er):\n",
        "    sample=(test_dataset_n[10][0]+e*m*torch.randn(c,d,d)).numpy()\n",
        "    print(p,'noise=',e)\n",
        "    test_dataset_e=[]\n",
        "    for i in range(len(test_dataset_n)):\n",
        "      test_dataset_e.append([test_dataset_n[i][0]+torch.randn(c,d,d)*e*m,100])\n",
        "\n",
        "\n",
        "    # Data preparation\n",
        "\n",
        "    Data = test_dataset_n+test_dataset_e\n",
        "    num=len(Data)\n",
        "    b=100\n",
        "    r=len(Data)\n",
        "    ls=[]\n",
        "    for i in range(int(r/b)+1):\n",
        "        data=[]\n",
        "        target=[]\n",
        "        for j in range(0,b):\n",
        "            cnt=i*b+j\n",
        "            if(cnt>=r):\n",
        "              break\n",
        "            data.append(Data[cnt][0])\n",
        "            target.append(Data[cnt][1])\n",
        "        if(len(data)==0):\n",
        "          break\n",
        "        data=torch.stack(data)\n",
        "        target=torch.tensor(target)\n",
        "        ls.append((data,target))\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "    #DUQ calculation\n",
        "\n",
        "    target = np.zeros((num,))\n",
        "    confidence_DUQ = np.zeros((num,))\n",
        "    pred_DUQ = np.zeros((num,))\n",
        "    cn=0\n",
        "    i=0\n",
        "    for Data,t1 in ls:\n",
        "      with torch.no_grad():\n",
        "        _ , output = model(Data.cuda())\n",
        "        c1,p1= output.max(1)\n",
        "\n",
        "        for j in range(0,len(Data)):\n",
        "          confidence_DUQ[cn]=c1[j]\n",
        "          target[cn]=t1[j]\n",
        "          pred_DUQ[cn]=p1[j]\n",
        "          cn+=1\n",
        "        i+=1\n",
        "        if(i%20==0):\n",
        "          print(cn)\n",
        "\n",
        "    a  = np.concatenate((target.reshape(-1,1),pred_DUQ.reshape(-1,1),confidence_DUQ.reshape(-1,1)) , axis=1)\n",
        "    x  = a[a[:,-1].argsort()]\n",
        "    accuracy_DUQ = np.zeros((len(rejection_list),1))\n",
        "    rejected_DUQ = np.zeros((len(rejection_list),1))\n",
        "    i=0\n",
        "    for reject in rejection_list :\n",
        "      y = x[:][int(reject*num):]\n",
        "      accuracy_DUQ[i] = ((y[:,0]==y[:,1]).sum())/((1-reject)*num)\n",
        "      rejected_DUQ[i] = reject*100\n",
        "      i+=1\n",
        "\n",
        "\n",
        "\n",
        "    #new DUQ model calculation\n",
        "\n",
        "    target = np.zeros((num,))\n",
        "    confidence_DUQ = np.zeros((num,))\n",
        "    pred_DUQ = np.zeros((num,))\n",
        "    cn=0\n",
        "    i=0\n",
        "    for Data,t1 in ls:\n",
        "      with torch.no_grad():\n",
        "        _ , output = model_new(Data.cuda())\n",
        "        c1,p1= output.max(1)\n",
        "\n",
        "        for j in range(0,len(Data)):\n",
        "          confidence_DUQ[cn]=c1[j]\n",
        "          target[cn]=t1[j]\n",
        "          pred_DUQ[cn]=p1[j]\n",
        "          cn+=1\n",
        "        i+=1\n",
        "        if(i%20==0):\n",
        "          print(cn)\n",
        "\n",
        "    a  = np.concatenate((target.reshape(-1,1),pred_DUQ.reshape(-1,1),confidence_DUQ.reshape(-1,1)) , axis=1)\n",
        "    x  = a[a[:,-1].argsort()]\n",
        "    accuracy_DUQ_new = np.zeros((len(rejection_list),1))\n",
        "    rejected_DUQ_new = np.zeros((len(rejection_list),1))\n",
        "    i=0\n",
        "    for reject in rejection_list :\n",
        "      y = x[:][int(reject*num):]\n",
        "      accuracy_DUQ_new[i] = ((y[:,0]==y[:,1]).sum())/((1-reject)*num)\n",
        "      rejected_DUQ_new[i] = reject*100\n",
        "      i+=1\n",
        "\n",
        "\n",
        "\n",
        "    #DE caluclation\n",
        "\n",
        "    target = np.zeros((num))\n",
        "    confidence_DE = np.zeros((num,))\n",
        "    pred_DE = np.zeros((num,))\n",
        "    i=0\n",
        "    cn=0\n",
        "    for Data,t1 in ls:\n",
        "      with torch.no_grad():\n",
        "          predictions = torch.stack([model(Data.cuda()) for model in ensemble])\n",
        "          mean_prediction = torch.mean(predictions.exp(), dim=0)\n",
        "          p1 = mean_prediction.max(1)[1]\n",
        "          c1= torch.sum(mean_prediction * torch.log(mean_prediction), dim=1)\n",
        "\n",
        "          for j in range(0,len(Data)):\n",
        "              confidence_DE[cn]=c1[j]\n",
        "              target[cn]=t1[j]\n",
        "              pred_DE[cn]=p1[j]\n",
        "              cn+=1\n",
        "          i+=1\n",
        "          if(i%20==0):\n",
        "              print(cn)\n",
        "\n",
        "    a  = np.concatenate((target.reshape(-1,1),pred_DE.reshape(-1,1),confidence_DE.reshape(-1,1)) , axis=1)\n",
        "    x  = a[a[:,-1].argsort()]\n",
        "    accuracy_DE = np.zeros((len(rejection_list),1))\n",
        "    rejected_DE = np.zeros((len(rejection_list),1))\n",
        "    i=0\n",
        "    for reject in rejection_list :\n",
        "      y = x[:][int(reject*num):]\n",
        "      accuracy_DE[i] = ((y[:,0]==y[:,1]).sum())/((1-reject)*num) \n",
        "      rejected_DE[i] = reject*100\n",
        "      i+=1\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "    ax=fig.add_subplot(len(er),2,2*p+1,xlabel='%',ylabel='Acc',title='noise = '+str(e),xlim=(0,100),ylim=(0.4,1.01))\n",
        "\n",
        "    ax.plot(rejected_DUQ, accuracy_DUQ, color='blue', linewidth = 2, \n",
        "            marker='o', markerfacecolor='blue', markersize=5)\n",
        "    ax.plot(rejected_DUQ_new, accuracy_DUQ_new, color='red', linewidth = 2, \n",
        "            marker='o', markerfacecolor='red', markersize=5)\n",
        "    ax.plot(rejected_DE , accuracy_DE , color='orange', linewidth = 2, \n",
        "            marker='o', markerfacecolor='orange', markersize=5 )\n",
        "    ax.legend(['DUQ','Ls Control DUQ', '5-Deep Ensemble'])\n",
        "\n",
        "    ax=fig.add_subplot(len(er),2,2*p+2)\n",
        "    sample-=sample.min()   \n",
        "    sample/=sample.max()\n",
        "    if(sample.shape[0]>1):\n",
        "      sample=sample.transpose((1,2,0))\n",
        "    else:\n",
        "      sample=sample.reshape(d,d)\n",
        "    ax.imshow(sample)\n",
        " \n",
        "fig.tight_layout()\n",
        "plt.show()\n"
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}