{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "gpuType": "T4"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PaQpd8KQYncG"
      },
      "outputs": [],
      "source": [
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as f\n",
        "from torch.utils.data import Dataset, DataLoader, ConcatDataset\n",
        "from collections import OrderedDict\n",
        "from matplotlib.colors import LinearSegmentedColormap\n",
        "import torchvision\n",
        "import torchvision.datasets as datasets\n",
        "import torchvision.transforms as transforms\n",
        "from torchvision.models import resnet18, ResNet18_Weights\n",
        "from torch.autograd import Function"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "class Head(nn.Module):\n",
        "    def __init__(self, dim_in, num_classes, transfer=nn.ReLU()):\n",
        "        super(Head, self).__init__()\n",
        "        self.transfer = transfer\n",
        "        self.logits = nn.Linear(dim_in, num_classes)\n",
        "\n",
        "        nn.init.constant_(self.logits.weight, 0)\n",
        "        nn.init.constant_(self.logits.bias, 0)\n",
        "\n",
        "    def forward(self, x):\n",
        "        transfer = self.transfer(x)\n",
        "        output = self.logits(transfer)\n",
        "        return output"
      ],
      "metadata": {
        "id": "VJ1lFOxHY_Df"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "class MNISTSplitDataset(Dataset):\n",
        "  def __init__(self, mnist_split, mapping=lambda x: x):\n",
        "    self.mnist_split = [(image, mapping(label)) for image, label in mnist_split if mapping(label) != 255]\n",
        "    self.resize = transforms.Resize([32,32])\n",
        "\n",
        "  def __len__(self):\n",
        "    return len(self.mnist_split)\n",
        "\n",
        "  def __getitem__(self, idx):\n",
        "    batch_item = {}\n",
        "    mnist_image = self.mnist_split[idx][0]\n",
        "    batch_item[\"input\"] = self.resize(torch.cat([mnist_image, mnist_image, mnist_image], axis=0))\n",
        "    batch_item[\"label\"] = torch.LongTensor(np.array([self.mnist_split[idx][1]]))\n",
        "    return batch_item"
      ],
      "metadata": {
        "id": "-RIiF1dCaGYy"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())\n",
        "mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "DH-2vPhLaJFu",
        "outputId": "b2e5dfc3-aec1-447f-fe4a-31520f1a5779"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "100%|██████████| 9.91M/9.91M [00:00<00:00, 53.9MB/s]\n",
            "100%|██████████| 28.9k/28.9k [00:00<00:00, 1.61MB/s]\n",
            "100%|██████████| 1.65M/1.65M [00:00<00:00, 14.5MB/s]\n",
            "100%|██████████| 4.54k/4.54k [00:00<00:00, 14.0MB/s]\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "mnist_trainset_list = [(image, label) for image, label in mnist_trainset]\n",
        "\n",
        "mnist_trainset_1_train = mnist_trainset_list[:25000]\n",
        "mnist_trainset_1_val = mnist_trainset_list[25000:30000]\n",
        "mnist_trainset_2_train = mnist_trainset_list[30000:55000]\n",
        "mnist_trainset_2_val = mnist_trainset_list[55000:]\n",
        "\n",
        "mapping_1 = [0, 1, 1, 2, 3, 4, 4, 4, 5, 255]\n",
        "n_class_1 = 6\n",
        "\n",
        "mapping_2 = [255, 0, 1, 1, 2, 3, 4, 5, 6, 7]\n",
        "n_class_2 = 8"
      ],
      "metadata": {
        "id": "IVsIoYKuaL4c"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "dataset_1_train = MNISTSplitDataset(mnist_trainset_1_train, lambda x: mapping_1[x])\n",
        "dataset_1_val = MNISTSplitDataset(mnist_trainset_1_val, lambda x: mapping_1[x])\n",
        "dataset_2_train = MNISTSplitDataset(mnist_trainset_2_train, lambda x: mapping_2[x])\n",
        "dataset_2_val = MNISTSplitDataset(mnist_trainset_2_val, lambda x: mapping_2[x])"
      ],
      "metadata": {
        "id": "YxZ9bBJ4alrb"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "train_loader_1 = DataLoader(dataset_1_train, batch_size=50, shuffle=True)\n",
        "train_loader_2 = DataLoader(dataset_2_train, batch_size=50, shuffle=True)\n",
        "val_loader_1 = DataLoader(dataset_1_val, batch_size=1, shuffle=False)\n",
        "val_loader_2 = DataLoader(dataset_2_val, batch_size=1, shuffle=False)"
      ],
      "metadata": {
        "id": "la3oE6GKaoOt"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "activation = {}\n",
        "def getActivation(name):\n",
        "    # the hook signature\n",
        "    def hook(model, input, output):\n",
        "        activation[name] = output\n",
        "    return hook"
      ],
      "metadata": {
        "id": "-UXoYfnGarse"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "resnet_weights = ResNet18_Weights.DEFAULT\n",
        "model_features_1 = resnet18(weights=resnet_weights, progress=False)\n",
        "h1 = model_features_1.avgpool.register_forward_hook(getActivation('avgpool_1'))\n",
        "model_head_1 = Head(512, n_class_1)\n",
        "\n",
        "loss_1 = nn.CrossEntropyLoss()\n",
        "\n",
        "model_features_1.cuda()\n",
        "model_head_1.cuda()"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "5jVhLk1oauKo",
        "outputId": "240268bc-99d3-48af-9bbc-e46b5167a654"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "Head(\n",
              "  (transfer): ReLU()\n",
              "  (logits): Linear(in_features=512, out_features=6, bias=True)\n",
              ")"
            ]
          },
          "metadata": {},
          "execution_count": 14
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "lr = 0.01\n",
        "weight_decay = 0\n",
        "optimizer_1 = torch.optim.Adam(list(model_features_1.parameters()) + list(model_head_1.parameters()), lr=lr, weight_decay=weight_decay)"
      ],
      "metadata": {
        "id": "LOmFqqTKaxTq"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "epochs = 5\n",
        "for epoch in range(epochs):\n",
        "    model_features_1.train()\n",
        "    model_head_1.train()\n",
        "\n",
        "    for i, batch in enumerate(train_loader_1):\n",
        "        optimizer_1.zero_grad()\n",
        "        model_features_1(batch['input'].cuda())\n",
        "        features_val = activation['avgpool_1'].squeeze((2,3))\n",
        "        head_val = model_head_1(features_val)\n",
        "        loss_val = loss_1(head_val, batch['label'][:,0].cuda())\n",
        "        loss_val.backward()\n",
        "        optimizer_1.step()\n",
        "        if i%100 == 0:\n",
        "            print(f'Loss: {loss_val}')\n",
        "\n",
        "    if epoch%1==0:\n",
        "        model_features_1.eval()\n",
        "        model_head_1.eval()\n",
        "\n",
        "        correct = 0\n",
        "        total = 0\n",
        "        with torch.no_grad():\n",
        "            for i, batch in enumerate(val_loader_1):\n",
        "                total += 1\n",
        "                model_features_1(batch['input'].cuda())\n",
        "                features_val = activation['avgpool_1'].squeeze((2,3))\n",
        "                head_val = model_head_1(features_val).softmax(1)\n",
        "                pred = head_val.argmax(1).detach().cpu()\n",
        "                if pred == batch['label']:\n",
        "                    correct += 1\n",
        "        acc = correct / total\n",
        "\n",
        "        print(f'Accuracy: {acc}')"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "LFkqPPX4a2tc",
        "outputId": "fd032a75-4349-46a3-b022-ad7e8556e24e"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Loss: 1.7917605638504028\n",
            "Loss: 0.5224153399467468\n",
            "Loss: 0.10932692885398865\n",
            "Loss: 0.16138581931591034\n",
            "Loss: 0.09360399842262268\n",
            "Accuracy: 0.6327934324384291\n",
            "Loss: 0.4977143406867981\n",
            "Loss: 0.07449705898761749\n",
            "Loss: 0.028845930472016335\n",
            "Loss: 0.10989704728126526\n",
            "Loss: 0.037191241979599\n",
            "Accuracy: 0.9667184379853561\n",
            "Loss: 0.06580374389886856\n",
            "Loss: 0.1298539936542511\n",
            "Loss: 0.007319376338273287\n",
            "Loss: 0.037454504519701004\n",
            "Loss: 0.10685894638299942\n",
            "Accuracy: 0.9778122919902374\n",
            "Loss: 0.09644032269716263\n",
            "Loss: 0.16484558582305908\n",
            "Loss: 0.03130181506276131\n",
            "Loss: 0.24443362653255463\n",
            "Loss: 0.22640539705753326\n",
            "Accuracy: 0.9846904814732638\n",
            "Loss: 0.004782047588378191\n",
            "Loss: 0.010640471242368221\n",
            "Loss: 0.22804246842861176\n",
            "Loss: 0.306806355714798\n",
            "Loss: 0.004525087773799896\n",
            "Accuracy: 0.9578433547814511\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "resnet_weights = ResNet18_Weights.DEFAULT\n",
        "model_features_2 = resnet18(weights=resnet_weights, progress=False)\n",
        "h1 = model_features_2.avgpool.register_forward_hook(getActivation('avgpool_2'))\n",
        "model_head_2 = Head(512, n_class_2)\n",
        "\n",
        "loss_2 = nn.CrossEntropyLoss()\n",
        "\n",
        "model_features_2.cuda()\n",
        "model_head_2.cuda()"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "UtifmXr1bnKd",
        "outputId": "0596a170-7e83-447c-f56c-c29405766af2"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "Head(\n",
              "  (transfer): ReLU()\n",
              "  (logits): Linear(in_features=512, out_features=8, bias=True)\n",
              ")"
            ]
          },
          "metadata": {},
          "execution_count": 20
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "optimizer_2 = torch.optim.Adam(list(model_features_2.parameters()) + list(model_head_2.parameters()), lr=lr, weight_decay=weight_decay)"
      ],
      "metadata": {
        "id": "05OsuvBKbqiH"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "epochs = 5\n",
        "for epoch in range(epochs):\n",
        "    model_features_2.train()\n",
        "    model_head_2.train()\n",
        "\n",
        "    for i, batch in enumerate(train_loader_2):\n",
        "        optimizer_2.zero_grad()\n",
        "        model_features_2(batch['input'].cuda())\n",
        "        features_val = activation['avgpool_2'].squeeze((2,3))\n",
        "        head_val = model_head_2(features_val)\n",
        "        loss_val = loss_2(head_val, batch['label'][:,0].cuda())\n",
        "        loss_val.backward()\n",
        "        optimizer_2.step()\n",
        "        if i%100 == 0:\n",
        "            print(f'Loss: {loss_val}')\n",
        "\n",
        "    if epoch%1==0:\n",
        "        model_features_2.eval()\n",
        "        model_head_2.eval()\n",
        "\n",
        "        correct = 0\n",
        "        total = 0\n",
        "        with torch.no_grad():\n",
        "            for i, batch in enumerate(val_loader_2):\n",
        "                total += 1\n",
        "                model_features_2(batch['input'].cuda())\n",
        "                features_val = activation['avgpool_2'].squeeze((2,3))\n",
        "                head_val = model_head_2(features_val).softmax(1)\n",
        "                pred = head_val.argmax(1).detach().cpu()\n",
        "                if pred == batch['label']:\n",
        "                    correct += 1\n",
        "        acc = correct / total\n",
        "\n",
        "        print(f'Accuracy: {acc}')"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "D2lreuQlbuNa",
        "outputId": "a32054bd-8c25-4dfc-9e46-18f8febc1709"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Loss: 2.079442262649536\n",
            "Loss: 0.371639609336853\n",
            "Loss: 0.9233502149581909\n",
            "Loss: 0.2956743538379669\n",
            "Loss: 0.2522566318511963\n",
            "Accuracy: 0.9552205719352693\n",
            "Loss: 0.20914539694786072\n",
            "Loss: 0.06403876841068268\n",
            "Loss: 0.27128246426582336\n",
            "Loss: 1.1651761531829834\n",
            "Loss: 0.23054225742816925\n",
            "Accuracy: 0.9510086455331412\n",
            "Loss: 0.08772919327020645\n",
            "Loss: 0.2339497208595276\n",
            "Loss: 0.14963069558143616\n",
            "Loss: 0.04184463620185852\n",
            "Loss: 0.12704597413539886\n",
            "Accuracy: 0.9647528264242962\n",
            "Loss: 0.19729752838611603\n",
            "Loss: 0.15919718146324158\n",
            "Loss: 0.22495122253894806\n",
            "Loss: 0.08478151261806488\n",
            "Loss: 0.00686968257650733\n",
            "Accuracy: 0.9727333185546442\n",
            "Loss: 0.020419657230377197\n",
            "Loss: 0.12483371049165726\n",
            "Loss: 0.15372627973556519\n",
            "Loss: 0.30317071080207825\n",
            "Loss: 0.21469475328922272\n",
            "Accuracy: 0.9521170472179118\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "train_eval_loader_1 = DataLoader(dataset_1_train, batch_size=1, shuffle=False)\n",
        "train_eval_loader_2 = DataLoader(dataset_2_train, batch_size=1, shuffle=False)"
      ],
      "metadata": {
        "id": "O9M1ynjociMt"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "model_features_2.eval()\n",
        "model_head_2.eval()\n",
        "\n",
        "correlation_matrix_gt_1_pred_2 = np.zeros([n_class_1, n_class_2+1], dtype=np.uint32)\n",
        "\n",
        "with torch.no_grad():\n",
        "    for i, batch in enumerate(train_eval_loader_1):\n",
        "        model_features_2(batch['input'].cuda())\n",
        "        features_val = activation['avgpool_2'].squeeze((2,3))\n",
        "        head_val = model_head_2(features_val).softmax(1)\n",
        "        conf = head_val.max(1)[0]\n",
        "        pred = head_val.argmax(1).detach().cpu()\n",
        "        #print(pred, conf)\n",
        "        if conf > 0.5:\n",
        "            correlation_matrix_gt_1_pred_2[batch['label'], pred] += 1\n",
        "        else:\n",
        "            correlation_matrix_gt_1_pred_2[batch['label'], n_class_2] += 1"
      ],
      "metadata": {
        "id": "GPgNX7u1cnS7"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "model_features_1.eval()\n",
        "model_head_1.eval()\n",
        "\n",
        "correlation_matrix_gt_2_pred_1 = np.zeros([n_class_2, n_class_1+1], dtype=np.uint32)\n",
        "\n",
        "with torch.no_grad():\n",
        "    for i, batch in enumerate(train_eval_loader_2):\n",
        "        model_features_1(batch['input'].cuda())\n",
        "        features_val = activation['avgpool_1'].squeeze((2,3))\n",
        "        head_val = model_head_1(features_val).softmax(1)\n",
        "        conf = head_val.max(1)[0]\n",
        "        pred = head_val.argmax(1).detach().cpu()\n",
        "        if conf > 0.5:\n",
        "            correlation_matrix_gt_2_pred_1[batch['label'], pred] += 1\n",
        "        else:\n",
        "            correlation_matrix_gt_2_pred_1[batch['label'], n_class_1] += 1"
      ],
      "metadata": {
        "id": "HdQPsO_vcq-q"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "for i in range(n_class_1):\n",
        "    print('|',end=' ')\n",
        "    for j in range(n_class_2+1):\n",
        "        print(f\"{correlation_matrix_gt_1_pred_2[i, j]:4}\", end=' | ')\n",
        "    print()"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "jprrvD0FdA_8",
        "outputId": "41679f5a-9302-49a9-aabd-c8f93d8a2e99"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "|  128 |    9 |   45 |  142 | 1899 |   55 |    0 |   61 |  133 | \n",
            "| 2840 | 2124 |   65 |   10 |   20 |   54 |   13 |   79 |   94 | \n",
            "|    8 | 2356 |    4 |   51 |    1 |   19 |    6 |   91 |   56 | \n",
            "|   17 |    0 | 2380 |    0 |    3 |    7 |    0 |   31 |    7 | \n",
            "|   18 |    9 |   10 | 2153 | 2481 | 2555 |    3 |   46 |   19 | \n",
            "|  118 |   18 |   40 |   46 |  122 |   14 | 1854 |   44 |  132 | \n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "for i in range(n_class_2):\n",
        "    print('|',end=' ')\n",
        "    for j in range(n_class_1+1):\n",
        "        print(f\"{correlation_matrix_gt_2_pred_1[i, j]:4}\", end=' | ')\n",
        "    print()"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "vheLy7XMdYuC",
        "outputId": "433e3769-1ac9-403b-c3f4-45d6a3ec036b"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "|    0 | 2522 |    2 |    8 |    5 |  214 |   38 | \n",
            "|    1 | 2417 | 2429 |    2 |   32 |  149 |   36 | \n",
            "|    3 |    4 |    0 | 2341 |   16 |   37 |   16 | \n",
            "|    1 |    2 |   24 |    2 | 2152 |   52 |   21 | \n",
            "|    1 |    3 |    0 |    2 | 2415 |   35 |    5 | \n",
            "|    3 |    6 |    9 |   17 | 2500 |   26 |   34 | \n",
            "|    0 |    1 |    5 |    3 |   11 | 2455 |    7 | \n",
            "|   11 |    9 |   38 | 1085 |  400 |  640 |  280 | \n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "percentage_1 = np.zeros([n_class_1, n_class_2+1], dtype=np.float32)\n",
        "\n",
        "for i in range(n_class_1):\n",
        "    for j in range(n_class_2+1):\n",
        "        percentage_1[i,j] = correlation_matrix_gt_1_pred_2[i, j] / correlation_matrix_gt_1_pred_2[i,:].sum()\n",
        "\n",
        "print(np.round(percentage_1, 2))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "7GzGdLR0ddqF",
        "outputId": "b7a563c5-8895-4eaf-d99d-f4e1b19109a0"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[[0.05 0.   0.02 0.06 0.77 0.02 0.   0.02 0.05]\n",
            " [0.54 0.4  0.01 0.   0.   0.01 0.   0.01 0.02]\n",
            " [0.   0.91 0.   0.02 0.   0.01 0.   0.04 0.02]\n",
            " [0.01 0.   0.97 0.   0.   0.   0.   0.01 0.  ]\n",
            " [0.   0.   0.   0.3  0.34 0.35 0.   0.01 0.  ]\n",
            " [0.05 0.01 0.02 0.02 0.05 0.01 0.78 0.02 0.06]]\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "percentage_2 = np.zeros([n_class_2, n_class_1+1], dtype=np.float32)\n",
        "\n",
        "for i in range(n_class_2):\n",
        "    for j in range(n_class_1+1):\n",
        "        percentage_2[i,j] = correlation_matrix_gt_2_pred_1[i, j] / correlation_matrix_gt_2_pred_1[i,:].sum()\n",
        "\n",
        "print(np.round(percentage_2, 2))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "t1woVuMrdiCx",
        "outputId": "4632545a-e8be-45ad-9c90-ee414e876cde"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[[0.   0.9  0.   0.   0.   0.08 0.01]\n",
            " [0.   0.48 0.48 0.   0.01 0.03 0.01]\n",
            " [0.   0.   0.   0.97 0.01 0.02 0.01]\n",
            " [0.   0.   0.01 0.   0.95 0.02 0.01]\n",
            " [0.   0.   0.   0.   0.98 0.01 0.  ]\n",
            " [0.   0.   0.   0.01 0.96 0.01 0.01]\n",
            " [0.   0.   0.   0.   0.   0.99 0.  ]\n",
            " [0.   0.   0.02 0.44 0.16 0.26 0.11]]\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "percentage_1_shared = percentage_1[:,:n_class_2]\n",
        "percentage_1_ds_specific = percentage_1[:, n_class_2]\n",
        "\n",
        "percentage_2_shared = percentage_2[:,:n_class_1].transpose()\n",
        "percentage_2_ds_specific = percentage_2[:, n_class_1]"
      ],
      "metadata": {
        "id": "BHWHNOJodpTt"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "test_mapping = torch.tensor([48, 8, 9, 17, 26, 35, 36, 37, 46, 61])\n",
        "\n",
        "true_relations = torch.zeros((62), dtype=torch.int32)\n",
        "\n",
        "for true_relation in test_mapping:\n",
        "    true_relations[true_relation] = 1"
      ],
      "metadata": {
        "id": "LOIotdy8d9iv"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "percentage_shared = np.maximum(percentage_1_shared, percentage_2_shared).reshape((48))"
      ],
      "metadata": {
        "id": "NvK6EnqweNjR"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "final_percentages = np.concatenate((percentage_shared, percentage_1_ds_specific, percentage_2_ds_specific)).reshape((62))"
      ],
      "metadata": {
        "id": "JTZZ967IeXAO"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def average_precision_vectorized(y_true, y_scores):\n",
        "    device = torch.device(\"cpu\")\n",
        "    y_true = y_true.to(device)\n",
        "    y_scores = y_scores.to(device)\n",
        "\n",
        "    sorted_indices = torch.argsort(y_scores, descending=True)\n",
        "    y_true = y_true[sorted_indices]\n",
        "\n",
        "    cumulative_true = torch.cumsum(y_true, dim=0)\n",
        "    precision_values = cumulative_true / (torch.arange(1, len(y_true) + 1, device=device))\n",
        "\n",
        "    average_precision = torch.sum(precision_values * y_true) / torch.sum(y_true)\n",
        "\n",
        "    return average_precision.item() if torch.sum(y_true) > 0 else 0.0"
      ],
      "metadata": {
        "id": "tBJYuFPseYd5"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "average_precision_vectorized(torch.tensor(true_relations), torch.tensor(final_percentages))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "J1Xbcwf_ecc1",
        "outputId": "9b5d5b4b-182b-47fd-ab52-ae353e6ec1de"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/tmp/ipython-input-39-1587174842.py:1: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
            "  average_precision_vectorized(torch.tensor(true_relations), torch.tensor(final_percentages))\n"
          ]
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "0.9169431924819946"
            ]
          },
          "metadata": {},
          "execution_count": 39
        }
      ]
    }
  ]
}