{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "gpuType": "T4"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "id": "_2ByVZlhemWj"
      },
      "outputs": [],
      "source": [
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as f\n",
        "from torch.utils.data import Dataset, DataLoader, ConcatDataset\n",
        "from collections import OrderedDict\n",
        "from matplotlib.colors import LinearSegmentedColormap\n",
        "import torchvision\n",
        "import torchvision.datasets as datasets\n",
        "import torchvision.transforms as transforms\n",
        "from torchvision.models import resnet18, ResNet18_Weights\n",
        "from torch.autograd import Function\n",
        "import copy"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "class Head(nn.Module):\n",
        "    def __init__(self, dim_in, num_classes, transfer=nn.ReLU()):\n",
        "        super(Head, self).__init__()\n",
        "        self.transfer = transfer\n",
        "        self.logits = nn.Linear(dim_in, num_classes)\n",
        "\n",
        "        nn.init.constant_(self.logits.weight, 0)\n",
        "        nn.init.constant_(self.logits.bias, 0)\n",
        "\n",
        "    def forward(self, x):\n",
        "        transfer = self.transfer(x)\n",
        "        output = self.logits(transfer)\n",
        "        return output"
      ],
      "metadata": {
        "id": "2sUg2eMuevyT"
      },
      "execution_count": 2,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "class MNISTSplitDataset(Dataset):\n",
        "  def __init__(self, mnist_split, mapping=lambda x: x):\n",
        "    self.mnist_split = [(image, mapping(label)) for image, label in mnist_split if mapping(label) != 255]\n",
        "    self.resize = transforms.Resize([32,32])\n",
        "\n",
        "  def __len__(self):\n",
        "    return len(self.mnist_split)\n",
        "\n",
        "  def __getitem__(self, idx):\n",
        "    batch_item = {}\n",
        "    mnist_image = self.mnist_split[idx][0]\n",
        "    batch_item[\"input\"] = self.resize(torch.cat([mnist_image, mnist_image, mnist_image], axis=0))\n",
        "    batch_item[\"label\"] = torch.LongTensor(np.array([self.mnist_split[idx][1]]))\n",
        "    return batch_item"
      ],
      "metadata": {
        "id": "63v3o3fpezmz"
      },
      "execution_count": 3,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "class OffsetDataset(Dataset):\n",
        "    def __init__(self, dataset, offset):\n",
        "        self.dataset = dataset\n",
        "        self.offset = offset\n",
        "\n",
        "    def __len__(self):\n",
        "        return self.dataset.__len__()\n",
        "\n",
        "    def __getitem__(self, idx):\n",
        "        batch = self.dataset.__getitem__(idx)\n",
        "        batch['label'] += self.offset\n",
        "        return batch"
      ],
      "metadata": {
        "id": "0xwOtCJjfb3g"
      },
      "execution_count": 4,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())\n",
        "mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "I2UuhRqoe0mn",
        "outputId": "c751a4f9-cfbc-4c53-f28f-bea181cb6f40"
      },
      "execution_count": 5,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "100%|██████████| 9.91M/9.91M [00:02<00:00, 3.78MB/s]\n",
            "100%|██████████| 28.9k/28.9k [00:00<00:00, 134kB/s]\n",
            "100%|██████████| 1.65M/1.65M [00:01<00:00, 1.08MB/s]\n",
            "100%|██████████| 4.54k/4.54k [00:00<00:00, 9.22MB/s]\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "mnist_trainset_list = [(image, label) for image, label in mnist_trainset]\n",
        "\n",
        "mnist_trainset_1_train = mnist_trainset_list[:25000]\n",
        "mnist_trainset_1_val = mnist_trainset_list[25000:30000]\n",
        "mnist_trainset_2_train = mnist_trainset_list[30000:55000]\n",
        "mnist_trainset_2_val = mnist_trainset_list[55000:]\n",
        "\n",
        "mapping_1 = [0, 1, 1, 2, 3, 4, 4, 4, 5, 255]\n",
        "n_class_1 = 6\n",
        "\n",
        "mapping_2 = [255, 0, 1, 1, 2, 3, 4, 5, 6, 7]\n",
        "n_class_2 = 8"
      ],
      "metadata": {
        "id": "joYApuRLe30Q"
      },
      "execution_count": 6,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "dataset_1_train = MNISTSplitDataset(mnist_trainset_1_train, lambda x: mapping_1[x])\n",
        "dataset_1_val = MNISTSplitDataset(mnist_trainset_1_val, lambda x: mapping_1[x])\n",
        "dataset_2_train = OffsetDataset(MNISTSplitDataset(mnist_trainset_2_train, lambda x: mapping_2[x]), n_class_1)\n",
        "dataset_2_val = OffsetDataset(MNISTSplitDataset(mnist_trainset_2_val, lambda x: mapping_2[x]), n_class_1)"
      ],
      "metadata": {
        "id": "-HFtReN6e64B"
      },
      "execution_count": 7,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "train_loader_1 = DataLoader(dataset_1_train, batch_size=50, shuffle=True)\n",
        "train_loader_2 = DataLoader(dataset_2_train, batch_size=50, shuffle=True)\n",
        "val_loader_1 = DataLoader(dataset_1_val, batch_size=1, shuffle=False)\n",
        "val_loader_2 = DataLoader(dataset_2_val, batch_size=1, shuffle=False)"
      ],
      "metadata": {
        "id": "VVDyAMHBfJJk"
      },
      "execution_count": 8,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "activation = {}\n",
        "def getActivation(name):\n",
        "    # the hook signature\n",
        "    def hook(model, input, output):\n",
        "        activation[name] = output\n",
        "    return hook"
      ],
      "metadata": {
        "id": "PQRijU8qfKiZ"
      },
      "execution_count": 9,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "resnet_weights = ResNet18_Weights.DEFAULT\n",
        "model_features_1 = resnet18(weights=resnet_weights, progress=False)\n",
        "h1 = model_features_1.avgpool.register_forward_hook(getActivation('avgpool_1'))\n",
        "model_head_1 = Head(512, n_class_1 + n_class_2)\n",
        "\n",
        "loss_1 = nn.CrossEntropyLoss()\n",
        "\n",
        "model_features_1.cuda()\n",
        "model_head_1.cuda()"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "yVGJq9CwfMQo",
        "outputId": "2b766ef8-1110-459e-f0b3-09c790292ae8"
      },
      "execution_count": 10,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Downloading: \"https://download.pytorch.org/models/resnet18-f37072fd.pth\" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth\n"
          ]
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "Head(\n",
              "  (transfer): ReLU()\n",
              "  (logits): Linear(in_features=512, out_features=14, bias=True)\n",
              ")"
            ]
          },
          "metadata": {},
          "execution_count": 10
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "lr = 0.01\n",
        "weight_decay = 0\n",
        "optimizer_1 = torch.optim.Adam(list(model_features_1.parameters()) + list(model_head_1.parameters()), lr=lr, weight_decay=weight_decay)"
      ],
      "metadata": {
        "id": "C0uXPCP6foGW"
      },
      "execution_count": 11,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "epochs = 5\n",
        "for epoch in range(epochs):\n",
        "    print(f\"Epoch: {epoch}\")\n",
        "    print()\n",
        "    model_features_1.train()\n",
        "    model_head_1.train()\n",
        "\n",
        "    for i, (batch_1, batch_2) in enumerate(zip(train_loader_1, train_loader_2)):\n",
        "        optimizer_1.zero_grad()\n",
        "\n",
        "        batch_input =  torch.concatenate([batch_1['input'], batch_2['input']])\n",
        "        batch_labels = torch.concatenate([batch_1['label'], batch_2['label']])\n",
        "\n",
        "        model_features_1(batch_input.cuda())\n",
        "        features_val = activation['avgpool_1'].squeeze((2,3))\n",
        "        head_val = model_head_1(features_val)\n",
        "        loss_val = loss_1(head_val, batch_labels[:,0].cuda())\n",
        "        loss_val.backward()\n",
        "        optimizer_1.step()\n",
        "        if i%100 == 0:\n",
        "            print(f'Loss: {loss_val}')\n",
        "\n",
        "    if epoch%1==0:\n",
        "        model_features_1.eval()\n",
        "        model_head_1.eval()\n",
        "\n",
        "        correct = 0\n",
        "        total = 0\n",
        "        with torch.no_grad():\n",
        "            for i, batch in enumerate(val_loader_1):\n",
        "                total += 1\n",
        "                model_features_1(batch['input'].cuda())\n",
        "                features_val = activation['avgpool_1'].squeeze((2,3))\n",
        "                head_val = model_head_1(features_val).softmax(1)\n",
        "                pred = head_val.argmax(1).detach().cpu()\n",
        "                if pred == batch['label']:\n",
        "                    correct += 1\n",
        "        acc = correct / total\n",
        "\n",
        "        print()\n",
        "        print(f'Accuracy split 1: {acc}')\n",
        "        print()\n",
        "\n",
        "        correct = 0\n",
        "        total = 0\n",
        "        with torch.no_grad():\n",
        "            for i, batch in enumerate(val_loader_2):\n",
        "                total += 1\n",
        "                model_features_1(batch['input'].cuda())\n",
        "                features_val = activation['avgpool_1'].squeeze((2,3))\n",
        "                head_val = model_head_1(features_val).softmax(1)\n",
        "                pred = head_val.argmax(1).detach().cpu()\n",
        "                if pred == batch['label']:\n",
        "                    correct += 1\n",
        "        acc = correct / total\n",
        "\n",
        "        print(f'Accuracy split 2: {acc}')\n",
        "        print()\n",
        "        print()\n",
        "        print()"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "ZkGoTQ0ffqG1",
        "outputId": "27799ab9-f193-462f-a616-5e9ee142a998"
      },
      "execution_count": 12,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Epoch: 0\n",
            "\n",
            "Loss: 2.6390578746795654\n",
            "Loss: 1.0371555089950562\n",
            "Loss: 0.9231243133544922\n",
            "Loss: 0.8360127806663513\n",
            "Loss: 0.766924262046814\n",
            "\n",
            "Accuracy split 1: 0.527845573552252\n",
            "\n",
            "Accuracy split 2: 0.5462203502549324\n",
            "\n",
            "\n",
            "\n",
            "Epoch: 1\n",
            "\n",
            "Loss: 0.7011138200759888\n",
            "Loss: 0.7506441473960876\n",
            "Loss: 0.7451788187026978\n",
            "Loss: 0.6705529093742371\n",
            "Loss: 0.7696570754051208\n",
            "\n",
            "Accuracy split 1: 0.5544708231639671\n",
            "\n",
            "Accuracy split 2: 0.5451119485701619\n",
            "\n",
            "\n",
            "\n",
            "Epoch: 2\n",
            "\n",
            "Loss: 0.6864511370658875\n",
            "Loss: 0.7193345427513123\n",
            "Loss: 0.6925226449966431\n",
            "Loss: 0.6938967704772949\n",
            "Loss: 0.6720374226570129\n",
            "\n",
            "Accuracy split 1: 0.6312402928777457\n",
            "\n",
            "Accuracy split 2: 0.45111948570161825\n",
            "\n",
            "\n",
            "\n",
            "Epoch: 3\n",
            "\n",
            "Loss: 0.6114894151687622\n",
            "Loss: 0.6656644940376282\n",
            "Loss: 0.9182350039482117\n",
            "Loss: 0.6250119209289551\n",
            "Loss: 0.6273298859596252\n",
            "\n",
            "Accuracy split 1: 0.5606833814067007\n",
            "\n",
            "Accuracy split 2: 0.5459986699179783\n",
            "\n",
            "\n",
            "\n",
            "Epoch: 4\n",
            "\n",
            "Loss: 0.7044827342033386\n",
            "Loss: 0.6188033819198608\n",
            "Loss: 0.6648716926574707\n",
            "Loss: 0.6514598727226257\n",
            "Loss: 0.8000199198722839\n",
            "\n",
            "Accuracy split 1: 0.39605058797426224\n",
            "\n",
            "Accuracy split 2: 0.6943028153402793\n",
            "\n",
            "\n",
            "\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "train_eval_loader_1 = DataLoader(dataset_1_train, batch_size=1, shuffle=False)\n",
        "train_eval_loader_2 = DataLoader(dataset_2_train, batch_size=1, shuffle=False)"
      ],
      "metadata": {
        "id": "deFkmhMJfyMU"
      },
      "execution_count": 13,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "model_features_1.eval()\n",
        "model_head_1.eval()\n",
        "\n",
        "correlation_matrix_gt_2_pred_1 = np.zeros([n_class_2, n_class_1+1], dtype=np.uint32)\n",
        "\n",
        "with torch.no_grad():\n",
        "    for i, batch in enumerate(train_eval_loader_2):\n",
        "        model_features_1(batch['input'].cuda())\n",
        "        features_val = activation['avgpool_1'].squeeze((2,3))\n",
        "        head_val = model_head_1(features_val)[:,:n_class_1].softmax(1)\n",
        "        pred = head_val.argmax(1).detach().cpu()\n",
        "        conf = head_val.max(1)[0]\n",
        "        if conf > 0.5:\n",
        "            correlation_matrix_gt_2_pred_1[batch['label']-n_class_1, pred] += 1\n",
        "        else:\n",
        "            correlation_matrix_gt_2_pred_1[batch['label']-n_class_1, n_class_1] += 1"
      ],
      "metadata": {
        "id": "3oOq1meVfy3p"
      },
      "execution_count": 14,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "correlation_matrix_gt_1_pred_2 = np.zeros([n_class_1, n_class_2+1], dtype=np.uint32)\n",
        "\n",
        "with torch.no_grad():\n",
        "    for i, batch in enumerate(train_eval_loader_1):\n",
        "        model_features_1(batch['input'].cuda())\n",
        "        features_val = activation['avgpool_1'].squeeze((2,3))\n",
        "        head_val = model_head_1(features_val)[:, n_class_1:].softmax(1)\n",
        "        pred = head_val.argmax(1).detach().cpu()\n",
        "        conf = head_val.max(1)[0]\n",
        "        if conf > 0.5:\n",
        "            correlation_matrix_gt_1_pred_2[batch['label'], pred] += 1\n",
        "        else:\n",
        "            correlation_matrix_gt_1_pred_2[batch['label'], n_class_2] += 1"
      ],
      "metadata": {
        "id": "oulgEgjEf06d"
      },
      "execution_count": 15,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "correlation_matrix_gt_1_pred_2[n_class_1-1, n_class_2] = 0\n",
        "for i in range(n_class_1):\n",
        "    print('|',end=' ')\n",
        "    for j in range(n_class_2+1):\n",
        "        print(f\"{correlation_matrix_gt_1_pred_2[i, j]:4}\", end=' | ')\n",
        "    print()"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "n026SI7of3WO",
        "outputId": "afeb4fd0-c164-401f-8def-7f65a92033de"
      },
      "execution_count": 16,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "|    0 |  153 |    0 |    0 |   52 |    0 | 2105 |    1 |  161 | \n",
            "| 2797 | 2434 |    0 |    0 |    8 |   23 |    7 |    0 |   30 | \n",
            "|    0 | 2550 |    0 |    5 |    0 |   25 |    2 |    4 |    6 | \n",
            "|    0 |    1 | 2302 |    0 |    2 |   13 |    2 |  115 |   10 | \n",
            "|    0 |    7 |    0 | 2199 | 2454 | 2595 |   12 |   11 |   16 | \n",
            "|    0 |    3 |    0 |    1 |    2 |    2 | 2375 |    3 |    0 | \n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "for i in range(n_class_2):\n",
        "    print('|',end=' ')\n",
        "    for j in range(n_class_1+1):\n",
        "        print(f\"{correlation_matrix_gt_2_pred_1[i, j]:4}\", end=' | ')\n",
        "    print()"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "YV0rGmaLf418",
        "outputId": "6adf8ce7-3d7d-45ed-bf89-f66b83b7d3cc"
      },
      "execution_count": 17,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "|    0 | 2770 |    0 |    1 |   12 |    1 |    5 | \n",
            "|    1 | 2506 | 2497 |    0 |   50 |    3 |    9 | \n",
            "|    0 |    3 |    0 | 2387 |   17 |    4 |    6 | \n",
            "|    2 |    0 |    7 |    0 | 2236 |    4 |    5 | \n",
            "|    1 |    0 |    0 |    0 | 2454 |    4 |    2 | \n",
            "|    0 |    6 |    0 |    0 | 2589 |    0 |    0 | \n",
            "|    1 |    7 |    1 |    0 |   11 | 2457 |    5 | \n",
            "|    6 |    2 |   16 | 2045 |  121 |   30 |  243 | \n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "percentage_1 = np.zeros([n_class_1, n_class_2+1], dtype=np.float32)\n",
        "\n",
        "for i in range(n_class_1):\n",
        "    for j in range(n_class_2+1):\n",
        "        percentage_1[i, j] = correlation_matrix_gt_1_pred_2[i, j] / correlation_matrix_gt_1_pred_2[i,:].max()\n",
        "\n",
        "print(percentage_1)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "QwJoY9oMf_R1",
        "outputId": "55588e59-dede-49d5-f424-d8b1001da4cb"
      },
      "execution_count": 18,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[[0.0000000e+00 7.2684087e-02 0.0000000e+00 0.0000000e+00 2.4703087e-02\n",
            "  0.0000000e+00 1.0000000e+00 4.7505938e-04 7.6484561e-02]\n",
            " [1.0000000e+00 8.7021810e-01 0.0000000e+00 0.0000000e+00 2.8602073e-03\n",
            "  8.2230959e-03 2.5026815e-03 0.0000000e+00 1.0725778e-02]\n",
            " [0.0000000e+00 1.0000000e+00 0.0000000e+00 1.9607844e-03 0.0000000e+00\n",
            "  9.8039219e-03 7.8431371e-04 1.5686274e-03 2.3529411e-03]\n",
            " [0.0000000e+00 4.3440488e-04 1.0000000e+00 0.0000000e+00 8.6880976e-04\n",
            "  5.6472630e-03 8.6880976e-04 4.9956560e-02 4.3440484e-03]\n",
            " [0.0000000e+00 2.6974953e-03 0.0000000e+00 8.4739882e-01 9.4566476e-01\n",
            "  1.0000000e+00 4.6242774e-03 4.2389212e-03 6.1657033e-03]\n",
            " [0.0000000e+00 1.2631579e-03 0.0000000e+00 4.2105262e-04 8.4210525e-04\n",
            "  8.4210525e-04 1.0000000e+00 1.2631579e-03 0.0000000e+00]]\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "percentage_2 = np.zeros([n_class_2, n_class_1+1], dtype=np.float32)\n",
        "\n",
        "for i in range(n_class_2):\n",
        "    for j in range(n_class_1+1):\n",
        "        percentage_2[i, j] = correlation_matrix_gt_2_pred_1[i, j] / correlation_matrix_gt_2_pred_1[i,:].max()\n",
        "\n",
        "print(percentage_2)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "FKH0wzrigAI6",
        "outputId": "27d73068-e6f2-471a-fca9-d3e06fad4098"
      },
      "execution_count": 19,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[[0.00000000e+00 1.00000000e+00 0.00000000e+00 3.61010840e-04\n",
            "  4.33212984e-03 3.61010840e-04 1.80505414e-03]\n",
            " [3.99042299e-04 1.00000000e+00 9.96408641e-01 0.00000000e+00\n",
            "  1.99521147e-02 1.19712693e-03 3.59138078e-03]\n",
            " [0.00000000e+00 1.25680771e-03 0.00000000e+00 1.00000000e+00\n",
            "  7.12191034e-03 1.67574361e-03 2.51361541e-03]\n",
            " [8.94454366e-04 0.00000000e+00 3.13059031e-03 0.00000000e+00\n",
            "  1.00000000e+00 1.78890873e-03 2.23613600e-03]\n",
            " [4.07497952e-04 0.00000000e+00 0.00000000e+00 0.00000000e+00\n",
            "  1.00000000e+00 1.62999181e-03 8.14995903e-04]\n",
            " [0.00000000e+00 2.31749704e-03 0.00000000e+00 0.00000000e+00\n",
            "  1.00000000e+00 0.00000000e+00 0.00000000e+00]\n",
            " [4.07000422e-04 2.84900283e-03 4.07000422e-04 0.00000000e+00\n",
            "  4.47700452e-03 1.00000000e+00 2.03500199e-03]\n",
            " [2.93398532e-03 9.77995107e-04 7.82396086e-03 1.00000000e+00\n",
            "  5.91687039e-02 1.46699268e-02 1.18826404e-01]]\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "d1_classes = [\"d1_0\", \"d1_1,2\", \"d1_3\", \"d1_4\", \"d1_5,6,7\", \"d1_8\"]\n",
        "d2_classes = [\"d2_1\", \"d2_2,3\", \"d2_4\", \"d2_5\", \"d2_6\", \"d2_7\", \"d2_8\", \"d2_9\"]\n",
        "d1_classes_target = [\"d1_0\", \"d1_1,2\", \"d1_3\", \"d1_4\", \"d1_5,6,7\", \"d1_8\", \"outlier\"]\n",
        "d2_classes_target = [\"d2_1\", \"d2_2,3\", \"d2_4\", \"d2_5\", \"d2_6\", \"d2_7\", \"d2_8\", \"d2_9\", \"outlier\"]\n",
        "\n",
        "concat_classes  = [\"d1_0\", \"d1_1,2\", \"d1_3\", \"d1_4\", \"d1_5,6,7\", \"d1_8\", \"d2_1\", \"d2_2,3\", \"d2_4\", \"d2_5\", \"d2_6\", \"d2_7\", \"d2_8\", \"d2_9\"]\n",
        "\n",
        "d1_mapping = {d1_class: [d1_class] for d1_class in d1_classes}\n",
        "d2_mapping = {d2_class: [d2_class] for d2_class in d2_classes}"
      ],
      "metadata": {
        "id": "ZQJIVtdlgCJg"
      },
      "execution_count": 20,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "edges_d1_gt_d2_pred = {d1_classes[ind]: d2_classes_target[int(pred)] for ind, pred in enumerate(correlation_matrix_gt_1_pred_2.argmax(1))}\n",
        "edges_d2_gt_d1_pred = {d2_classes[ind]: d1_classes_target[int(pred)] for ind, pred in enumerate(correlation_matrix_gt_2_pred_1.argmax(1))}\n",
        "\n",
        "print(\"EDGES d1 -> d2\")\n",
        "print()\n",
        "keys = list(edges_d1_gt_d2_pred.keys())\n",
        "for key in keys:\n",
        "    print(key + ' -> ' + edges_d1_gt_d2_pred[key])\n",
        "    if edges_d1_gt_d2_pred[key] == \"outlier\":\n",
        "        del(edges_d1_gt_d2_pred[key])\n",
        "\n",
        "print()\n",
        "print()\n",
        "print()\n",
        "print(\"EDGES d2 -> d1\")\n",
        "print()\n",
        "keys = list(edges_d2_gt_d1_pred.keys())\n",
        "for key in keys:\n",
        "    print(key + ' -> ' + edges_d2_gt_d1_pred[key])\n",
        "    if edges_d2_gt_d1_pred[key] == \"outlier\":\n",
        "        del(edges_d2_gt_d1_pred[key])\n",
        ""
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "d_9SyiPlgOst",
        "outputId": "47a66d94-52be-4236-fffc-f5a4b21a2f44"
      },
      "execution_count": 22,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "EDGES d1 -> d2\n",
            "\n",
            "d1_0 -> d2_8\n",
            "d1_1,2 -> d2_1\n",
            "d1_3 -> d2_2,3\n",
            "d1_4 -> d2_4\n",
            "d1_5,6,7 -> d2_7\n",
            "d1_8 -> d2_8\n",
            "\n",
            "\n",
            "\n",
            "EDGES d2 -> d1\n",
            "\n",
            "d2_1 -> d1_1,2\n",
            "d2_2,3 -> d1_1,2\n",
            "d2_4 -> d1_4\n",
            "d2_5 -> d1_5,6,7\n",
            "d2_6 -> d1_5,6,7\n",
            "d2_7 -> d1_5,6,7\n",
            "d2_8 -> d1_8\n",
            "d2_9 -> d1_4\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "bidirectional_edges = []\n",
        "\n",
        "items = list(edges_d1_gt_d2_pred.items())\n",
        "\n",
        "for d1_gt, d2_pred in items:\n",
        "    if d2_pred in edges_d2_gt_d1_pred and edges_d2_gt_d1_pred[d2_pred]==d1_gt:\n",
        "        bidirectional_edges.append((d1_gt, d2_pred))\n",
        "        del(edges_d2_gt_d1_pred[d2_pred])\n",
        "        del(edges_d1_gt_d2_pred[d1_gt])\n",
        "\n",
        "\n",
        "print(\"BIDIRECTIONAL EDGES\")\n",
        "print()\n",
        "for pair in bidirectional_edges:\n",
        "    print (pair[0] + ' <-> ' + pair[1])"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "J32oykKmgSBc",
        "outputId": "20a78214-188b-4839-dfcc-fe9cb036e713"
      },
      "execution_count": 23,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "BIDIRECTIONAL EDGES\n",
            "\n",
            "d1_1,2 <-> d2_1\n",
            "d1_4 <-> d2_4\n",
            "d1_5,6,7 <-> d2_7\n",
            "d1_8 <-> d2_8\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "def build_chain(current, visited, edges_1, edges_2):\n",
        "    if current in visited:\n",
        "        return 'CIRCLE'\n",
        "    if current not in edges_1:\n",
        "        return current\n",
        "    else:\n",
        "        visited.append(current)\n",
        "        return current + ' -> ' + build_chain(edges_1[current], visited, edges_2, edges_1)\n",
        "\n",
        "def build_chains(edges_1, edges_2):\n",
        "    chains = []\n",
        "    for key in edges_1.keys():\n",
        "        chains.append(build_chain(key, [], edges_1, edges_2))\n",
        "\n",
        "    for key in edges_2.keys():\n",
        "        chains.append(' -> ' + build_chain(key, [], edges_2, edges_1))\n",
        "\n",
        "    return chains"
      ],
      "metadata": {
        "id": "yG4QeBX1gTv9"
      },
      "execution_count": 24,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "chains = build_chains(edges_d1_gt_d2_pred, edges_d2_gt_d1_pred)"
      ],
      "metadata": {
        "id": "nrIhma1kgVt-"
      },
      "execution_count": 25,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "chains_sorted = sorted(chains, key=len)\n",
        "chains_filtered_subchains = []\n",
        "\n",
        "for i in range(len(chains_sorted)):\n",
        "    add = True\n",
        "    for j in range(i+1, len(chains_sorted)):\n",
        "        if chains_sorted[i] in chains_sorted[j]:\n",
        "            add=False\n",
        "    if add:\n",
        "        chains_filtered_subchains.append(chains_sorted[i])"
      ],
      "metadata": {
        "id": "iX_6ZjghgXaW"
      },
      "execution_count": 26,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "chains_filter_single = []\n",
        "accepted_edges = []\n",
        "\n",
        "for chain in chains_filtered_subchains:\n",
        "    nodes = chain.split(' -> ')\n",
        "    if len(nodes[0]) == 0:\n",
        "        nodes = nodes[1:]\n",
        "    if len(nodes)>2:\n",
        "        chains_filter_single.append(chain)\n",
        "    else:\n",
        "        accepted_edges.append(chain)"
      ],
      "metadata": {
        "id": "jRV40ArAgZ8v"
      },
      "execution_count": 27,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "for chain in chains_filter_single:\n",
        "    print(chain)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "dUUAHBf5gccr",
        "outputId": "391177b1-00aa-440d-9789-8d9c691a0bea"
      },
      "execution_count": 28,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "d1_3 -> d2_2,3 -> d1_1,2\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "for edge in accepted_edges:\n",
        "    print(edge)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "p82amf_MgdFs",
        "outputId": "9ee3a2d7-dc26-4857-86c6-bbf5878d60c7"
      },
      "execution_count": 29,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "d1_0 -> d2_8\n",
            " -> d2_9 -> d1_4\n",
            " -> d2_5 -> d1_5,6,7\n",
            " -> d2_6 -> d1_5,6,7\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "groups = []\n",
        "\n",
        "for chain in chains_filter_single:\n",
        "    new_group = True\n",
        "    for group in groups:\n",
        "        add_this_group=False\n",
        "        for group_chain in group:\n",
        "            chain_set = chain.split(' -> ')\n",
        "            if len(chain_set[0]) == 0:\n",
        "                chain_set = chain_set[1:]\n",
        "            chain_set = set(chain_set)\n",
        "\n",
        "            group_chain_set = group_chain.split(' -> ')\n",
        "            if len(group_chain_set[0]) == 0:\n",
        "                group_chain_set = group_chain_set[1:]\n",
        "            group_chain_set = set(group_chain_set)\n",
        "\n",
        "            if not group_chain_set.isdisjoint(chain_set):\n",
        "                add_this_group=True\n",
        "                new_group=False\n",
        "                break\n",
        "        if add_this_group:\n",
        "            group.append(chain)\n",
        "            break\n",
        "    if new_group:\n",
        "        groups.append([chain])"
      ],
      "metadata": {
        "id": "6wTEsXwfggt5"
      },
      "execution_count": 30,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "print(groups)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "6ANDY0xrghol",
        "outputId": "698c736f-1466-417e-fa24-b06dfd01f150"
      },
      "execution_count": 31,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[['d1_3 -> d2_2,3 -> d1_1,2']]\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "contradictory_pairs = []\n",
        "\n",
        "for group in groups:\n",
        "    contradictory_pair = (set(), set())\n",
        "    for chain in group:\n",
        "        vertices = chain.split(' -> ')\n",
        "        hypothetical_edges_0 = set()\n",
        "        hypothetical_edges_1 = set()\n",
        "        for i in range(1,len(vertices),2):\n",
        "            if len(vertices[i-1])>0:\n",
        "                hypothetical_edges_0.add(vertices[i-1] + ' -> ' + vertices[i])\n",
        "        for i in range(1, len(vertices)-1,2):\n",
        "            hypothetical_edges_1.add(vertices[i] + ' -> ' + vertices[i+1])\n",
        "        contradictory_pair = (contradictory_pair[0].union(hypothetical_edges_0), contradictory_pair[1].union(hypothetical_edges_1))\n",
        "    contradictory_pairs.append(contradictory_pair)"
      ],
      "metadata": {
        "id": "x-ZfWVUNgodp"
      },
      "execution_count": 32,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "d1_map_accepted = copy.deepcopy(d1_mapping)\n",
        "d2_map_accepted = copy.deepcopy(d2_mapping)\n",
        "\n",
        "[d1_map_accepted[d1_class].append(d2_class) for d1_class, d2_class in bidirectional_edges]\n",
        "[d2_map_accepted[d2_class].append(d1_class) for d1_class, d2_class in bidirectional_edges]\n",
        "\n",
        "for edge in accepted_edges:\n",
        "    nodes = edge.split(' -> ')\n",
        "    d1_node = nodes[2] if len(nodes[0])==0 else nodes[0]\n",
        "    d2_node = nodes[1]\n",
        "    d1_map_accepted[d1_node].append(d2_node)\n",
        "    d2_map_accepted[d2_node].append(d1_node)\n",
        "\n",
        "print(\"Dataset 1\")\n",
        "print()\n",
        "for key in d1_map_accepted:\n",
        "    print(key, d1_map_accepted[key])\n",
        "\n",
        "print()\n",
        "print()\n",
        "print(\"Dataset 2\")\n",
        "print()\n",
        "for key in d2_map_accepted:\n",
        "    print(key, d2_map_accepted[key])"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "hfpBC0ExgpOQ",
        "outputId": "a804a275-34b9-4810-84b1-bac7458532df"
      },
      "execution_count": 33,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Dataset 1\n",
            "\n",
            "d1_0 ['d1_0', 'd2_8']\n",
            "d1_1,2 ['d1_1,2', 'd2_1']\n",
            "d1_3 ['d1_3']\n",
            "d1_4 ['d1_4', 'd2_4', 'd2_9']\n",
            "d1_5,6,7 ['d1_5,6,7', 'd2_7', 'd2_5', 'd2_6']\n",
            "d1_8 ['d1_8', 'd2_8']\n",
            "\n",
            "\n",
            "Dataset 2\n",
            "\n",
            "d2_1 ['d2_1', 'd1_1,2']\n",
            "d2_2,3 ['d2_2,3']\n",
            "d2_4 ['d2_4', 'd1_4']\n",
            "d2_5 ['d2_5', 'd1_5,6,7']\n",
            "d2_6 ['d2_6', 'd1_5,6,7']\n",
            "d2_7 ['d2_7', 'd1_5,6,7']\n",
            "d2_8 ['d2_8', 'd1_8', 'd1_0']\n",
            "d2_9 ['d2_9', 'd1_4']\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "dataset_2_train_no_offset = MNISTSplitDataset(mnist_trainset_2_train, lambda x: mapping_2[x])\n",
        "train_eval_loader_2_no_offset = DataLoader(dataset_2_train_no_offset, batch_size=1, shuffle=False)"
      ],
      "metadata": {
        "id": "bFOFRn-ogxKx"
      },
      "execution_count": 34,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def calculate_accuracy(loader, mapping_matrix):\n",
        "    correct = 0\n",
        "    total = 0\n",
        "    with torch.no_grad():\n",
        "        for i, batch in enumerate(loader):\n",
        "            total += 1\n",
        "            model_features_1(batch['input'].cuda())\n",
        "            features_val = activation['avgpool_1'].squeeze((2,3))\n",
        "            head_val = model_head_1(features_val).softmax(1)\n",
        "            mapping_val = torch.matmul(head_val, mapping_matrix.permute(1,0))\n",
        "            pred = mapping_val.argmax(1).detach().cpu()\n",
        "            if pred == batch['label']:\n",
        "                correct += 1\n",
        "    acc = correct / total\n",
        "    return acc"
      ],
      "metadata": {
        "id": "yTZqtJ8Eg1xF"
      },
      "execution_count": 35,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "accepted_hypothetical_edges = []\n",
        "for i, (h0, h1) in enumerate(contradictory_pairs):\n",
        "    d1_map_h0 = copy.deepcopy(d1_map_accepted)\n",
        "    d2_map_h0 = copy.deepcopy(d2_map_accepted)\n",
        "\n",
        "    for edge in h0:\n",
        "        vertices = edge.split(\" -> \")\n",
        "        d1_map_h0[vertices[0]].append(vertices[1])\n",
        "        d2_map_h0[vertices[1]].append(vertices[0])\n",
        "\n",
        "    mapping_matrix_1 = torch.zeros((n_class_1, n_class_1 + n_class_2))\n",
        "\n",
        "    for key in d1_map_h0.keys():\n",
        "        for concat_class in d1_map_h0[key]:\n",
        "            mapping_matrix_1[d1_classes.index(key), concat_classes.index(concat_class)] = 1\n",
        "\n",
        "    mapping_matrix_1 = mapping_matrix_1.cuda()\n",
        "\n",
        "    mapping_matrix_2 = torch.zeros((n_class_2, n_class_1 + n_class_2))\n",
        "\n",
        "    for key in d2_map_h0.keys():\n",
        "        for concat_class in d2_map_h0[key]:\n",
        "            mapping_matrix_2[d2_classes.index(key), concat_classes.index(concat_class)] = 1\n",
        "\n",
        "    mapping_matrix_2 = mapping_matrix_2.cuda()\n",
        "\n",
        "    acc_1_h0 = calculate_accuracy(train_eval_loader_1, mapping_matrix_1)\n",
        "    acc_2_h0 = calculate_accuracy(train_eval_loader_2_no_offset, mapping_matrix_2)\n",
        "\n",
        "    print(f\"Accuracy h0: {acc_1_h0}, {acc_2_h0}\")\n",
        "\n",
        "    d1_map_h1 = copy.deepcopy(d1_map_accepted)\n",
        "    d2_map_h1 = copy.deepcopy(d2_map_accepted)\n",
        "\n",
        "    for edge in h1:\n",
        "        vertices = edge.split(\" -> \")\n",
        "        d1_map_h1[vertices[1]].append(vertices[0])\n",
        "        d2_map_h1[vertices[0]].append(vertices[1])\n",
        "\n",
        "    mapping_matrix_1 = torch.zeros((n_class_1, n_class_1 + n_class_2))\n",
        "\n",
        "    for key in d1_map_h1.keys():\n",
        "        for concat_class in d1_map_h1[key]:\n",
        "            mapping_matrix_1[d1_classes.index(key), concat_classes.index(concat_class)] = 1\n",
        "\n",
        "    mapping_matrix_1 = mapping_matrix_1.cuda()\n",
        "\n",
        "    mapping_matrix_2 = torch.zeros((n_class_2, n_class_1 + n_class_2))\n",
        "\n",
        "    for key in d2_map_h1.keys():\n",
        "        for concat_class in d2_map_h1[key]:\n",
        "            mapping_matrix_2[d2_classes.index(key), concat_classes.index(concat_class)] = 1\n",
        "\n",
        "    mapping_matrix_2 = mapping_matrix_2.cuda()\n",
        "\n",
        "    acc_1_h1 = calculate_accuracy(train_eval_loader_1, mapping_matrix_1)\n",
        "    acc_2_h1 = calculate_accuracy(train_eval_loader_2_no_offset, mapping_matrix_2)\n",
        "\n",
        "    print(f\"Accuracy h1: {acc_1_h1}, {acc_2_h1}\")\n",
        "\n",
        "    if (acc_1_h0 + acc_2_h0) > (acc_1_h1 + acc_2_h1):\n",
        "        d1_map_accepted = d1_map_h0\n",
        "        d2_map_accepted = d2_map_h0\n",
        "        accepted_hypothetical_edges.extend(h0)\n",
        "    else:\n",
        "        d1_map_accepted = d1_map_h1\n",
        "        d2_map_accepted = d2_map_h1\n",
        "        accepted_hypothetical_edges.extend(h1)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "5Wy20Jcjg58k",
        "outputId": "19a3ad8d-3271-435a-d821-28cb80341a79"
      },
      "execution_count": 36,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Accuracy h0: 0.8963983992885727, 0.9771829360323168\n",
            "Accuracy h1: 0.9530457981325033, 0.9839747858125805\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "print(\"Dataset 0\")\n",
        "print()\n",
        "for key in d1_map_accepted:\n",
        "    print(key, d1_map_accepted[key])\n",
        "\n",
        "print()\n",
        "print()\n",
        "print(\"Dataset 1\")\n",
        "print()\n",
        "for key in d2_map_accepted:\n",
        "    print(key, d2_map_accepted[key])"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "OuK-I_2zg6ri",
        "outputId": "150f3c0d-2afa-4c49-eb73-121dd7cacc1f"
      },
      "execution_count": 37,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Dataset 0\n",
            "\n",
            "d1_0 ['d1_0', 'd2_8']\n",
            "d1_1,2 ['d1_1,2', 'd2_1', 'd2_2,3']\n",
            "d1_3 ['d1_3']\n",
            "d1_4 ['d1_4', 'd2_4', 'd2_9']\n",
            "d1_5,6,7 ['d1_5,6,7', 'd2_7', 'd2_5', 'd2_6']\n",
            "d1_8 ['d1_8', 'd2_8']\n",
            "\n",
            "\n",
            "Dataset 1\n",
            "\n",
            "d2_1 ['d2_1', 'd1_1,2']\n",
            "d2_2,3 ['d2_2,3', 'd1_1,2']\n",
            "d2_4 ['d2_4', 'd1_4']\n",
            "d2_5 ['d2_5', 'd1_5,6,7']\n",
            "d2_6 ['d2_6', 'd1_5,6,7']\n",
            "d2_7 ['d2_7', 'd1_5,6,7']\n",
            "d2_8 ['d2_8', 'd1_8', 'd1_0']\n",
            "d2_9 ['d2_9', 'd1_4']\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "test_mapping = torch.tensor([48, 8, 9, 17, 26, 35, 36, 37, 46, 61])\n",
        "\n",
        "true_relations = torch.zeros((62), dtype=torch.int32)\n",
        "\n",
        "for true_relation in test_mapping:\n",
        "    true_relations[true_relation] = 1"
      ],
      "metadata": {
        "id": "ovjPmo_qhAN2"
      },
      "execution_count": 39,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "shared_relation_predictions = torch.zeros((n_class_1, n_class_2), dtype=torch.int32)\n",
        "ds_1_relation_predictions = torch.zeros((n_class_1,), dtype=torch.int32)\n",
        "ds_2_relation_predictions = torch.zeros((n_class_2,), dtype=torch.int32)"
      ],
      "metadata": {
        "id": "raHVJByBhLr-"
      },
      "execution_count": 42,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "for key in d1_map_accepted:\n",
        "    if len(d1_map_accepted[key]) == 1:\n",
        "        ds_1_relation_predictions[d1_classes.index(key)] = 1\n",
        "    else:\n",
        "        for mapped_class in d1_map_accepted[key]:\n",
        "            if mapped_class != key:\n",
        "                shared_relation_predictions[d1_classes.index(key), d2_classes.index(mapped_class)] = 1\n",
        "\n",
        "for key in d2_map_accepted:\n",
        "    if len(d2_map_accepted[key]) == 1:\n",
        "        ds_2_relation_predictions[d1_classes.index(key)] = 1\n",
        "    else:\n",
        "        for mapped_class in d2_map_accepted[key]:\n",
        "            if mapped_class != key:\n",
        "                shared_relation_predictions[d1_classes.index(mapped_class), d2_classes.index(key)] = 1"
      ],
      "metadata": {
        "id": "5i4Q-8VlhOuH"
      },
      "execution_count": 43,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "final_relation_predictions = np.concatenate((shared_relation_predictions.reshape(48), ds_1_relation_predictions, ds_2_relation_predictions)).reshape((62))"
      ],
      "metadata": {
        "id": "tVJYfDLehfJa"
      },
      "execution_count": 45,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "print(shared_relation_predictions)\n",
        "print(ds_1_relation_predictions)\n",
        "print(ds_2_relation_predictions)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "KCLDXsHrhiwd",
        "outputId": "3cb9f5cc-7dfc-4e44-9403-3e6c27d059bb"
      },
      "execution_count": 46,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "tensor([[0, 0, 0, 0, 0, 0, 1, 0],\n",
            "        [1, 1, 0, 0, 0, 0, 0, 0],\n",
            "        [0, 0, 0, 0, 0, 0, 0, 0],\n",
            "        [0, 0, 1, 0, 0, 0, 0, 1],\n",
            "        [0, 0, 0, 1, 1, 1, 0, 0],\n",
            "        [0, 0, 0, 0, 0, 0, 1, 0]], dtype=torch.int32)\n",
            "tensor([0, 0, 1, 0, 0, 0], dtype=torch.int32)\n",
            "tensor([0, 0, 0, 0, 0, 0, 0, 0], dtype=torch.int32)\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "ged = torch.logical_xor(torch.tensor(true_relations), torch.tensor(final_relation_predictions)).sum().item()\n",
        "print(ged)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "gxO4iIoXhkfF",
        "outputId": "5c9f9e89-7a46-478e-9b11-2864ede009ee"
      },
      "execution_count": 47,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "6\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/tmp/ipython-input-957551273.py:1: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
            "  ged = torch.logical_xor(torch.tensor(true_relations), torch.tensor(final_relation_predictions)).sum().item()\n"
          ]
        }
      ]
    }
  ]
}