{
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "##### Preperation:\n",
        "\n",
        "To do:\n",
        "\n",
        "Download ImageNet training data and put under ./ILSVRC2012_img_train/\n",
        "\n",
        "Download ImageNet validation data and put under ./ILSVRC2012_img_val/"
      ],
      "metadata": {
        "id": "0yuxFQLq9bYa"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "z4LrJI--XSvO",
        "outputId": "1697e496-254d-4767-942e-ab2d78963063"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Wed Jul 26 20:59:52 2023       \n",
            "+-----------------------------------------------------------------------------+\n",
            "| NVIDIA-SMI 525.105.17   Driver Version: 525.105.17   CUDA Version: 12.0     |\n",
            "|-------------------------------+----------------------+----------------------+\n",
            "| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
            "| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |\n",
            "|                               |                      |               MIG M. |\n",
            "|===============================+======================+======================|\n",
            "|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |\n",
            "| N/A   34C    P8     9W /  70W |      0MiB / 15360MiB |      0%      Default |\n",
            "|                               |                      |                  N/A |\n",
            "+-------------------------------+----------------------+----------------------+\n",
            "                                                                               \n",
            "+-----------------------------------------------------------------------------+\n",
            "| Processes:                                                                  |\n",
            "|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |\n",
            "|        ID   ID                                                   Usage      |\n",
            "|=============================================================================|\n",
            "|  No running processes found                                                 |\n",
            "+-----------------------------------------------------------------------------+\n"
          ]
        }
      ],
      "source": [
        "gpu_info = !nvidia-smi\n",
        "gpu_info = '\\n'.join(gpu_info)\n",
        "if gpu_info.find('failed') >= 0:\n",
        "  print('Not connected to a GPU')\n",
        "else:\n",
        "  print(gpu_info)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YuaJeSxRGNXD"
      },
      "source": [
        "##### Architecture"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1Sb2wRNpFLlB"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "import torch.nn as nn\n",
        "import math\n",
        "import torch.utils.model_zoo as model_zoo\n",
        "import torch.nn.functional as F\n",
        "\n",
        "__all__ = ['ResNet', 'resnet18', 'resnet50', ]\n",
        "\n",
        "\n",
        "normalization = nn.BatchNorm2d\n",
        "\n",
        "model_urls = {\n",
        "    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',\n",
        "    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',\n",
        "    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',\n",
        "    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',\n",
        "    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',\n",
        "}\n",
        "\n",
        "\n",
        "def conv3x3(in_planes, out_planes, stride=1):\n",
        "    \"\"\"3x3 convolution with padding\"\"\"\n",
        "    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,\n",
        "                     padding=1, bias=False)\n",
        "\n",
        "\n",
        "class Identity(nn.Module):\n",
        "    def forward(self, input):\n",
        "        return input + 0.0\n",
        "\n",
        "class BasicBlock(nn.Module):\n",
        "    expansion = 1\n",
        "\n",
        "    def __init__(self, inplanes, planes, stride=1, downsample=None):\n",
        "        super(BasicBlock, self).__init__()\n",
        "        self.conv1 = conv3x3(inplanes, planes, stride)\n",
        "        self.bn1 = normalization(planes)\n",
        "        self.relu = nn.ReLU(inplace=False)\n",
        "        self.conv2 = conv3x3(planes, planes)\n",
        "        self.bn2 = normalization(planes)\n",
        "        self.shortcut = Identity()\n",
        "        self.downsample = downsample\n",
        "        self.stride = stride\n",
        "\n",
        "    def forward(self, x):\n",
        "        residual = x\n",
        "\n",
        "        out = self.conv1(x)\n",
        "        out = self.bn1(out)\n",
        "        out = self.relu(out)\n",
        "\n",
        "        out = self.conv2(out)\n",
        "        out = self.bn2(out)\n",
        "\n",
        "        if self.downsample is not None:\n",
        "            residual = self.downsample(x)\n",
        "\n",
        "        out = out + residual\n",
        "        out = self.shortcut(out)\n",
        "        out = self.relu(out)\n",
        "\n",
        "        return out\n",
        "\n",
        "    def forward_masked(self, x, mask_weight=None, mask_bias=None):\n",
        "        residual = x\n",
        "\n",
        "        out = self.conv1(x)\n",
        "        out = self.bn1(out)\n",
        "        out = self.relu(out)\n",
        "\n",
        "        out = self.conv2(out)\n",
        "        out = self.bn2(out)\n",
        "\n",
        "        if self.downsample is not None:\n",
        "            residual = self.downsample(x)\n",
        "\n",
        "        out = out + residual\n",
        "        out = self.shortcut(out)\n",
        "        out = self.relu(out)\n",
        "\n",
        "        if mask_weight is not None:\n",
        "            out = out * mask_weight[None,:,None,None]\n",
        "        if mask_bias is not None:\n",
        "            out = out + mask_bias[None,:,None,None]\n",
        "        return out\n",
        "\n",
        "    def forward_threshold(self, x, threshold=1e10):\n",
        "        residual = x\n",
        "\n",
        "        out = self.conv1(x)\n",
        "        out = self.bn1(out)\n",
        "        out = self.relu(out)\n",
        "\n",
        "        out = self.conv2(out)\n",
        "        out = self.bn2(out)\n",
        "\n",
        "        if self.downsample is not None:\n",
        "            residual = self.downsample(x)\n",
        "\n",
        "        out = out + residual\n",
        "        out = self.relu(out)\n",
        "\n",
        "        b, c, w, h = out.shape\n",
        "        mask = out.view(b, c, -1).mean(2) < threshold\n",
        "        out = mask[:, :, None, None] * out\n",
        "        # print(mask.sum(1).float().mean(0))\n",
        "        return out\n",
        "\n",
        "\n",
        "class WideBasicBlock(nn.Module):\n",
        "    expansion = 4\n",
        "\n",
        "    def __init__(self, inplanes, planes, stride=1, downsample=None):\n",
        "        super(WideBasicBlock, self).__init__()\n",
        "        self.conv1 = conv3x3(inplanes, planes, stride)\n",
        "        self.bn1 = normalization(planes)\n",
        "        self.relu = nn.ReLU(inplace=False)\n",
        "        self.conv2 = conv3x3(planes, planes * 4)\n",
        "        self.bn2 = normalization(planes * 4)\n",
        "        self.downsample = downsample\n",
        "        self.stride = stride\n",
        "\n",
        "    def forward(self, x):\n",
        "        residual = x\n",
        "\n",
        "        out = self.conv1(x)\n",
        "        out = self.bn1(out)\n",
        "        out = self.relu(out)\n",
        "\n",
        "        out = self.conv2(out)\n",
        "        out = self.bn2(out)\n",
        "\n",
        "        if self.downsample is not None:\n",
        "            residual = self.downsample(x)\n",
        "\n",
        "        out = out + residual\n",
        "        out = self.relu(out)\n",
        "\n",
        "        return out\n",
        "\n",
        "    def forward_masked(self, x, mask=None):\n",
        "        residual = x\n",
        "\n",
        "        out = self.conv1(x)\n",
        "        out = self.bn1(out)\n",
        "        out = self.relu(out)\n",
        "\n",
        "        out = self.conv2(out)\n",
        "        out = self.bn2(out)\n",
        "\n",
        "        if self.downsample is not None:\n",
        "            residual = self.downsample(x)\n",
        "\n",
        "\n",
        "        out = out + residual\n",
        "        if mask is not None:\n",
        "            out = out * mask[None,:,None,None]# + self.bn2.bias[None,:,None,None] * (1 - mask[None,:,None,None])\n",
        "\n",
        "        out = self.relu(out)\n",
        "\n",
        "        return out\n",
        "\n",
        "class Bottleneck(nn.Module):\n",
        "    expansion = 4\n",
        "\n",
        "    def __init__(self, inplanes, planes, stride=1, downsample=None):\n",
        "        super(Bottleneck, self).__init__()\n",
        "        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)\n",
        "        self.bn1 = normalization(planes)\n",
        "        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,\n",
        "                               padding=1, bias=False)\n",
        "        self.bn2 = normalization(planes)\n",
        "        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)\n",
        "        self.bn3 = normalization(planes * 4)\n",
        "        self.relu = nn.ReLU(inplace=False)\n",
        "        self.shortcut = Identity()\n",
        "        self.downsample = downsample\n",
        "        self.stride = stride\n",
        "\n",
        "    def forward(self, x):\n",
        "        residual = x\n",
        "\n",
        "        out = self.conv1(x)\n",
        "        out = self.bn1(out)\n",
        "        out = self.relu(out)\n",
        "\n",
        "        out = self.conv2(out)\n",
        "        out = self.bn2(out)\n",
        "        out = self.relu(out)\n",
        "\n",
        "        out = self.conv3(out)\n",
        "        out = self.bn3(out)\n",
        "\n",
        "        if self.downsample is not None:\n",
        "            residual = self.downsample(x)\n",
        "\n",
        "        out = out + residual\n",
        "        out = self.shortcut(out)\n",
        "        out = self.relu(out)\n",
        "\n",
        "        return out\n",
        "\n",
        "    def forward_masked(self, x, mask_weight=None, mask_bias=None):\n",
        "        residual = x\n",
        "\n",
        "        out = self.conv1(x)\n",
        "        out = self.bn1(out)\n",
        "        out = self.relu(out)\n",
        "\n",
        "        out = self.conv2(out)\n",
        "        out = self.bn2(out)\n",
        "        out = self.relu(out)\n",
        "\n",
        "        out = self.conv3(out)\n",
        "        out = self.bn3(out)\n",
        "\n",
        "        if self.downsample is not None:\n",
        "            residual = self.downsample(x)\n",
        "\n",
        "        out = out + residual\n",
        "        if mask_weight is not None:\n",
        "            out = out * mask_weight[None,:,None,None]\n",
        "        if mask_bias is not None:\n",
        "            out = out + mask_bias[None,:,None,None]\n",
        "        out = self.relu(out)\n",
        "        return out\n",
        "\n",
        "    def forward_threshold(self, x, threshold=1e10):\n",
        "        residual = x\n",
        "\n",
        "        out = self.conv1(x)\n",
        "        out = self.bn1(out)\n",
        "        out = self.relu(out)\n",
        "\n",
        "        out = self.conv2(out)\n",
        "        out = self.bn2(out)\n",
        "        out = self.relu(out)\n",
        "\n",
        "        out = self.conv3(out)\n",
        "        out = self.bn3(out)\n",
        "\n",
        "        if self.downsample is not None:\n",
        "            residual = self.downsample(x)\n",
        "\n",
        "        out = out + residual\n",
        "        out = self.relu(out)\n",
        "        b, c, w, h = out.shape\n",
        "        mask = out.view(b, c, -1).mean(2) < threshold\n",
        "        out = mask[:, :, None, None] * out\n",
        "\n",
        "        return out\n",
        "\n",
        "\n",
        "class AbstractResNet(nn.Module):\n",
        "\n",
        "    def __init__(self, block, layers, num_classes=1000):\n",
        "        super(AbstractResNet, self).__init__()\n",
        "        self.inplanes = 64\n",
        "        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,\n",
        "                               bias=False)\n",
        "        self.bn1 = normalization(64)\n",
        "        self.relu = nn.ReLU(inplace=False)\n",
        "        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n",
        "        self.layer1 = self._make_layer(block, 64, layers[0])\n",
        "        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)\n",
        "        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)\n",
        "        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)\n",
        "        self.avgpool = nn.AvgPool2d(7, stride=1)\n",
        "\n",
        "    def _initial_weight(self):\n",
        "        for m in self.modules():\n",
        "            if isinstance(m, nn.Conv2d):\n",
        "                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\n",
        "                m.weight.data.normal_(0, math.sqrt(2. / n))\n",
        "            elif isinstance(m, nn.BatchNorm2d):\n",
        "                m.weight.data.fill_(1)\n",
        "                m.bias.data.zero_()\n",
        "\n",
        "    def _make_layer(self, block, planes, blocks, stride=1):\n",
        "        downsample = None\n",
        "        if stride != 1 or self.inplanes != planes * block.expansion:\n",
        "            downsample = nn.Sequential(\n",
        "                nn.Conv2d(self.inplanes, planes * block.expansion,\n",
        "                          kernel_size=1, stride=stride, bias=False),\n",
        "                normalization(planes * block.expansion),\n",
        "            )\n",
        "\n",
        "        layers = []\n",
        "        layers.append(block(self.inplanes, planes, stride, downsample))\n",
        "        self.inplanes = planes * block.expansion\n",
        "        for i in range(1, blocks):\n",
        "            layers.append(block(self.inplanes, planes))\n",
        "\n",
        "        return nn.Sequential(*layers)\n",
        "\n",
        "    def features(self, x):\n",
        "        x = self.maxpool(self.relu(self.bn1(self.conv1(x))))\n",
        "        x = self.layer4(self.layer3(self.layer2(self.layer1(x))))\n",
        "        return x\n",
        "\n",
        "    def forward(self, x):\n",
        "        x = self.features(x)\n",
        "        x = self.avgpool(x)\n",
        "        x = x.view(x.size(0), -1)\n",
        "        x = self.fc(x)\n",
        "        return x\n",
        "\n",
        "    def load_state_dict(self, state_dict, strict=True):\n",
        "        missing_keys = []\n",
        "        unexpected_keys = []\n",
        "        error_msgs = []\n",
        "\n",
        "        # copy state_dict so _load_from_state_dict can modify it\n",
        "        metadata = getattr(state_dict, '_metadata', None)\n",
        "        state_dict = state_dict.copy()\n",
        "        if metadata is not None:\n",
        "            state_dict._metadata = metadata\n",
        "\n",
        "        def load(module, prefix=''):\n",
        "            local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})\n",
        "            module._load_from_state_dict(\n",
        "                state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)\n",
        "            for name, child in module._modules.items():\n",
        "                if child is not None:\n",
        "                    load(child, prefix + name + '.')\n",
        "\n",
        "        load(self)\n",
        "\n",
        "        if strict:\n",
        "            error_msg = ''\n",
        "            if len(unexpected_keys) > 0:\n",
        "                error_msgs.insert(\n",
        "                    0, 'Unexpected key(s) in state_dict: {}. '.format(\n",
        "                        ', '.join('\"{}\"'.format(k) for k in unexpected_keys)))\n",
        "            if len(missing_keys) > 0:\n",
        "                error_msgs.insert(\n",
        "                    0, 'Missing key(s) in state_dict: {}. '.format(\n",
        "                        ', '.join('\"{}\"'.format(k) for k in missing_keys)))\n",
        "\n",
        "        if len(error_msgs) > 0:\n",
        "            print('Warning(s) in loading state_dict for {}:\\n\\t{}'.format(self.__class__.__name__, \"\\n\\t\".join(error_msgs)))\n",
        "\n",
        "\n",
        "class ResNet(AbstractResNet):\n",
        "\n",
        "    def __init__(self, block, layers, num_classes=1000):\n",
        "        super(ResNet, self).__init__(block, layers, num_classes)\n",
        "        self.fc = nn.Linear(512 * block.expansion, num_classes)\n",
        "        self._initial_weight()\n",
        "\n",
        "    def forward_masked(self, x, mask_weight=None, mask_bias=None):\n",
        "        x = self.maxpool(self.relu(self.bn1(self.conv1(x))))\n",
        "        x = self.layer3(self.layer2(self.layer1(x)))\n",
        "        x = self.layer4[:-1](x)\n",
        "        x = self.layer4[-1].forward_masked(x, mask_weight=mask_weight, mask_bias=mask_bias)\n",
        "        x = self.avgpool(x)\n",
        "        x = x.view(x.size(0), -1)\n",
        "        return self.fc(x)\n",
        "\n",
        "    def forward_threshold(self, x, threshold=1e10):\n",
        "        x = self.maxpool(self.relu(self.bn1(self.conv1(x))))\n",
        "        # x = self.layer3(self.layer2(self.layer1(x)))\n",
        "        # x = self.layer4[:-1](x)\n",
        "        # x = self.layer4[-1].forward_threshold(x, threshold=1e10)\n",
        "        x= self.layer4(self.layer3(self.layer2(self.layer1(x))))\n",
        "        x = self.avgpool(x)\n",
        "        x = x.clip(max=threshold)\n",
        "        # mask = x < threshold\n",
        "        # mask = mask.float()\n",
        "        # x = mask * x + (1-mask.float()) * (2.)\n",
        "        x = x.view(x.size(0), -1)\n",
        "\n",
        "        # if self.fc.weight.data.min().item() < 0:\n",
        "        #     w = self.fc.weight.data\n",
        "        #     w = w - self.fc.weight.data.min()\n",
        "        #     self.fc.weight.data = w\n",
        "        x = self.fc(x)\n",
        "        return x\n",
        "\n",
        "    def feature_list(self, x):\n",
        "        out_list = []\n",
        "        out = self.maxpool(F.relu(self.bn1(self.conv1(x))))\n",
        "        out = self.layer1(out)\n",
        "        out_list.append(out)\n",
        "        out = self.layer2(out)\n",
        "        out_list.append(out)\n",
        "        out = self.layer3(out)\n",
        "        out_list.append(out)\n",
        "        out = self.layer4(out)\n",
        "        out_list.append(out)\n",
        "        out = self.avgpool(out)\n",
        "        # out = out.clip(max=1.0)\n",
        "        # out_list.append(out)\n",
        "        out = out.view(out.size(0), -1)\n",
        "        y = self.fc(out)\n",
        "        return y, out_list\n",
        "\n",
        "    def intermediate_forward(self, x, layer_index):\n",
        "    # if layer_index >= 0:\n",
        "        out = self.maxpool(F.relu(self.bn1(self.conv1(x))))\n",
        "    # if layer_index >= 1:\n",
        "        out = self.layer1(out)\n",
        "    # if layer_index >= 2:\n",
        "        out = self.layer2(out)\n",
        "    # if layer_index >= 3:\n",
        "        out = self.layer3(out)\n",
        "    # if layer_index >= 4:\n",
        "        out = self.layer4(out)\n",
        "        out = self.avgpool(out)\n",
        "        # out = out.clip(max=1.0)\n",
        "        return out\n",
        "\n",
        "\n",
        "def resnet18(pretrained=False, **kwargs):\n",
        "    model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)\n",
        "    if pretrained:\n",
        "        model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))\n",
        "    return model\n",
        "\n",
        "\n",
        "def resnet50(pretrained=False, **kwargs):\n",
        "    model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)\n",
        "    if pretrained:\n",
        "        model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))\n",
        "    return model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TJBYKDKlGihD"
      },
      "source": [
        "##### Load model:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rWg493a6GpRk",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "1935ec63-a1b3-45b0-8b68-af54c647a4a7"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Downloading: \"https://download.pytorch.org/models/resnet50-19c8e357.pth\" to /root/.cache/torch/hub/checkpoints/resnet50-19c8e357.pth\n",
            "100%|██████████| 97.8M/97.8M [00:00<00:00, 211MB/s]\n"
          ]
        }
      ],
      "source": [
        "model = resnet50(num_classes=1000, pretrained=True)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "_xmH5dA6IQV8",
        "outputId": "71e3309c-53f3-45b9-e59b-0ae383a777b2"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Number of model parameters: 25557032\n"
          ]
        }
      ],
      "source": [
        "model.cuda()\n",
        "model.eval()\n",
        "print('Number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()])))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5vpVuhaPRFVs"
      },
      "outputs": [],
      "source": [
        "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
        "class Normalize(nn.Module):\n",
        "    def __init__(self, mean, std):\n",
        "        super(Normalize, self).__init__()\n",
        "        self.mean = torch.tensor(mean, device=device)\n",
        "        self.std = torch.tensor(std, device=device)\n",
        "\n",
        "    def forward(self, input): # [batch, channel, dim_1, dim_2]\n",
        "        x = input.permute(0,2,3,1)#/ 255.0\n",
        "        #print(x.shape, self.mean.shape)\n",
        "        x = x - self.mean\n",
        "        x = x / self.std # [batch, dim_1, dim_2, channel]\n",
        "        return x.permute(0, 3, 1, 2)\n",
        "model_cw = nn.Sequential(\n",
        "    Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),\n",
        "    model\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BvebTk7tJjj5"
      },
      "source": [
        "##### Load in-distribution data"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "StTftg8vq7Cl"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "from torchvision.datasets import ImageNet, ImageFolder\n",
        "from torch.utils.data import Subset\n",
        "from torchvision import transforms\n",
        "from easydict import EasyDict\n",
        "import numpy as np\n",
        "\n",
        "cw_transform_test_largescale = transforms.Compose([\n",
        "    transforms.Resize(256),\n",
        "    transforms.CenterCrop(224),\n",
        "    transforms.ToTensor(),\n",
        "])\n",
        "\n",
        "def ld_cw_ImageNetVal(b_size = 64):\n",
        "    \"\"\"Load training and test data.\"\"\"\n",
        "\n",
        "    train_dataset = ImageFolder(root=\"./ILSVRC2012_img_train/\", transform=cw_transform_test_largescale)\n",
        "    test_dataset = ImageFolder(root=\"./ILSVRC2012_img_val/\", transform=cw_transform_test_largescale)\n",
        "\n",
        "    train_loader = torch.utils.data.DataLoader(\n",
        "        train_dataset, batch_size=b_size, shuffle=False, num_workers=2, pin_memory=True\n",
        "    )\n",
        "\n",
        "    test_loader = torch.utils.data.DataLoader(\n",
        "        test_dataset, batch_size=b_size, shuffle=False, num_workers=2, pin_memory=True\n",
        "    )\n",
        "    return EasyDict(train=train_loader, test=test_loader)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nPrhyJx_soyo"
      },
      "source": [
        "##### Prepare for out-distribution data"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aKH6wnxo81mP"
      },
      "source": [
        "###### Texture"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PeqL2OVPsnTW"
      },
      "outputs": [],
      "source": [
        "import torchvision\n",
        "from torchvision.datasets import CIFAR10, CIFAR100, SVHN, DTD\n",
        "from torch.utils.data import ConcatDataset, Subset\n",
        "from easydict import EasyDict\n",
        "from torchvision.datasets import ImageFolder\n",
        "from torchvision import datasets, transforms\n",
        "from torch.utils.data import DataLoader\n",
        "import os\n",
        "\n",
        "def ld_Texture(b_size = 64):\n",
        "    \"\"\"Load training and test data.\"\"\"\n",
        "\n",
        "    # Load Texture dataset\n",
        "    train_dataset = DTD(root=\"/tmp/data/\", split = 'train', transform=cw_transform_test_largescale, download=True)\n",
        "    test_dataset = DTD(root=\"/tmp/data\", split = 'test', transform=cw_transform_test_largescale, download=True)\n",
        "    val_dataset = DTD(root=\"/tmp/data\", split = 'val', transform=cw_transform_test_largescale, download=True)\n",
        "\n",
        "    full_dataset = ConcatDataset([train_dataset, test_dataset, val_dataset])\n",
        "\n",
        "\n",
        "    data_loader = torch.utils.data.DataLoader(\n",
        "        full_dataset, batch_size=b_size, shuffle=False, num_workers=2\n",
        "    )\n",
        "\n",
        "    return EasyDict(test=data_loader)"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "###### Gaussian Noise"
      ],
      "metadata": {
        "id": "kxPy82dLMdgK"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from torch.utils.data import Dataset, DataLoader\n",
        "\n",
        "class GaussianDataset(Dataset):\n",
        "    def __init__(self, num_samples, image_size):\n",
        "        self.num_samples = num_samples\n",
        "        self.image_size = image_size\n",
        "        #self.data = np.random.normal(loc=mean, scale=std, size=(num_samples, 3, image_size, image_size)) + 0.5\n",
        "        self.images = torch.randn(self.num_samples,3,self.image_size,self.image_size) + 0.5\n",
        "        self.images = torch.clamp(self.images, 0, 1)\n",
        "\n",
        "    def __len__(self):\n",
        "        return self.num_samples\n",
        "\n",
        "    def __getitem__(self, idx):\n",
        "        return self.images[idx], 0"
      ],
      "metadata": {
        "id": "peZODodRM2fh"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def ld_Gaussian(b_size = 64):\n",
        "    \"\"\"Load training and test data.\"\"\"\n",
        "\n",
        "    test_dataset = GaussianDataset(50000, 224)\n",
        "    #test_dataset = ImageFolder(root=\"/content/gdrive/My Drive/ood_data/Places/\", transform=transform_test_largescale)\n",
        "\n",
        "\n",
        "    test_loader = torch.utils.data.DataLoader(\n",
        "        test_dataset, batch_size=b_size, shuffle=False, num_workers=2, pin_memory=True\n",
        "    )\n",
        "    return EasyDict(test=test_loader)"
      ],
      "metadata": {
        "id": "D-FStDqcNQlx"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7YGRAEbfvVNV"
      },
      "source": [
        "##### Extract weight vector and training mean"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WmYomxSkk39n"
      },
      "outputs": [],
      "source": [
        "model = model_cw\n",
        "model.eval()\n",
        "net = nn.Sequential(torch.nn.Flatten(), model[1].fc)\n",
        "feature_net = nn.Sequential(model[0], model[1].conv1, model[1].bn1, model[1].relu, model[1].maxpool, model[1].layer1,model[1].layer2, model[1].layer3, model[1].layer4, model[1].avgpool)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7Cj3S09O1fNd"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "for i, param in enumerate(net.parameters()):\n",
        "  if i == 0:\n",
        "    w = param.data.cpu().numpy()\n",
        "  else:\n",
        "    b = param.data.cpu().numpy()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hqBwFrGX3FoM"
      },
      "outputs": [],
      "source": [
        "def get_train_mean():\n",
        "  data = ld_cw_ImageNetVal(b_size=b_size).train\n",
        "  feature_matrix = np.zeros((len(data.dataset), 2048))\n",
        "  for idx, (x ,y) in enumerate(data,0):\n",
        "    x= x.to(device)\n",
        "    f_x = feature_net(x)\n",
        "    feature_matrix[idx*b_size:(idx+1)*b_size, :] = f_x.view(len(f_x),-1).detach().cpu().numpy()\n",
        "  return np.mean(feature_matrix,axis=0)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "dbudFVGDt19g",
        "outputId": "616372d2-a509-46e3-9e68-9c11b8fed1fb"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "array([0.38581224, 0.63779744, 0.48358519, ..., 0.35021857, 0.47102501,\n",
              "       0.36980772])"
            ]
          },
          "metadata": {},
          "execution_count": 38
        }
      ],
      "source": [
        "train_mean = get_train_mean()"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "##### OOD detection"
      ],
      "metadata": {
        "id": "0xTnbyy5oeCS"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def cal_NC_score(data, w, alpha):\n",
        "\n",
        "  for idx, (x ,y) in enumerate(data,0):\n",
        "    with torch.no_grad():\n",
        "      x= x.to(device)\n",
        "      f_x = feature_net(x)\n",
        "\n",
        "      ### get prediction\n",
        "      outputs = net(f_x)\n",
        "      values, nn_idx = outputs.max(1)\n",
        "      batch_predicted_label  = list(nn_idx.detach().to('cpu').numpy().flatten())\n",
        "\n",
        "      ### get l1 norm\n",
        "      l1_norm = LA.norm(f_x.view(len(f_x),-1).detach().cpu().numpy(), ord = 1, axis = 1)\n",
        "\n",
        "      ### get centered norm\n",
        "      f_x_centered = f_x.view(len(f_x),-1).detach().cpu().numpy() - train_mean\n",
        "      centered_norm = LA.norm(f_x_centered, ord = 2, axis = 1)\n",
        "\n",
        "      ### select corresponding columns\n",
        "      predicted_w = w[batch_predicted_label,:]\n",
        "\n",
        "      ### compute score\n",
        "      score += list(np.sum(predicted_w * f_x_centered, axis=1)/LA.norm(f_x_centered, ord = 2, axis = 1) + alpha*l1_norm)\n",
        "\n",
        "  return score"
      ],
      "metadata": {
        "id": "Ha5zhOh5I82I"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import math\n",
        "from sklearn import metrics\n",
        "\n",
        "def eval_missed_detection(in_sta, out_sta, false_alarm):\n",
        "  in_sta_sorted = np.sort(in_sta)\n",
        "  idx = math.ceil(len(in_sta)*(1-false_alarm))\n",
        "  threshold = in_sta_sorted[-idx]\n",
        "  FP = np.sum(out_sta > threshold)\n",
        "  TP = np.sum(in_sta > threshold)\n",
        "  print('FPR', FP/len(out_sta)*100)\n",
        "  return FP/len(out_sta)*100"
      ],
      "metadata": {
        "id": "kY9HPEv2JQ5_"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "name_dic = {'imagenet':ld_cw_ImageNetVal(b_size=b_size).test, 'Texture':ld_Texture(b_size = b_size).test, 'Gaussian':ld_Gaussian(b_size = b_size).test}"
      ],
      "metadata": {
        "id": "sB40XZvpJjbn"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "for alpha in [0.001, 0.01, 0.1, 1]:\n",
        "  print('#####################################')\n",
        "  print(alpha)\n",
        "\n",
        "  in_sta = cal_NC_score('imagenet', alpha)\n",
        "  out_sta = cal_NC_score('Gaussian', w, alpha)\n",
        "\n",
        "  fpr = eval_missed_detection(in_sta, out_sta, 0.05)\n",
        "  auroc = metrics.roc_auc_score([1]*len(in_sta) + [0]*len(out_sta), np.append(in_sta, out_sta))*100\n",
        "  print('AUROC:', metrics.roc_auc_score([1]*len(in_sta) + [0]*len(out_sta), np.append(in_sta, out_sta))*100)"
      ],
      "metadata": {
        "id": "Y7HQJt_PK8c5"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "alpha = 0.001\n",
        "\n",
        "in_sta = cal_NC_score('imagenet', alpha)\n",
        "out_sta = cal_NC_score('Texture', w, alpha)\n",
        "fpr = eval_missed_detection(in_sta, out_sta, 0.05)\n",
        "auroc = metrics.roc_auc_score([1]*len(in_sta) + [0]*len(out_sta), np.append(in_sta, out_sta))*100\n",
        "print('AUROC:', metrics.roc_auc_score([1]*len(in_sta) + [0]*len(out_sta), np.append(in_sta, out_sta))*100)"
      ],
      "metadata": {
        "id": "4ac2uWjuLxMw"
      },
      "execution_count": null,
      "outputs": []
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [
        "YuaJeSxRGNXD",
        "hsGZjopOWw7g",
        "hAa3aHzBQSkY",
        "aKH6wnxo81mP",
        "_Ewl7ZqH1a0k"
      ],
      "provenance": []
    },
    "gpuClass": "standard",
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}