{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "machine_shape": "hm",
      "gpuType": "V100",
      "collapsed_sections": [
        "UP_E2H4EBnsh",
        "z_gdUqNO5NYM"
      ]
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "#FactorCL on Synthetic Dataset"
      ],
      "metadata": {
        "id": "LTEV_elPBK32"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!git clone https://github.com/pliang279/FactorCL"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "IHKnbVkoASPI",
        "outputId": "c0e82d17-f04a-4550-96f6-bf60af567c61"
      },
      "execution_count": 1,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Cloning into 'FactorCL'...\n",
            "remote: Enumerating objects: 146, done.\u001b[K\n",
            "remote: Counting objects: 100% (30/30), done.\u001b[K\n",
            "remote: Compressing objects: 100% (29/29), done.\u001b[K\n",
            "remote: Total 146 (delta 11), reused 0 (delta 0), pack-reused 116\u001b[K\n",
            "Receiving objects: 100% (146/146), 317.87 KiB | 9.08 MiB/s, done.\n",
            "Resolving deltas: 100% (68/68), done.\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "%cd FactorCL"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "R_UCHEyzAbYQ",
        "outputId": "21e1d2a8-4e84-4bfc-b15f-cee17d7da7e1"
      },
      "execution_count": 2,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "/content/FactorCL\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "from Synthetic.dataset import*\n",
        "from synthetic_model import*\n",
        "from torch.utils.data import DataLoader\n",
        "from sklearn.linear_model import LogisticRegression"
      ],
      "metadata": {
        "id": "A0mgHF3bAgvo"
      },
      "execution_count": 3,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "hidden_dim=512\n",
        "embed_dim=128"
      ],
      "metadata": {
        "id": "5RWziN1_ANhB"
      },
      "execution_count": 4,
      "outputs": []
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "glk_GDn0AEvs",
        "outputId": "faacc9a2-e622-49be-c4e4-0e1f690cf0a7"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "['1', '2', '12']\n",
            "{'12': 10, '1': 6, '2': 6}\n",
            "{'12': 10, '1': 6, '2': 6}\n"
          ]
        }
      ],
      "source": [
        "\n",
        "# Define custom dimensions of features and labels\n",
        "feature_dim_info = dict()\n",
        "label_dim_info = dict()\n",
        "\n",
        "intersections = get_intersections(num_modalities=2)\n",
        "\n",
        "feature_dim_info['12'] = 10\n",
        "feature_dim_info['1'] = 6\n",
        "feature_dim_info['2'] = 6\n",
        "\n",
        "label_dim_info['12'] = 10\n",
        "label_dim_info['1'] = 6\n",
        "label_dim_info['2'] = 6\n",
        "\n",
        "print(intersections)\n",
        "print(feature_dim_info)\n",
        "print(label_dim_info)\n",
        "\n",
        "# Get datasets\n",
        "total_data, total_labels, total_raw_features = generate_data(30000, 2, feature_dim_info, label_dim_info)\n",
        "total_labels = get_labels(label_dim_info, total_raw_features)\n",
        "\n",
        "dataset = MultimodalDataset(total_data, total_labels)\n",
        "\n",
        "# Dataloader\n",
        "batch_size = 256\n",
        "num_data = total_labels.shape[0]\n",
        "train_dataset, test_dataset = torch.utils.data.random_split(dataset, [int(0.8*num_data), num_data-int(0.8*num_data)])\n",
        "\n",
        "train_loader = DataLoader(train_dataset, shuffle=True, drop_last=True,\n",
        "                            batch_size=batch_size)\n",
        "test_loader = DataLoader(test_dataset, shuffle=False, drop_last=False,\n",
        "                            batch_size=batch_size)\n",
        "data_loader = DataLoader(dataset, shuffle=False, drop_last=False,\n",
        "                            batch_size=batch_size)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "##Linear Probing"
      ],
      "metadata": {
        "id": "xzKxUpkiwzhb"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "###FactorCL-SUP"
      ],
      "metadata": {
        "id": "UP_E2H4EBnsh"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "factorcl_sup = FactorCLSUP(A_dim, B_dim, 20, hidden_dim, embed_dim).cuda()\n",
        "train_sup_model(factorcl_sup, train_loader, dataset, num_epoch=10, num_club_iter=1)\n",
        "factorcl_sup.eval()"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "oUrkdAhgAPgJ",
        "outputId": "0e02ddd8-9388-46e1-fdb6-4aa97bfcbdf0"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "iter:  0  i_batch:  0  loss:  9.623612277209759e-05\n",
            "iter:  1  i_batch:  0  loss:  -3.0933303833007812\n",
            "iter:  2  i_batch:  0  loss:  -5.884997367858887\n",
            "iter:  3  i_batch:  0  loss:  -7.007236480712891\n",
            "iter:  4  i_batch:  0  loss:  -7.123583793640137\n",
            "iter:  5  i_batch:  0  loss:  -7.228031158447266\n",
            "iter:  6  i_batch:  0  loss:  -7.025860786437988\n",
            "iter:  7  i_batch:  0  loss:  -7.642902374267578\n",
            "iter:  8  i_batch:  0  loss:  -8.002883911132812\n",
            "iter:  9  i_batch:  0  loss:  -7.967354774475098\n"
          ]
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "FactorCLSUP(\n",
              "  (backbones): ModuleList(\n",
              "    (0-1): 2 x Sequential(\n",
              "      (0): Linear(in_features=100, out_features=512, bias=True)\n",
              "      (1): ReLU()\n",
              "      (2): Linear(in_features=512, out_features=512, bias=True)\n",
              "      (3): ReLU()\n",
              "      (4): Linear(in_features=512, out_features=512, bias=True)\n",
              "      (5): ReLU()\n",
              "      (6): Linear(in_features=512, out_features=128, bias=True)\n",
              "    )\n",
              "  )\n",
              "  (linears_infonce_x1x2): ModuleList(\n",
              "    (0-1): 2 x Sequential(\n",
              "      (0): Linear(in_features=128, out_features=128, bias=True)\n",
              "      (1): ReLU()\n",
              "      (2): Linear(in_features=128, out_features=128, bias=True)\n",
              "      (3): ReLU()\n",
              "      (4): Linear(in_features=128, out_features=128, bias=True)\n",
              "    )\n",
              "  )\n",
              "  (linears_club_x1x2_cond): ModuleList(\n",
              "    (0-1): 2 x Sequential(\n",
              "      (0): Linear(in_features=128, out_features=128, bias=True)\n",
              "      (1): ReLU()\n",
              "      (2): Linear(in_features=128, out_features=128, bias=True)\n",
              "      (3): ReLU()\n",
              "      (4): Linear(in_features=128, out_features=128, bias=True)\n",
              "    )\n",
              "  )\n",
              "  (linears_infonce_x1y): Sequential(\n",
              "    (0): Linear(in_features=128, out_features=128, bias=True)\n",
              "    (1): ReLU()\n",
              "    (2): Linear(in_features=128, out_features=128, bias=True)\n",
              "    (3): ReLU()\n",
              "    (4): Linear(in_features=128, out_features=128, bias=True)\n",
              "  )\n",
              "  (linears_infonce_x2y): Sequential(\n",
              "    (0): Linear(in_features=128, out_features=128, bias=True)\n",
              "    (1): ReLU()\n",
              "    (2): Linear(in_features=128, out_features=128, bias=True)\n",
              "    (3): ReLU()\n",
              "    (4): Linear(in_features=128, out_features=128, bias=True)\n",
              "  )\n",
              "  (linears_infonce_x1x2_cond): ModuleList(\n",
              "    (0-1): 2 x Sequential(\n",
              "      (0): Linear(in_features=128, out_features=128, bias=True)\n",
              "      (1): ReLU()\n",
              "      (2): Linear(in_features=128, out_features=128, bias=True)\n",
              "      (3): ReLU()\n",
              "      (4): Linear(in_features=128, out_features=128, bias=True)\n",
              "    )\n",
              "  )\n",
              "  (linears_club_x1x2): ModuleList(\n",
              "    (0-1): 2 x Sequential(\n",
              "      (0): Linear(in_features=128, out_features=128, bias=True)\n",
              "      (1): ReLU()\n",
              "      (2): Linear(in_features=128, out_features=128, bias=True)\n",
              "      (3): ReLU()\n",
              "      (4): Linear(in_features=128, out_features=128, bias=True)\n",
              "    )\n",
              "  )\n",
              "  (infonce_x1x2): InfoNCECritic(\n",
              "    (_f): Sequential(\n",
              "      (0): Linear(in_features=256, out_features=512, bias=True)\n",
              "      (1): ReLU()\n",
              "      (2): Linear(in_features=512, out_features=512, bias=True)\n",
              "      (3): ReLU()\n",
              "      (4): Linear(in_features=512, out_features=1, bias=True)\n",
              "    )\n",
              "  )\n",
              "  (club_x1x2_cond): CLUBInfoNCECritic(\n",
              "    (_f): Sequential(\n",
              "      (0): Linear(in_features=276, out_features=512, bias=True)\n",
              "      (1): ReLU()\n",
              "      (2): Linear(in_features=512, out_features=512, bias=True)\n",
              "      (3): ReLU()\n",
              "      (4): Linear(in_features=512, out_features=1, bias=True)\n",
              "    )\n",
              "  )\n",
              "  (infonce_x1y): InfoNCECritic(\n",
              "    (_f): Sequential(\n",
              "      (0): Linear(in_features=129, out_features=512, bias=True)\n",
              "      (1): ReLU()\n",
              "      (2): Linear(in_features=512, out_features=512, bias=True)\n",
              "      (3): ReLU()\n",
              "      (4): Linear(in_features=512, out_features=1, bias=True)\n",
              "    )\n",
              "  )\n",
              "  (infonce_x2y): InfoNCECritic(\n",
              "    (_f): Sequential(\n",
              "      (0): Linear(in_features=129, out_features=512, bias=True)\n",
              "      (1): ReLU()\n",
              "      (2): Linear(in_features=512, out_features=512, bias=True)\n",
              "      (3): ReLU()\n",
              "      (4): Linear(in_features=512, out_features=1, bias=True)\n",
              "    )\n",
              "  )\n",
              "  (infonce_x1x2_cond): InfoNCECritic(\n",
              "    (_f): Sequential(\n",
              "      (0): Linear(in_features=276, out_features=512, bias=True)\n",
              "      (1): ReLU()\n",
              "      (2): Linear(in_features=512, out_features=512, bias=True)\n",
              "      (3): ReLU()\n",
              "      (4): Linear(in_features=512, out_features=1, bias=True)\n",
              "    )\n",
              "  )\n",
              "  (club_x1x2): CLUBInfoNCECritic(\n",
              "    (_f): Sequential(\n",
              "      (0): Linear(in_features=256, out_features=512, bias=True)\n",
              "      (1): ReLU()\n",
              "      (2): Linear(in_features=512, out_features=512, bias=True)\n",
              "      (3): ReLU()\n",
              "      (4): Linear(in_features=512, out_features=1, bias=True)\n",
              "    )\n",
              "  )\n",
              ")"
            ]
          },
          "metadata": {},
          "execution_count": 5
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Simple evaluation using linear logistic regression model\n",
        "\n",
        "# Embeddings\n",
        "train_embeds = factorcl_sup.get_embedding(torch.stack(train_dataset[:][:-1]).cuda()).detach().cpu().numpy()\n",
        "train_labels = train_dataset[:][-1].detach().cpu().numpy()\n",
        "\n",
        "test_embeds = factorcl_sup.get_embedding(torch.stack(test_dataset[:][:-1]).cuda()).detach().cpu().numpy()\n",
        "test_labels = test_dataset[:][-1].detach().cpu().numpy()\n",
        "\n",
        "# Train Logistic Classifier\n",
        "clf = LogisticRegression(max_iter=200).fit(train_embeds, train_labels)\n",
        "score = clf.score(test_embeds, test_labels)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "ril93KsbBeHZ",
        "outputId": "e99477e0-9e3e-4748-fca8-5d05ccb8aa14"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.10/dist-packages/sklearn/utils/validation.py:1143: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n",
            "  y = column_or_1d(y, warn=True)\n",
            "/usr/local/lib/python3.10/dist-packages/sklearn/linear_model/_logistic.py:458: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
            "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
            "\n",
            "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
            "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
            "Please also refer to the documentation for alternative solver options:\n",
            "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
            "  n_iter_i = _check_optimize_result(\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "score"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "hhFO10gs272V",
        "outputId": "899044cb-8c97-4d8f-b555-5f4643e8f234"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "0.9863333333333333"
            ]
          },
          "metadata": {},
          "execution_count": 8
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "###FactorCL-SSL"
      ],
      "metadata": {
        "id": "z_gdUqNO5NYM"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "factorcl_ssl = FactorCLSSL(A_dim, B_dim, hidden_dim, embed_dim).cuda()\n",
        "train_sup_model(factorcl_ssl, train_loader, dataset, num_epoch=10, num_club_iter=1)\n",
        "factorcl_ssl.eval()"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "aqCK0pyg5NYS",
        "outputId": "b441fbac-ee51-4e7a-ccbe-1bce52fcece4"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "iter:  0  i_batch:  0  loss:  3.082677721977234e-06\n",
            "iter:  1  i_batch:  0  loss:  -7.57092809677124\n",
            "iter:  2  i_batch:  0  loss:  -8.953313827514648\n",
            "iter:  3  i_batch:  0  loss:  -10.78848648071289\n",
            "iter:  4  i_batch:  0  loss:  -12.652485847473145\n",
            "iter:  5  i_batch:  0  loss:  -14.410904884338379\n",
            "iter:  6  i_batch:  0  loss:  -14.283191680908203\n",
            "iter:  7  i_batch:  0  loss:  -14.439704895019531\n",
            "iter:  8  i_batch:  0  loss:  -15.733597755432129\n",
            "iter:  9  i_batch:  0  loss:  -15.78852653503418\n"
          ]
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "FactorCLSSL(\n",
              "  (backbones): ModuleList(\n",
              "    (0): Sequential(\n",
              "      (0): Linear(in_features=100, out_features=512, bias=True)\n",
              "      (1): ReLU()\n",
              "      (2): Linear(in_features=512, out_features=512, bias=True)\n",
              "      (3): ReLU()\n",
              "      (4): Linear(in_features=512, out_features=512, bias=True)\n",
              "      (5): ReLU()\n",
              "      (6): Linear(in_features=512, out_features=128, bias=True)\n",
              "    )\n",
              "    (1): Sequential(\n",
              "      (0): Linear(in_features=100, out_features=512, bias=True)\n",
              "      (1): ReLU()\n",
              "      (2): Linear(in_features=512, out_features=512, bias=True)\n",
              "      (3): ReLU()\n",
              "      (4): Linear(in_features=512, out_features=512, bias=True)\n",
              "      (5): ReLU()\n",
              "      (6): Linear(in_features=512, out_features=128, bias=True)\n",
              "    )\n",
              "  )\n",
              "  (linears_infonce_x1x2): ModuleList(\n",
              "    (0): Sequential(\n",
              "      (0): Linear(in_features=128, out_features=128, bias=True)\n",
              "      (1): ReLU()\n",
              "      (2): Linear(in_features=128, out_features=128, bias=True)\n",
              "      (3): ReLU()\n",
              "      (4): Linear(in_features=128, out_features=128, bias=True)\n",
              "    )\n",
              "    (1): Sequential(\n",
              "      (0): Linear(in_features=128, out_features=128, bias=True)\n",
              "      (1): ReLU()\n",
              "      (2): Linear(in_features=128, out_features=128, bias=True)\n",
              "      (3): ReLU()\n",
              "      (4): Linear(in_features=128, out_features=128, bias=True)\n",
              "    )\n",
              "  )\n",
              "  (linears_club_x1x2_cond): ModuleList(\n",
              "    (0): Sequential(\n",
              "      (0): Linear(in_features=128, out_features=128, bias=True)\n",
              "      (1): ReLU()\n",
              "      (2): Linear(in_features=128, out_features=128, bias=True)\n",
              "      (3): ReLU()\n",
              "      (4): Linear(in_features=128, out_features=128, bias=True)\n",
              "    )\n",
              "    (1): Sequential(\n",
              "      (0): Linear(in_features=128, out_features=128, bias=True)\n",
              "      (1): ReLU()\n",
              "      (2): Linear(in_features=128, out_features=128, bias=True)\n",
              "      (3): ReLU()\n",
              "      (4): Linear(in_features=128, out_features=128, bias=True)\n",
              "    )\n",
              "  )\n",
              "  (linears_infonce_x1y): Sequential(\n",
              "    (0): Linear(in_features=128, out_features=128, bias=True)\n",
              "    (1): ReLU()\n",
              "    (2): Linear(in_features=128, out_features=128, bias=True)\n",
              "    (3): ReLU()\n",
              "    (4): Linear(in_features=128, out_features=128, bias=True)\n",
              "  )\n",
              "  (linears_infonce_x2y): Sequential(\n",
              "    (0): Linear(in_features=128, out_features=128, bias=True)\n",
              "    (1): ReLU()\n",
              "    (2): Linear(in_features=128, out_features=128, bias=True)\n",
              "    (3): ReLU()\n",
              "    (4): Linear(in_features=128, out_features=128, bias=True)\n",
              "  )\n",
              "  (linears_infonce_x1x2_cond): ModuleList(\n",
              "    (0): Sequential(\n",
              "      (0): Linear(in_features=128, out_features=128, bias=True)\n",
              "      (1): ReLU()\n",
              "      (2): Linear(in_features=128, out_features=128, bias=True)\n",
              "      (3): ReLU()\n",
              "      (4): Linear(in_features=128, out_features=128, bias=True)\n",
              "    )\n",
              "    (1): Sequential(\n",
              "      (0): Linear(in_features=128, out_features=128, bias=True)\n",
              "      (1): ReLU()\n",
              "      (2): Linear(in_features=128, out_features=128, bias=True)\n",
              "      (3): ReLU()\n",
              "      (4): Linear(in_features=128, out_features=128, bias=True)\n",
              "    )\n",
              "  )\n",
              "  (linears_club_x1x2): ModuleList(\n",
              "    (0): Sequential(\n",
              "      (0): Linear(in_features=128, out_features=128, bias=True)\n",
              "      (1): ReLU()\n",
              "      (2): Linear(in_features=128, out_features=128, bias=True)\n",
              "      (3): ReLU()\n",
              "      (4): Linear(in_features=128, out_features=128, bias=True)\n",
              "    )\n",
              "    (1): Sequential(\n",
              "      (0): Linear(in_features=128, out_features=128, bias=True)\n",
              "      (1): ReLU()\n",
              "      (2): Linear(in_features=128, out_features=128, bias=True)\n",
              "      (3): ReLU()\n",
              "      (4): Linear(in_features=128, out_features=128, bias=True)\n",
              "    )\n",
              "  )\n",
              "  (infonce_x1x2): InfoNCECritic(\n",
              "    (_f): Sequential(\n",
              "      (0): Linear(in_features=256, out_features=512, bias=True)\n",
              "      (1): ReLU()\n",
              "      (2): Linear(in_features=512, out_features=512, bias=True)\n",
              "      (3): ReLU()\n",
              "      (4): Linear(in_features=512, out_features=1, bias=True)\n",
              "    )\n",
              "  )\n",
              "  (club_x1x2_cond): CLUBInfoNCECritic(\n",
              "    (_f): Sequential(\n",
              "      (0): Linear(in_features=512, out_features=512, bias=True)\n",
              "      (1): ReLU()\n",
              "      (2): Linear(in_features=512, out_features=512, bias=True)\n",
              "      (3): ReLU()\n",
              "      (4): Linear(in_features=512, out_features=1, bias=True)\n",
              "    )\n",
              "  )\n",
              "  (infonce_x1y): InfoNCECritic(\n",
              "    (_f): Sequential(\n",
              "      (0): Linear(in_features=256, out_features=512, bias=True)\n",
              "      (1): ReLU()\n",
              "      (2): Linear(in_features=512, out_features=512, bias=True)\n",
              "      (3): ReLU()\n",
              "      (4): Linear(in_features=512, out_features=1, bias=True)\n",
              "    )\n",
              "  )\n",
              "  (infonce_x2y): InfoNCECritic(\n",
              "    (_f): Sequential(\n",
              "      (0): Linear(in_features=256, out_features=512, bias=True)\n",
              "      (1): ReLU()\n",
              "      (2): Linear(in_features=512, out_features=512, bias=True)\n",
              "      (3): ReLU()\n",
              "      (4): Linear(in_features=512, out_features=1, bias=True)\n",
              "    )\n",
              "  )\n",
              "  (infonce_x1x2_cond): InfoNCECritic(\n",
              "    (_f): Sequential(\n",
              "      (0): Linear(in_features=512, out_features=512, bias=True)\n",
              "      (1): ReLU()\n",
              "      (2): Linear(in_features=512, out_features=512, bias=True)\n",
              "      (3): ReLU()\n",
              "      (4): Linear(in_features=512, out_features=1, bias=True)\n",
              "    )\n",
              "  )\n",
              "  (club_x1x2): CLUBInfoNCECritic(\n",
              "    (_f): Sequential(\n",
              "      (0): Linear(in_features=256, out_features=512, bias=True)\n",
              "      (1): ReLU()\n",
              "      (2): Linear(in_features=512, out_features=512, bias=True)\n",
              "      (3): ReLU()\n",
              "      (4): Linear(in_features=512, out_features=1, bias=True)\n",
              "    )\n",
              "  )\n",
              ")"
            ]
          },
          "metadata": {},
          "execution_count": 7
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Simple evaluation using linear logistic regression model\n",
        "\n",
        "# Embeddings\n",
        "train_embeds = factorcl_ssl.get_embedding(torch.stack(train_dataset[:][:-1]).cuda()).detach().cpu().numpy()\n",
        "train_labels = train_dataset[:][-1].detach().cpu().numpy()\n",
        "\n",
        "test_embeds = factorcl_ssl.get_embedding(torch.stack(test_dataset[:][:-1]).cuda()).detach().cpu().numpy()\n",
        "test_labels = test_dataset[:][-1].detach().cpu().numpy()\n",
        "\n",
        "# Train Logistic Classifier\n",
        "clf = LogisticRegression(max_iter=200).fit(train_embeds, train_labels)\n",
        "score = clf.score(test_embeds, test_labels)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "5a7075d7-8f47-43b5-d6de-5be795807bed",
        "id": "tpvNxNwG5NYS"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.10/dist-packages/sklearn/utils/validation.py:1143: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n",
            "  y = column_or_1d(y, warn=True)\n",
            "/usr/local/lib/python3.10/dist-packages/sklearn/linear_model/_logistic.py:458: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
            "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
            "\n",
            "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
            "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
            "Please also refer to the documentation for alternative solver options:\n",
            "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
            "  n_iter_i = _check_optimize_result(\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "score"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "adcf6b83-c595-49c3-a7c7-9f3363670ab2",
        "id": "1iHDduqL5NYS"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "0.9716666666666667"
            ]
          },
          "metadata": {},
          "execution_count": 9
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "###SimCLR"
      ],
      "metadata": {
        "id": "6icituq2Hkra"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# set use_label=True for SupCon model\n",
        "simclr_model = SupConModel(A_dim, B_dim, hidden_dim, embed_dim, use_label=False).cuda()\n",
        "simclr_optim = optim.Adam(simclr_model.parameters(), lr=lr)\n",
        "train_supcon(simclr_model, train_loader, simclr_optim, num_epoch=20)\n",
        "simclr_model.eval()"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "-pYPGQgNHiPN",
        "outputId": "f6bd0c12-f1f8-46c0-df8a-3856ee57effe"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "iter:  0  i_batch:  0  loss:  11.985010147094727\n",
            "iter:  1  i_batch:  0  loss:  5.873631477355957\n",
            "iter:  2  i_batch:  0  loss:  4.937842845916748\n",
            "iter:  3  i_batch:  0  loss:  4.577206611633301\n",
            "iter:  4  i_batch:  0  loss:  3.8152596950531006\n",
            "iter:  5  i_batch:  0  loss:  2.4593682289123535\n",
            "iter:  6  i_batch:  0  loss:  1.589794635772705\n",
            "iter:  7  i_batch:  0  loss:  1.177107810974121\n",
            "iter:  8  i_batch:  0  loss:  0.9892462491989136\n",
            "iter:  9  i_batch:  0  loss:  0.9151424169540405\n",
            "iter:  10  i_batch:  0  loss:  0.809130072593689\n",
            "iter:  11  i_batch:  0  loss:  0.6216927766799927\n",
            "iter:  12  i_batch:  0  loss:  0.600822925567627\n",
            "iter:  13  i_batch:  0  loss:  0.5172191858291626\n",
            "iter:  14  i_batch:  0  loss:  0.5168259143829346\n",
            "iter:  15  i_batch:  0  loss:  0.43819260597229004\n",
            "iter:  16  i_batch:  0  loss:  0.5183995962142944\n",
            "iter:  17  i_batch:  0  loss:  0.32648688554763794\n",
            "iter:  18  i_batch:  0  loss:  0.42932015657424927\n",
            "iter:  19  i_batch:  0  loss:  0.5603318810462952\n"
          ]
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "SupConModel(\n",
              "  (encoder_x1): Sequential(\n",
              "    (0): Linear(in_features=100, out_features=512, bias=True)\n",
              "    (1): ReLU()\n",
              "    (2): Linear(in_features=512, out_features=512, bias=True)\n",
              "    (3): ReLU()\n",
              "    (4): Linear(in_features=512, out_features=512, bias=True)\n",
              "    (5): ReLU()\n",
              "    (6): Linear(in_features=512, out_features=128, bias=True)\n",
              "  )\n",
              "  (encoder_x2): Sequential(\n",
              "    (0): Linear(in_features=100, out_features=512, bias=True)\n",
              "    (1): ReLU()\n",
              "    (2): Linear(in_features=512, out_features=512, bias=True)\n",
              "    (3): ReLU()\n",
              "    (4): Linear(in_features=512, out_features=512, bias=True)\n",
              "    (5): ReLU()\n",
              "    (6): Linear(in_features=512, out_features=128, bias=True)\n",
              "  )\n",
              "  (projection_x1): Sequential(\n",
              "    (0): Linear(in_features=128, out_features=128, bias=True)\n",
              "    (1): ReLU()\n",
              "    (2): Linear(in_features=128, out_features=128, bias=True)\n",
              "    (3): ReLU()\n",
              "    (4): Linear(in_features=128, out_features=128, bias=True)\n",
              "  )\n",
              "  (projection_x2): Sequential(\n",
              "    (0): Linear(in_features=128, out_features=128, bias=True)\n",
              "    (1): ReLU()\n",
              "    (2): Linear(in_features=128, out_features=128, bias=True)\n",
              "    (3): ReLU()\n",
              "    (4): Linear(in_features=128, out_features=128, bias=True)\n",
              "  )\n",
              "  (critic): SupConLoss()\n",
              ")"
            ]
          },
          "metadata": {},
          "execution_count": 13
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "train_embeds = simclr_model.get_embedding(torch.stack(train_dataset[:][:-1]).cuda()).detach().cpu().numpy()\n",
        "train_labels = train_dataset[:][-1].detach().cpu().numpy()\n",
        "\n",
        "# Embeddings\n",
        "test_embeds = simclr_model.get_embedding(torch.stack(test_dataset[:][:-1]).cuda()).detach().cpu().numpy()\n",
        "test_labels = test_dataset[:][-1].detach().cpu().numpy()\n",
        "\n",
        "# Train Logistic Classifier\n",
        "clf = LogisticRegression(max_iter=200).fit(train_embeds, train_labels)\n",
        "score = clf.score(test_embeds, test_labels)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "FvHpTE5gHoi6",
        "outputId": "7128738e-9cd0-4c3c-b288-62f83bb8da8e"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.10/dist-packages/sklearn/utils/validation.py:1143: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n",
            "  y = column_or_1d(y, warn=True)\n",
            "/usr/local/lib/python3.10/dist-packages/sklearn/linear_model/_logistic.py:458: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
            "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
            "\n",
            "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
            "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
            "Please also refer to the documentation for alternative solver options:\n",
            "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
            "  n_iter_i = _check_optimize_result(\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "score"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "rKl16Y8s3l7Y",
        "outputId": "878d4697-aa22-445a-d4a5-bc62a663372b"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "0.9555"
            ]
          },
          "metadata": {},
          "execution_count": 15
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "#MI Estimation"
      ],
      "metadata": {
        "id": "0BYAcarbw1aJ"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "We can estimate the MI between any two sets of vectors using the Probablistic Classifier approach proposed in [Neural Methods for Point-wise Dependency Estimation](https://arxiv.org/abs/2006.05553)"
      ],
      "metadata": {
        "id": "UNUlLl-Sw4u1"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Simple dataset for constructing paired data\n",
        "\n",
        "class PairedDataset(Dataset):\n",
        "  def __init__(self, data_A, data_B):\n",
        "    self.data_A = data_A.float()\n",
        "    self.data_B = data_B.float()\n",
        "\n",
        "  def __len__(self):\n",
        "    return self.data_A.shape[0]\n",
        "\n",
        "  def __getitem__(self, idx):\n",
        "    return self.data_A[idx], self.data_B[idx]\n",
        "\n",
        "  def get_dims(self):\n",
        "    return self.data_A.shape[1], self.data_B.shape[1]\n",
        "\n",
        "\n",
        "def train_critic_paired(critic_model, train_loader, opt_crit, num_epoch=80):\n",
        "    MIs = []\n",
        "    for _iter in range(num_epoch):\n",
        "        for i_batch, (A_batch, B_batch) in enumerate(train_loader):\n",
        "            opt_crit.zero_grad()\n",
        "\n",
        "            x_batch = A_batch.cuda()\n",
        "            y_batch = B_batch.cuda()\n",
        "\n",
        "            scores = critic_model(x_batch, y_batch)\n",
        "            MIs.append(probabilistic_classifier_eval(scores))\n",
        "\n",
        "            negative_loss = probabilistic_classifier_obj(scores)\n",
        "            loss = -negative_loss\n",
        "\n",
        "            loss.backward()\n",
        "            opt_crit.step()\n",
        "\n",
        "            if i_batch%100 == 0:\n",
        "                print('iter: ', _iter, ' i_batch: ', i_batch, ' negative_loss: ', negative_loss.item())\n",
        "\n",
        "    return MIs\n",
        "\n",
        "\n",
        "def eval_MI_paired(critic_model, test_loader):\n",
        "    MIs = []\n",
        "\n",
        "    for i_batch, (A_batch, B_batch) in enumerate(test_loader):\n",
        "\n",
        "        x_batch = A_batch.cuda()\n",
        "        y_batch = B_batch.cuda()\n",
        "\n",
        "        scores = critic_model(x_batch, y_batch)\n",
        "        MIs.append(probabilistic_classifier_eval(scores))\n",
        "\n",
        "    MI = torch.stack(MIs).mean()\n",
        "    print(MI.item())\n",
        "    return MI.item()\n"
      ],
      "metadata": {
        "id": "1Y3i9XEAx8hR"
      },
      "execution_count": 10,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "layers = 1\n",
        "activation = 'relu'\n",
        "lr = 1e-4"
      ],
      "metadata": {
        "id": "Vxbxr9jZ1Xl0"
      },
      "execution_count": 8,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "data_A = train_dataset[:][0]\n",
        "data_B = train_dataset[:][1]\n",
        "\n",
        "mi_dataset = PairedDataset(data_A, data_B)\n",
        "mi_loader = DataLoader(mi_dataset,\n",
        "                                 shuffle=True,\n",
        "                                 batch_size=batch_size)\n",
        "\n",
        "dim1, dim2 = mi_dataset.get_dims()\n",
        "critic_model = SeparableCritic(x1_dim=dim1, x2_dim=dim2, hidden_dim=hidden_dim,\n",
        "                              embed_dim=embed_dim, layers=layers, activation=activation).cuda()\n",
        "\n",
        "opt_crit = optim.Adam(critic_model.parameters(), lr=lr)\n",
        "\n",
        "train_critic_paired(critic_model, mi_loader, opt_crit, num_epoch=30)\n",
        "MI = eval_MI_paired(critic_model, mi_loader)\n",
        "\n",
        "\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "qdAjxJKYw3gU",
        "outputId": "1a76d26c-c545-4c01-f1af-54fde67e327f"
      },
      "execution_count": 9,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "iter:  0  i_batch:  0  negative_loss:  -0.6815298795700073\n",
            "iter:  1  i_batch:  0  negative_loss:  -0.02864418551325798\n",
            "iter:  2  i_batch:  0  negative_loss:  -0.025951214134693146\n",
            "iter:  3  i_batch:  0  negative_loss:  -0.025350451469421387\n",
            "iter:  4  i_batch:  0  negative_loss:  -0.022115405648946762\n",
            "iter:  5  i_batch:  0  negative_loss:  -0.018267907202243805\n",
            "iter:  6  i_batch:  0  negative_loss:  -0.014728059060871601\n",
            "iter:  7  i_batch:  0  negative_loss:  -0.012674244120717049\n",
            "iter:  8  i_batch:  0  negative_loss:  -0.008447285741567612\n",
            "iter:  9  i_batch:  0  negative_loss:  -0.007021321449428797\n",
            "iter:  10  i_batch:  0  negative_loss:  -0.005306956823915243\n",
            "iter:  11  i_batch:  0  negative_loss:  -0.004481859505176544\n",
            "iter:  12  i_batch:  0  negative_loss:  -0.0029017655178904533\n",
            "iter:  13  i_batch:  0  negative_loss:  -0.002666341606527567\n",
            "iter:  14  i_batch:  0  negative_loss:  -0.002147895749658346\n",
            "iter:  15  i_batch:  0  negative_loss:  -0.0017188391648232937\n",
            "iter:  16  i_batch:  0  negative_loss:  -0.0015495491679757833\n",
            "iter:  17  i_batch:  0  negative_loss:  -0.0013614825438708067\n",
            "iter:  18  i_batch:  0  negative_loss:  -0.0012494383845478296\n",
            "iter:  19  i_batch:  0  negative_loss:  -0.001161330845206976\n",
            "iter:  20  i_batch:  0  negative_loss:  -0.0012195127783343196\n",
            "iter:  21  i_batch:  0  negative_loss:  -0.0004979142686352134\n",
            "iter:  22  i_batch:  0  negative_loss:  -0.000861837062984705\n",
            "iter:  23  i_batch:  0  negative_loss:  -0.0006348080933094025\n",
            "iter:  24  i_batch:  0  negative_loss:  -0.0006930832751095295\n",
            "iter:  25  i_batch:  0  negative_loss:  -0.0007835590513423085\n",
            "iter:  26  i_batch:  0  negative_loss:  -0.0006960316095501184\n",
            "iter:  27  i_batch:  0  negative_loss:  -0.0004299855208955705\n",
            "iter:  28  i_batch:  0  negative_loss:  -0.0006372892530634999\n",
            "iter:  29  i_batch:  0  negative_loss:  -0.0007753260433673859\n",
            "9.910201072692871\n"
          ]
        }
      ]
    }
  ]
}