{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "gpuType": "T4"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TjlkTQKVGQCf"
      },
      "outputs": [],
      "source": [
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as f\n",
        "from torch.utils.data import Dataset, DataLoader, ConcatDataset\n",
        "from collections import OrderedDict\n",
        "from matplotlib.colors import LinearSegmentedColormap\n",
        "import torchvision\n",
        "import torchvision.datasets as datasets\n",
        "import torchvision.transforms as transforms\n",
        "from torchvision.models import resnet18, ResNet18_Weights\n",
        "from torch.autograd import Function"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "class Head(nn.Module):\n",
        "    def __init__(self, dim_in, num_classes, transfer=nn.ReLU()):\n",
        "        super(Head, self).__init__()\n",
        "        self.transfer = transfer\n",
        "        self.logits = nn.Linear(dim_in, num_classes)\n",
        "\n",
        "        nn.init.constant_(self.logits.weight, 0)\n",
        "        nn.init.constant_(self.logits.bias, 0)\n",
        "\n",
        "    def forward(self, x):\n",
        "        transfer = self.transfer(x)\n",
        "        output = self.logits(transfer)\n",
        "        return output"
      ],
      "metadata": {
        "id": "u6ztf-25GlIv"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "class NLLPlusLoss(nn.Module):\n",
        "    def __init__(self, mask_nllplus):\n",
        "        super(NLLPlusLoss, self).__init__()\n",
        "        self.register_buffer('mask_nllplus', mask_nllplus)\n",
        "\n",
        "    def forward(self, input_logits, labels):\n",
        "        with torch.no_grad():\n",
        "            target_nllplus = self.mask_nllplus[labels.view(-1)]\n",
        "\n",
        "        loss = -input_logits.softmax(1).mul(target_nllplus).sum(1).log().mean()\n",
        "\n",
        "        return loss\n",
        "\n",
        "class StableNLLPlusLoss(nn.Module):\n",
        "    def __init__(self, mask_nllplus):\n",
        "        super(StableNLLPlusLoss, self).__init__()\n",
        "        self.register_buffer('mask_nllplus', mask_nllplus)\n",
        "\n",
        "    def forward(self, input_logits, labels):\n",
        "        with torch.no_grad():\n",
        "            target_nllplus = self.mask_nllplus[labels.view(-1)]\n",
        "\n",
        "        loss = StableNLLPlusLossOptimized.apply(input_logits, target_nllplus)\n",
        "        return loss\n",
        "\n",
        "class StableNLLPlusLossOptimized(Function):\n",
        "    @staticmethod\n",
        "    def forward(ctx, *args, **kwargs):\n",
        "        with torch.no_grad():\n",
        "            classification_head_val, target_nllplus = args\n",
        "\n",
        "            softmax = classification_head_val.softmax(1).mul_(target_nllplus).sum(1)\n",
        "\n",
        "            log_probs = -softmax.log_()\n",
        "            loss = log_probs.mean()\n",
        "\n",
        "            ctx.save_for_backward(classification_head_val, target_nllplus)\n",
        "\n",
        "            return loss\n",
        "\n",
        "    @staticmethod\n",
        "    def backward(ctx, *grad_outputs):\n",
        "        with torch.no_grad():\n",
        "            classification_head_val, target_nllplus = ctx.saved_tensors\n",
        "\n",
        "            N = classification_head_val.shape[0]\n",
        "\n",
        "            classification_head_val.sub_(classification_head_val.max(1, keepdim=True)[0])\n",
        "\n",
        "            univ_probs = classification_head_val.exp()\n",
        "            univ_probs.div_(univ_probs.sum(1, keepdim=True))#.mul_(-1)\n",
        "\n",
        "            mapped_probs = classification_head_val.mul_(target_nllplus)\n",
        "            mapped_probs.sub_(mapped_probs.min(1, keepdims=True)[0])\n",
        "            mapped_probs.mul_(target_nllplus)\n",
        "            mapped_probs.sub_(mapped_probs.max(1, keepdims=True)[0])\n",
        "            mapped_probs.exp_()\n",
        "            mapped_probs.mul_(target_nllplus)\n",
        "            mapped_probs.div_(mapped_probs.sum(1, keepdim=True))\n",
        "\n",
        "            classification_head_grads = univ_probs.sub_(mapped_probs).div_(N)\n",
        "\n",
        "            return grad_outputs[0] * classification_head_grads, None"
      ],
      "metadata": {
        "id": "kL2YapL_GwvO"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "class OffsetDataset(Dataset):\n",
        "    def __init__(self, dataset, offset):\n",
        "        self.dataset = dataset\n",
        "        self.offset = offset\n",
        "\n",
        "    def __len__(self):\n",
        "        return self.dataset.__len__()\n",
        "\n",
        "    def __getitem__(self, idx):\n",
        "        batch = self.dataset.__getitem__(idx)\n",
        "        batch['label'] += self.offset\n",
        "        return batch"
      ],
      "metadata": {
        "id": "xkCl84r6Gz9T"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "class MNISTSplitDataset(Dataset):\n",
        "  def __init__(self, mnist_split):\n",
        "    self.mnist_split = [(image, label) for image, label in mnist_split if label != 255]\n",
        "    self.to_tensor = transforms.ToTensor()\n",
        "\n",
        "  def __len__(self):\n",
        "    return len(self.mnist_split)\n",
        "\n",
        "  def __getitem__(self, idx):\n",
        "    batch_item = {}\n",
        "    mnist_image = self.to_tensor(self.mnist_split[idx][0])\n",
        "    batch_item[\"input\"] = torch.cat([mnist_image, mnist_image, mnist_image], axis=0)\n",
        "    batch_item[\"label\"] = torch.LongTensor(np.array([self.mnist_split[idx][1]]))\n",
        "    return batch_item"
      ],
      "metadata": {
        "id": "tyaBTw7ZG3bc"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=None)\n",
        "mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())\n",
        "\n",
        "mnist_trainset_list = [(image, label) for image, label in mnist_trainset]\n",
        "\n",
        "# Split MNIST into two subsets\n",
        "mnist_trainset_1 = mnist_trainset_list[:30000]\n",
        "mnist_trainset_2 = mnist_trainset_list[30000:]\n",
        "\n",
        "# Remap subsplits and merge classes\n",
        "# D1 classes: {0}, (1, 2), {3}, {4}, {5, 6, 7}, {8}\n",
        "# D2 classes: {1}, {2, 3}, {4}, {5}, {6}, {7}, {8}, {9}\n",
        "mapping_1 = [0, 1, 1, 2, 3, 4, 4, 4, 5, 255]\n",
        "n_class_1 = 6\n",
        "mapping_2 = [255, 0, 1, 1, 2, 3, 4, 5, 6, 7]\n",
        "n_class_2 = 8\n",
        "\n",
        "mnist_trainset_1_mapped = [(image, mapping_1[label]) for image, label in mnist_trainset_1]\n",
        "mnist_trainset_2_mapped = [(image, mapping_2[label]) for image, label in mnist_trainset_2]\n",
        "\n",
        "dataset_1_mapped = MNISTSplitDataset(mnist_trainset_1_mapped)\n",
        "dataset_2_mapped = MNISTSplitDataset(mnist_trainset_2_mapped)\n",
        "dataset_2_mapped_offset= OffsetDataset(dataset_2_mapped, n_class_1)\n",
        "\n",
        "dataset_train = ConcatDataset([dataset_1_mapped, dataset_2_mapped_offset])"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "zQSaQ-NjG6Vy",
        "outputId": "5f289059-fd1f-47b9-a793-5e0d3035ce7b"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "100%|██████████| 9.91M/9.91M [00:01<00:00, 5.09MB/s]\n",
            "100%|██████████| 28.9k/28.9k [00:00<00:00, 136kB/s]\n",
            "100%|██████████| 1.65M/1.65M [00:06<00:00, 247kB/s]\n",
            "100%|██████████| 4.54k/4.54k [00:00<00:00, 7.71MB/s]\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "n_class_classification = 10\n",
        "\n",
        "superclass_mask_nllplus = torch.zeros(n_class_1 + n_class_2, n_class_classification)\n",
        "\n",
        "for class_i, mapped_class_i in enumerate(mapping_1):\n",
        "    if mapped_class_i != 255:\n",
        "        superclass_mask_nllplus[mapped_class_i, class_i] = 1\n",
        "\n",
        "for class_i, mapped_class_i in enumerate(mapping_2):\n",
        "    if mapped_class_i != 255:\n",
        "        superclass_mask_nllplus[mapped_class_i + n_class_1, class_i] = 1"
      ],
      "metadata": {
        "id": "ZqBXKjgKG_aU"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "train_loader_1 = DataLoader(dataset_1_mapped, batch_size=2, shuffle=True)\n",
        "train_loader_2 = DataLoader(dataset_2_mapped_offset, batch_size=2, shuffle=True)\n",
        "test_loader = DataLoader(mnist_testset, batch_size=1, shuffle=False)"
      ],
      "metadata": {
        "id": "3-O06tMRHIUg"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "activation = {}\n",
        "def getActivation(name):\n",
        "    # the hook signature\n",
        "    def hook(model, input, output):\n",
        "        activation[name] = output\n",
        "    return hook"
      ],
      "metadata": {
        "id": "aCuyqjI5HLsc"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "resnet_weights = ResNet18_Weights.DEFAULT\n",
        "model_features = resnet18(weights=resnet_weights, progress=False)\n",
        "h1 = model_features.avgpool.register_forward_hook(getActivation('avgpool'))\n",
        "model_classification_head = Head(512, n_class_classification)\n",
        "loss = StableNLLPlusLoss(superclass_mask_nllplus)\n",
        "\n",
        "model_features.cuda()\n",
        "model_classification_head.cuda()\n",
        "loss.cuda()"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "pTCWWITtHNzg",
        "outputId": "5feb061f-e36b-490a-de9e-f6125efc33a5"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "StableNLLPlusLoss()"
            ]
          },
          "metadata": {},
          "execution_count": 71
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "epochs = 2\n",
        "lr = 0.0005\n",
        "weight_decay = 0.00\n",
        "optimizer = torch.optim.Adam([{'params': list(model_features.parameters()), 'lr_factor': 0.3},\n",
        "                              {'params': list(model_classification_head.parameters()), 'lr_factor': 1}],\n",
        "                              lr=lr, weight_decay=weight_decay)\n",
        "scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs, eta_min=1e-6)"
      ],
      "metadata": {
        "id": "T8XXLhx5HROt"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "for epoch in range(epochs):\n",
        "    model_features.train()\n",
        "    model_classification_head.train()\n",
        "\n",
        "    lr = optimizer.state_dict()['param_groups'][0]['lr']\n",
        "\n",
        "    print(f'Epoch: {epoch}, LR: {lr}')\n",
        "    print()\n",
        "\n",
        "    for i, (batch_1, batch_2) in enumerate(zip(train_loader_1, train_loader_2)):\n",
        "        optimizer.zero_grad()\n",
        "\n",
        "        batch_input =  torch.concatenate([batch_1['input'], batch_2['input']])\n",
        "        batch_labels = torch.concatenate([batch_1['label'], batch_2['label']])\n",
        "\n",
        "        model_features(batch_input.cuda())\n",
        "        features_val = activation['avgpool'].squeeze((2,3))\n",
        "        classification_head_val = model_classification_head(features_val)\n",
        "\n",
        "        loss_val = loss(classification_head_val, batch_labels.cuda())\n",
        "\n",
        "        loss_val.backward()\n",
        "        optimizer.step()\n",
        "\n",
        "        if i%500 == 0:\n",
        "            print(f'Total loss: {loss_val} ')\n",
        "\n",
        "    if epoch%1==0:\n",
        "        model_features.eval()\n",
        "        model_classification_head.eval()\n",
        "\n",
        "        total = 0\n",
        "        correct = 0\n",
        "\n",
        "        preds_per_class = {}\n",
        "\n",
        "        for i in range(10):\n",
        "            preds_per_class[i] = np.empty([0], dtype=np.int32)\n",
        "\n",
        "        with torch.no_grad():\n",
        "            for i, batch in enumerate(test_loader):\n",
        "                total += 1\n",
        "                model_features(torch.cat([batch[0], batch[0], batch[0]], axis=1).cuda())\n",
        "                features_val = activation['avgpool'].squeeze((2,3))\n",
        "                head_val = model_classification_head(features_val).softmax(1)\n",
        "                pred = head_val.argmax(1).detach().cpu()\n",
        "\n",
        "                if pred == batch[1]:\n",
        "                    correct += 1\n",
        "\n",
        "\n",
        "        acc = correct / total\n",
        "\n",
        "        print()\n",
        "        print('TEST DATASET')\n",
        "        print(f'Accuracy: {acc}')\n",
        "\n",
        "        print()\n",
        "        print()\n",
        "        print()\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "cFYw5ieeKKnL",
        "outputId": "bf3ea5f4-4814-4672-d871-8880157253b0"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Epoch: 0, LR: 0.0005\n",
            "\n",
            "Total loss: 2.129298210144043 \n",
            "Total loss: 0.6390420198440552 \n",
            "Total loss: 0.8482486009597778 \n",
            "Total loss: 1.507961630821228 \n",
            "Total loss: 0.08511792123317719 \n",
            "Total loss: 0.7258517742156982 \n",
            "Total loss: 0.770119845867157 \n",
            "Total loss: 1.4686123132705688 \n",
            "Total loss: 0.5361353754997253 \n",
            "Total loss: 0.039429303258657455 \n",
            "Total loss: 3.7519314289093018 \n",
            "Total loss: 0.6730976104736328 \n",
            "Total loss: 0.8831077218055725 \n",
            "Total loss: 0.050707533955574036 \n",
            "Total loss: 0.06387367099523544 \n",
            "Total loss: 0.9928809404373169 \n",
            "Total loss: 0.05188272148370743 \n",
            "Total loss: 0.01740499958395958 \n",
            "Total loss: 0.017824610695242882 \n",
            "Total loss: 0.0042883711867034435 \n",
            "Total loss: 0.018012776970863342 \n",
            "Total loss: 0.44279026985168457 \n",
            "Total loss: 0.0008849854348227382 \n",
            "Total loss: 0.3320085108280182 \n",
            "Total loss: 4.201266288757324 \n",
            "Total loss: 0.7105538249015808 \n",
            "Total loss: 0.007401555776596069 \n",
            "\n",
            "TEST DATASET\n",
            "Accuracy: 0.9608\n",
            "\n",
            "\n",
            "\n",
            "Epoch: 1, LR: 0.0005\n",
            "\n",
            "Total loss: 0.037276990711688995 \n",
            "Total loss: 0.0034946228843182325 \n",
            "Total loss: 0.0016263547586277127 \n",
            "Total loss: 0.002603666391223669 \n",
            "Total loss: 0.009405193850398064 \n",
            "Total loss: 0.0032043226528912783 \n",
            "Total loss: 0.4675026535987854 \n",
            "Total loss: 0.03537721186876297 \n",
            "Total loss: 0.010555644519627094 \n",
            "Total loss: 0.00028056505834683776 \n",
            "Total loss: 0.0020650536753237247 \n",
            "Total loss: 1.906930923461914 \n",
            "Total loss: 0.8384901285171509 \n",
            "Total loss: 0.03608042746782303 \n",
            "Total loss: 0.0008701452170498669 \n",
            "Total loss: 0.0001194122523884289 \n",
            "Total loss: 0.0008788890554569662 \n",
            "Total loss: 0.001772381947375834 \n",
            "Total loss: 0.056542426347732544 \n",
            "Total loss: 0.003655378706753254 \n",
            "Total loss: 0.003133680671453476 \n",
            "Total loss: 0.0004406999214552343 \n",
            "Total loss: 0.3247745931148529 \n",
            "Total loss: 0.12781253457069397 \n",
            "Total loss: 0.0005702761700376868 \n",
            "Total loss: 0.07749534398317337 \n",
            "Total loss: 0.00039766920963302255 \n",
            "\n",
            "TEST DATASET\n",
            "Accuracy: 0.9698\n",
            "\n",
            "\n",
            "\n"
          ]
        }
      ]
    }
  ]
}