{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fWNBW_QzYrvs"
      },
      "source": [
        "#Preparation"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "id": "M9WBEG59zoyA",
        "outputId": "7d37cbc8-be87-4383-82fe-6c0545cc644c"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Collecting torch==1.13.1\n",
            "  Downloading torch-1.13.1-cp310-cp310-manylinux1_x86_64.whl (887.5 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m887.5/887.5 MB\u001b[0m \u001b[31m1.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hCollecting torchvision==0.14.1\n",
            "  Downloading torchvision-0.14.1-cp310-cp310-manylinux1_x86_64.whl (24.2 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m24.2/24.2 MB\u001b[0m \u001b[31m65.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hCollecting torchaudio==0.13.1\n",
            "  Downloading torchaudio-0.13.1-cp310-cp310-manylinux1_x86_64.whl (4.2 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.2/4.2 MB\u001b[0m \u001b[31m108.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hCollecting torchdata==0.5.1\n",
            "  Downloading torchdata-0.5.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.6 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.6/4.6 MB\u001b[0m \u001b[31m103.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hCollecting torchtext==0.14.1\n",
            "  Downloading torchtext-0.14.1-cp310-cp310-manylinux1_x86_64.whl (2.0 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.0/2.0 MB\u001b[0m \u001b[31m101.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch==1.13.1) (4.5.0)\n",
            "Requirement already satisfied: nvidia-cuda-runtime-cu11==11.7.99 in /usr/local/lib/python3.10/dist-packages (from torch==1.13.1) (11.7.99)\n",
            "Requirement already satisfied: nvidia-cudnn-cu11==8.5.0.96 in /usr/local/lib/python3.10/dist-packages (from torch==1.13.1) (8.5.0.96)\n",
            "Requirement already satisfied: nvidia-cublas-cu11==11.10.3.66 in /usr/local/lib/python3.10/dist-packages (from torch==1.13.1) (11.10.3.66)\n",
            "Requirement already satisfied: nvidia-cuda-nvrtc-cu11==11.7.99 in /usr/local/lib/python3.10/dist-packages (from torch==1.13.1) (11.7.99)\n",
            "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from torchvision==0.14.1) (1.23.5)\n",
            "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from torchvision==0.14.1) (2.31.0)\n",
            "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.10/dist-packages (from torchvision==0.14.1) (9.4.0)\n",
            "Requirement already satisfied: urllib3>=1.25 in /usr/local/lib/python3.10/dist-packages (from torchdata==0.5.1) (2.0.7)\n",
            "Collecting portalocker>=2.0.0 (from torchdata==0.5.1)\n",
            "  Downloading portalocker-2.8.2-py3-none-any.whl (17 kB)\n",
            "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from torchtext==0.14.1) (4.66.1)\n",
            "Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from nvidia-cublas-cu11==11.10.3.66->torch==1.13.1) (67.7.2)\n",
            "Requirement already satisfied: wheel in /usr/local/lib/python3.10/dist-packages (from nvidia-cublas-cu11==11.10.3.66->torch==1.13.1) (0.41.3)\n",
            "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->torchvision==0.14.1) (3.3.2)\n",
            "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->torchvision==0.14.1) (3.4)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->torchvision==0.14.1) (2023.7.22)\n",
            "Installing collected packages: portalocker, torch, torchvision, torchtext, torchdata, torchaudio\n",
            "  Attempting uninstall: torch\n",
            "    Found existing installation: torch 2.0.1\n",
            "    Uninstalling torch-2.0.1:\n",
            "      Successfully uninstalled torch-2.0.1\n",
            "  Attempting uninstall: torchvision\n",
            "    Found existing installation: torchvision 0.15.2\n",
            "    Uninstalling torchvision-0.15.2:\n",
            "      Successfully uninstalled torchvision-0.15.2\n",
            "  Attempting uninstall: torchtext\n",
            "    Found existing installation: torchtext 0.15.2\n",
            "    Uninstalling torchtext-0.15.2:\n",
            "      Successfully uninstalled torchtext-0.15.2\n",
            "  Attempting uninstall: torchdata\n",
            "    Found existing installation: torchdata 0.6.1\n",
            "    Uninstalling torchdata-0.6.1:\n",
            "      Successfully uninstalled torchdata-0.6.1\n",
            "  Attempting uninstall: torchaudio\n",
            "    Found existing installation: torchaudio 2.0.2\n",
            "    Uninstalling torchaudio-2.0.2:\n",
            "      Successfully uninstalled torchaudio-2.0.2\n",
            "Successfully installed portalocker-2.8.2 torch-1.13.1 torchaudio-0.13.1 torchdata-0.5.1 torchtext-0.14.1 torchvision-0.14.1\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.colab-display-data+json": {
              "pip_warning": {
                "packages": [
                  "torch",
                  "torchdata",
                  "torchtext",
                  "torchvision"
                ]
              }
            }
          },
          "metadata": {}
        }
      ],
      "source": [
        "!pip install torch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 torchdata==0.5.1 torchtext==0.14.1"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "BgGoqg0DYP92",
        "outputId": "6a2f7674-d80e-4ba9-d516-06a13de00b1e"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Cloning into 'MultiBench'...\n",
            "remote: Enumerating objects: 6937, done.\u001b[K\n",
            "remote: Counting objects: 100% (148/148), done.\u001b[K\n",
            "remote: Compressing objects: 100% (88/88), done.\u001b[K\n",
            "remote: Total 6937 (delta 68), reused 121 (delta 60), pack-reused 6789\u001b[K\n",
            "Receiving objects: 100% (6937/6937), 51.07 MiB | 21.70 MiB/s, done.\n",
            "Resolving deltas: 100% (4254/4254), done.\n"
          ]
        }
      ],
      "source": [
        "!git clone https://github.com/pliang279/MultiBench"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "r1KGIlFvYqYH",
        "outputId": "5be863d3-36df-43a0-996c-227524321871"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "/content/MultiBench\n"
          ]
        }
      ],
      "source": [
        "%cd MultiBench/"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "S3RpE80CYxIY",
        "outputId": "695bcbf0-d0a9-4020-d6d6-721610ba23d1"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Requirement already satisfied: gdown in /usr/local/lib/python3.10/dist-packages (4.6.6)\n",
            "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from gdown) (3.13.1)\n",
            "Requirement already satisfied: requests[socks] in /usr/local/lib/python3.10/dist-packages (from gdown) (2.31.0)\n",
            "Requirement already satisfied: six in /usr/local/lib/python3.10/dist-packages (from gdown) (1.16.0)\n",
            "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from gdown) (4.66.1)\n",
            "Requirement already satisfied: beautifulsoup4 in /usr/local/lib/python3.10/dist-packages (from gdown) (4.11.2)\n",
            "Requirement already satisfied: soupsieve>1.2 in /usr/local/lib/python3.10/dist-packages (from beautifulsoup4->gdown) (2.5)\n",
            "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests[socks]->gdown) (3.3.2)\n",
            "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests[socks]->gdown) (3.4)\n",
            "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests[socks]->gdown) (2.0.7)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests[socks]->gdown) (2023.7.22)\n",
            "Requirement already satisfied: PySocks!=1.5.7,>=1.5.6 in /usr/local/lib/python3.10/dist-packages (from requests[socks]->gdown) (1.7.1)\n",
            "Downloading...\n",
            "From: https://drive.google.com/uc?id=1_XdzdW8UNG1TTS6QcX10uhoS6N11OBit\n",
            "To: /content/MultiBench/mosi_data.pkl\n",
            "100% 154M/154M [00:03<00:00, 42.9MB/s]\n",
            "Downloading...\n",
            "From: https://drive.google.com/uc?id=180l4pN6XAv8-OAYQ6OrMheFUMwtqUWbz\n",
            "To: /content/MultiBench/mosei_senti_data.pkl\n",
            "100% 3.73G/3.73G [00:44<00:00, 84.3MB/s]\n",
            "Downloading...\n",
            "From: https://drive.google.com/uc?id=1EMBUmUL5B0PTncGx3L-sBElGOmjFBR_h\n",
            "To: /content/MultiBench/sarcasm.pkl\n",
            "100% 208M/208M [00:00<00:00, 232MB/s]\n",
            "Downloading...\n",
            "From: https://drive.google.com/uc?id=1L5slPmYyhEVtwGyM1kgcFMjeBpXLZGT0\n",
            "To: /content/MultiBench/humor.pkl\n",
            "100% 1.22G/1.22G [00:09<00:00, 122MB/s]\n",
            "Downloading...\n",
            "From: https://drive.google.com/uc?id=1SuTPg0MTo4P8dXLKFjK7LSoXWqm1YP92\n",
            "To: /content/MultiBench/im.pk\n",
            "100% 83.2M/83.2M [00:00<00:00, 190MB/s]\n"
          ]
        }
      ],
      "source": [
        "!pip install gdown\n",
        "\n",
        "!gdown https://drive.google.com/uc?id=1_XdzdW8UNG1TTS6QcX10uhoS6N11OBit&export=download #MOSI data\n",
        "!gdown https://drive.google.com/uc?id=180l4pN6XAv8-OAYQ6OrMheFUMwtqUWbz&export=download #MOSEI data\n",
        "!gdown https://drive.google.com/uc?id=1EMBUmUL5B0PTncGx3L-sBElGOmjFBR_h&export=download #Sarcasm data\n",
        "!gdown https://drive.google.com/uc?id=1L5slPmYyhEVtwGyM1kgcFMjeBpXLZGT0&export=download #Humor data\n",
        "!gdown https://drive.google.com/uc?id=1SuTPg0MTo4P8dXLKFjK7LSoXWqm1YP92&export=download #MIMIC data\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "IHKnbVkoASPI",
        "outputId": "ba0a3375-4bec-4a28-dbfa-46fe55068de2"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Cloning into 'FactorCL'...\n",
            "remote: Enumerating objects: 128, done.\u001b[K\n",
            "remote: Counting objects: 100% (12/12), done.\u001b[K\n",
            "remote: Compressing objects: 100% (11/11), done.\u001b[K\n",
            "remote: Total 128 (delta 2), reused 0 (delta 0), pack-reused 116\u001b[K\n",
            "Receiving objects: 100% (128/128), 288.20 KiB | 1.43 MiB/s, done.\n",
            "Resolving deltas: 100% (59/59), done.\n"
          ]
        }
      ],
      "source": [
        "!git clone https://github.com/pliang279/FactorCL"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "id": "A0mgHF3bAgvo"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "import sys\n",
        "from torch.utils.data import DataLoader\n",
        "from sklearn.linear_model import LogisticRegression\n",
        "\n",
        "from unimodals.common_models import Transformer, MLP\n",
        "from unimodals.common_models import MLP, GRUWithLinear, GRU\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "id": "ZBH2aOwge2Lb"
      },
      "outputs": [],
      "source": [
        "sys.path.append(os.getcwd())"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "-9fKJqfVdYbH",
        "outputId": "751d6093-aa2f-4aa0-fb4a-3e821f1499b0"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "/content/MultiBench/FactorCL\n"
          ]
        }
      ],
      "source": [
        "%cd FactorCL"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "id": "r2H_R1qlgQKt"
      },
      "outputs": [],
      "source": [
        "from multibench_model import*"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LgpMeb4eRI5Z"
      },
      "source": [
        "#MOSI"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 57,
      "metadata": {
        "id": "SKkJkrCQt2uR"
      },
      "outputs": [],
      "source": [
        "from datasets.affect.get_data import get_dataloader\n",
        "\n",
        "train_loader, valid_loader, test_loader = get_dataloader('/content/MultiBench/mosi_data.pkl',\n",
        "                                                         robust_test=False,\n",
        "                                                         batch_size=32,\n",
        "                                                         train_shuffle=True)\n",
        "\n",
        "eval_train_loader, eval_valid_loader, eval_test_loader = get_dataloader('/content/MultiBench/mosi_data.pkl',\n",
        "                                                                        robust_test=False,\n",
        "                                                                        batch_size=32,\n",
        "                                                                        train_shuffle=False)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AQ8W129stEfo"
      },
      "source": [
        "##FactorCL-SUP"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 58,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Ha5JwFzHRKjd",
        "outputId": "9cee9423-88c9-432c-ec2c-1026902c0450"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "iter:  0  i_batch:  0  loss:  0.0011624246835708618\n",
            "iter:  1  i_batch:  0  loss:  -0.03190102055668831\n",
            "iter:  2  i_batch:  0  loss:  -0.4311522841453552\n",
            "iter:  3  i_batch:  0  loss:  -0.9014155864715576\n",
            "iter:  4  i_batch:  0  loss:  -2.1317553520202637\n",
            "iter:  5  i_batch:  0  loss:  -2.976452350616455\n",
            "iter:  6  i_batch:  0  loss:  -3.3434486389160156\n",
            "iter:  7  i_batch:  0  loss:  -3.346254348754883\n",
            "iter:  8  i_batch:  0  loss:  -4.042529106140137\n",
            "iter:  9  i_batch:  0  loss:  -4.252300262451172\n",
            "iter:  10  i_batch:  0  loss:  -4.225554466247559\n",
            "iter:  11  i_batch:  0  loss:  1.0005271434783936\n",
            "iter:  12  i_batch:  0  loss:  -4.045951843261719\n",
            "iter:  13  i_batch:  0  loss:  -4.469624996185303\n",
            "iter:  14  i_batch:  0  loss:  -4.786789894104004\n",
            "iter:  15  i_batch:  0  loss:  -4.101232528686523\n",
            "iter:  16  i_batch:  0  loss:  -4.211440086364746\n",
            "iter:  17  i_batch:  0  loss:  -3.848883628845215\n",
            "iter:  18  i_batch:  0  loss:  -3.969515085220337\n",
            "iter:  19  i_batch:  0  loss:  -4.6729512214660645\n",
            "iter:  20  i_batch:  0  loss:  -4.789951324462891\n",
            "iter:  21  i_batch:  0  loss:  -5.008730888366699\n",
            "iter:  22  i_batch:  0  loss:  -4.737300872802734\n",
            "iter:  23  i_batch:  0  loss:  -4.246749401092529\n",
            "iter:  24  i_batch:  0  loss:  -4.636990547180176\n",
            "iter:  25  i_batch:  0  loss:  -3.654554843902588\n",
            "iter:  26  i_batch:  0  loss:  -3.8198657035827637\n",
            "iter:  27  i_batch:  0  loss:  -5.670473098754883\n",
            "iter:  28  i_batch:  0  loss:  -4.847062110900879\n",
            "iter:  29  i_batch:  0  loss:  -4.916831970214844\n",
            "iter:  30  i_batch:  0  loss:  -0.6473921537399292\n",
            "iter:  31  i_batch:  0  loss:  -5.111954689025879\n",
            "iter:  32  i_batch:  0  loss:  -5.216065883636475\n",
            "iter:  33  i_batch:  0  loss:  -3.4189281463623047\n",
            "iter:  34  i_batch:  0  loss:  -4.672704696655273\n",
            "iter:  35  i_batch:  0  loss:  -4.497556686401367\n",
            "iter:  36  i_batch:  0  loss:  -5.14064359664917\n",
            "iter:  37  i_batch:  0  loss:  -5.370944023132324\n",
            "iter:  38  i_batch:  0  loss:  -4.58414888381958\n",
            "iter:  39  i_batch:  0  loss:  -5.476208686828613\n",
            "iter:  40  i_batch:  0  loss:  -4.5191802978515625\n",
            "iter:  41  i_batch:  0  loss:  -4.727219104766846\n",
            "iter:  42  i_batch:  0  loss:  -4.7529616355896\n",
            "iter:  43  i_batch:  0  loss:  -5.122702598571777\n",
            "iter:  44  i_batch:  0  loss:  -5.684665679931641\n",
            "iter:  45  i_batch:  0  loss:  -2.4282925128936768\n",
            "iter:  46  i_batch:  0  loss:  -5.595916748046875\n",
            "iter:  47  i_batch:  0  loss:  -5.13438606262207\n",
            "iter:  48  i_batch:  0  loss:  -6.019073963165283\n",
            "iter:  49  i_batch:  0  loss:  -5.550722599029541\n",
            "iter:  50  i_batch:  0  loss:  -3.9192724227905273\n",
            "iter:  51  i_batch:  0  loss:  -5.574321746826172\n",
            "iter:  52  i_batch:  0  loss:  -5.297717094421387\n",
            "iter:  53  i_batch:  0  loss:  -4.170279026031494\n",
            "iter:  54  i_batch:  0  loss:  -5.7246012687683105\n",
            "iter:  55  i_batch:  0  loss:  -5.899558067321777\n",
            "iter:  56  i_batch:  0  loss:  -5.035403251647949\n",
            "iter:  57  i_batch:  0  loss:  -5.506241321563721\n",
            "iter:  58  i_batch:  0  loss:  -5.031591415405273\n",
            "iter:  59  i_batch:  0  loss:  -6.013862609863281\n",
            "iter:  60  i_batch:  0  loss:  -4.940839767456055\n",
            "iter:  61  i_batch:  0  loss:  -4.712447166442871\n",
            "iter:  62  i_batch:  0  loss:  -4.295894622802734\n",
            "iter:  63  i_batch:  0  loss:  -5.478939056396484\n",
            "iter:  64  i_batch:  0  loss:  -5.582670211791992\n",
            "iter:  65  i_batch:  0  loss:  -5.190284729003906\n",
            "iter:  66  i_batch:  0  loss:  -5.443666458129883\n",
            "iter:  67  i_batch:  0  loss:  -3.967256546020508\n",
            "iter:  68  i_batch:  0  loss:  -6.085321426391602\n",
            "iter:  69  i_batch:  0  loss:  -4.760692596435547\n",
            "iter:  70  i_batch:  0  loss:  -5.3948259353637695\n",
            "iter:  71  i_batch:  0  loss:  -5.857318878173828\n",
            "iter:  72  i_batch:  0  loss:  -5.607830047607422\n",
            "iter:  73  i_batch:  0  loss:  -5.552946090698242\n",
            "iter:  74  i_batch:  0  loss:  -5.454460144042969\n",
            "iter:  75  i_batch:  0  loss:  -5.54594612121582\n",
            "iter:  76  i_batch:  0  loss:  -3.8459858894348145\n",
            "iter:  77  i_batch:  0  loss:  -4.436514377593994\n",
            "iter:  78  i_batch:  0  loss:  -5.631844520568848\n",
            "iter:  79  i_batch:  0  loss:  -5.609157562255859\n",
            "iter:  80  i_batch:  0  loss:  -6.110953330993652\n",
            "iter:  81  i_batch:  0  loss:  -5.901972770690918\n",
            "iter:  82  i_batch:  0  loss:  -5.723435401916504\n",
            "iter:  83  i_batch:  0  loss:  -5.417698860168457\n",
            "iter:  84  i_batch:  0  loss:  -5.827786445617676\n",
            "iter:  85  i_batch:  0  loss:  -6.005405426025391\n",
            "iter:  86  i_batch:  0  loss:  -5.556831359863281\n",
            "iter:  87  i_batch:  0  loss:  -4.285808086395264\n",
            "iter:  88  i_batch:  0  loss:  3.1289167404174805\n",
            "iter:  89  i_batch:  0  loss:  -5.291248321533203\n",
            "iter:  90  i_batch:  0  loss:  -6.017822742462158\n",
            "iter:  91  i_batch:  0  loss:  -5.869406700134277\n",
            "iter:  92  i_batch:  0  loss:  -5.685717582702637\n",
            "iter:  93  i_batch:  0  loss:  -6.027961254119873\n",
            "iter:  94  i_batch:  0  loss:  -5.728565692901611\n",
            "iter:  95  i_batch:  0  loss:  -6.010775566101074\n",
            "iter:  96  i_batch:  0  loss:  -5.595132827758789\n",
            "iter:  97  i_batch:  0  loss:  -5.609217643737793\n",
            "iter:  98  i_batch:  0  loss:  -6.420151233673096\n",
            "iter:  99  i_batch:  0  loss:  -5.681667327880859\n"
          ]
        }
      ],
      "source": [
        "encoders = [Transformer(20, 40), Transformer(300, 600)]\n",
        "factorcl_sup = FactorCLSUP(encoders=encoders, feat_dims=[40, 600], y_ohe_dim=3).cuda()\n",
        "train_sup_mosi(factorcl_sup, train_loader, num_epoch=100, num_club_iter=1)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "4cCSnp8_Zv6I",
        "outputId": "b9cc1e6e-f1df-4076-8cc7-a8b2a1238112"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.10/dist-packages/sklearn/utils/validation.py:1143: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n",
            "  y = column_or_1d(y, warn=True)\n",
            "/usr/local/lib/python3.10/dist-packages/sklearn/linear_model/_logistic.py:458: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
            "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
            "\n",
            "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
            "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
            "Please also refer to the documentation for alternative solver options:\n",
            "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
            "  n_iter_i = _check_optimize_result(\n"
          ]
        }
      ],
      "source": [
        "factorcl_sup.eval()\n",
        "\n",
        "train_embeds_x1 = np.concatenate([factorcl_sup.get_embedding(data[0][0].cuda(), data[0][2].cuda())[0].detach().cpu().numpy() for data in eval_train_loader])\n",
        "train_embeds_x2 = np.concatenate([factorcl_sup.get_embedding(data[0][0].cuda(), data[0][2].cuda())[1].detach().cpu().numpy() for data in eval_train_loader])\n",
        "train_embeds = np.concatenate([train_embeds_x1, train_embeds_x2], axis=1)\n",
        "train_labels = np.concatenate([data[3].detach().cpu().numpy() for data in eval_train_loader])\n",
        "train_labels = mosi_label(train_labels)\n",
        "\n",
        "test_embeds_x1 = np.concatenate([factorcl_sup.get_embedding(data[0][0].cuda(), data[0][2].cuda())[0].detach().cpu().numpy() for data in eval_test_loader])\n",
        "test_embeds_x2 = np.concatenate([factorcl_sup.get_embedding(data[0][0].cuda(), data[0][2].cuda())[1].detach().cpu().numpy() for data in eval_test_loader])\n",
        "test_embeds = np.concatenate([test_embeds_x1, test_embeds_x2], axis=1)\n",
        "test_labels = np.concatenate([data[3].detach().cpu().numpy() for data in eval_test_loader])\n",
        "test_labels = mosi_label(test_labels)\n",
        "\n",
        "# Train Logistic Classifier\n",
        "clf = LogisticRegression(max_iter=200).fit(train_embeds, train_labels)\n",
        "score = clf.score(test_embeds, test_labels)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "oTEJEZ1wdE_M",
        "outputId": "1e4566ab-cd75-415b-e6e4-428ffef88f65"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "0.6720116618075802"
            ]
          },
          "execution_count": 22,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "score"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AQGVBuCDt6nv"
      },
      "source": [
        "##FactorCL-SSL"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "L2JmyUO-t6nv",
        "outputId": "bf67eb40-74ba-400d-c921-1ed8f315faa0"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/transformer.py:282: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.self_attn.batch_first was not True(use batch_first for better inference performance)\n",
            "  warnings.warn(f\"enable_nested_tensor is True, but self.use_nested_tensor is False because {why_not_sparsity_fast_path}\")\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "iter:  0  i_batch:  0  loss:  -0.002487795427441597\n",
            "iter:  1  i_batch:  0  loss:  -1.830185890197754\n",
            "iter:  2  i_batch:  0  loss:  -3.1662981510162354\n",
            "iter:  3  i_batch:  0  loss:  -3.681149482727051\n",
            "iter:  4  i_batch:  0  loss:  -4.313801288604736\n",
            "iter:  5  i_batch:  0  loss:  -4.4881367683410645\n",
            "iter:  6  i_batch:  0  loss:  -4.696531295776367\n",
            "iter:  7  i_batch:  0  loss:  -4.604437351226807\n",
            "iter:  8  i_batch:  0  loss:  -6.114076614379883\n",
            "iter:  9  i_batch:  0  loss:  -6.7265400886535645\n",
            "iter:  10  i_batch:  0  loss:  -6.742205619812012\n",
            "iter:  11  i_batch:  0  loss:  -7.247134208679199\n",
            "iter:  12  i_batch:  0  loss:  -8.011838912963867\n",
            "iter:  13  i_batch:  0  loss:  -8.293033599853516\n",
            "iter:  14  i_batch:  0  loss:  -5.3193159103393555\n",
            "iter:  15  i_batch:  0  loss:  -8.390777587890625\n",
            "iter:  16  i_batch:  0  loss:  -8.524253845214844\n",
            "iter:  17  i_batch:  0  loss:  -7.289597511291504\n",
            "iter:  18  i_batch:  0  loss:  -9.07352066040039\n",
            "iter:  19  i_batch:  0  loss:  -7.363587379455566\n",
            "iter:  20  i_batch:  0  loss:  -8.358322143554688\n",
            "iter:  21  i_batch:  0  loss:  -8.734292030334473\n",
            "iter:  22  i_batch:  0  loss:  -8.809568405151367\n",
            "iter:  23  i_batch:  0  loss:  -8.602814674377441\n",
            "iter:  24  i_batch:  0  loss:  -9.089998245239258\n",
            "iter:  25  i_batch:  0  loss:  -8.887147903442383\n",
            "iter:  26  i_batch:  0  loss:  -8.333454132080078\n",
            "iter:  27  i_batch:  0  loss:  -9.577531814575195\n",
            "iter:  28  i_batch:  0  loss:  -9.318632125854492\n",
            "iter:  29  i_batch:  0  loss:  -9.441801071166992\n",
            "iter:  30  i_batch:  0  loss:  -9.611943244934082\n",
            "iter:  31  i_batch:  0  loss:  -9.811370849609375\n",
            "iter:  32  i_batch:  0  loss:  -8.55459976196289\n",
            "iter:  33  i_batch:  0  loss:  -8.980246543884277\n",
            "iter:  34  i_batch:  0  loss:  -9.60719108581543\n",
            "iter:  35  i_batch:  0  loss:  -8.611591339111328\n",
            "iter:  36  i_batch:  0  loss:  -9.681583404541016\n",
            "iter:  37  i_batch:  0  loss:  -9.74519157409668\n",
            "iter:  38  i_batch:  0  loss:  -9.833538055419922\n",
            "iter:  39  i_batch:  0  loss:  -9.5485258102417\n",
            "iter:  40  i_batch:  0  loss:  -9.292667388916016\n",
            "iter:  41  i_batch:  0  loss:  -10.301009178161621\n",
            "iter:  42  i_batch:  0  loss:  -9.700072288513184\n",
            "iter:  43  i_batch:  0  loss:  -10.055258750915527\n",
            "iter:  44  i_batch:  0  loss:  -9.476600646972656\n",
            "iter:  45  i_batch:  0  loss:  -8.880590438842773\n",
            "iter:  46  i_batch:  0  loss:  -9.685264587402344\n",
            "iter:  47  i_batch:  0  loss:  -9.662805557250977\n",
            "iter:  48  i_batch:  0  loss:  -9.99907398223877\n",
            "iter:  49  i_batch:  0  loss:  -8.704071044921875\n",
            "iter:  50  i_batch:  0  loss:  -10.071045875549316\n",
            "iter:  51  i_batch:  0  loss:  -9.820974349975586\n",
            "iter:  52  i_batch:  0  loss:  -9.052451133728027\n",
            "iter:  53  i_batch:  0  loss:  -10.254377365112305\n",
            "iter:  54  i_batch:  0  loss:  -9.101917266845703\n",
            "iter:  55  i_batch:  0  loss:  -10.659225463867188\n",
            "iter:  56  i_batch:  0  loss:  -9.395548820495605\n",
            "iter:  57  i_batch:  0  loss:  -9.780409812927246\n",
            "iter:  58  i_batch:  0  loss:  -9.663912773132324\n",
            "iter:  59  i_batch:  0  loss:  -9.981481552124023\n",
            "iter:  60  i_batch:  0  loss:  -10.325379371643066\n",
            "iter:  61  i_batch:  0  loss:  -10.244674682617188\n",
            "iter:  62  i_batch:  0  loss:  -10.248712539672852\n",
            "iter:  63  i_batch:  0  loss:  -10.596391677856445\n",
            "iter:  64  i_batch:  0  loss:  -9.929593086242676\n",
            "iter:  65  i_batch:  0  loss:  -8.307312965393066\n",
            "iter:  66  i_batch:  0  loss:  -10.014596939086914\n",
            "iter:  67  i_batch:  0  loss:  -8.15733528137207\n",
            "iter:  68  i_batch:  0  loss:  -9.8958740234375\n",
            "iter:  69  i_batch:  0  loss:  -10.367892265319824\n",
            "iter:  70  i_batch:  0  loss:  -9.459168434143066\n",
            "iter:  71  i_batch:  0  loss:  -10.081171035766602\n",
            "iter:  72  i_batch:  0  loss:  -9.856391906738281\n",
            "iter:  73  i_batch:  0  loss:  -9.965957641601562\n",
            "iter:  74  i_batch:  0  loss:  -10.20479965209961\n",
            "iter:  75  i_batch:  0  loss:  -10.243646621704102\n",
            "iter:  76  i_batch:  0  loss:  -10.218008041381836\n",
            "iter:  77  i_batch:  0  loss:  -10.071391105651855\n",
            "iter:  78  i_batch:  0  loss:  -9.449466705322266\n",
            "iter:  79  i_batch:  0  loss:  -10.068207740783691\n",
            "iter:  80  i_batch:  0  loss:  -10.4512939453125\n",
            "iter:  81  i_batch:  0  loss:  -9.146720886230469\n",
            "iter:  82  i_batch:  0  loss:  -9.211112022399902\n",
            "iter:  83  i_batch:  0  loss:  -10.577682495117188\n",
            "iter:  84  i_batch:  0  loss:  -10.82564926147461\n",
            "iter:  85  i_batch:  0  loss:  -10.904359817504883\n",
            "iter:  86  i_batch:  0  loss:  -10.144580841064453\n",
            "iter:  87  i_batch:  0  loss:  -8.592063903808594\n",
            "iter:  88  i_batch:  0  loss:  -8.313719749450684\n",
            "iter:  89  i_batch:  0  loss:  -10.202726364135742\n",
            "iter:  90  i_batch:  0  loss:  -9.839406967163086\n",
            "iter:  91  i_batch:  0  loss:  -10.685945510864258\n",
            "iter:  92  i_batch:  0  loss:  -10.541755676269531\n",
            "iter:  93  i_batch:  0  loss:  -10.253667831420898\n",
            "iter:  94  i_batch:  0  loss:  -9.885324478149414\n",
            "iter:  95  i_batch:  0  loss:  -7.863533020019531\n",
            "iter:  96  i_batch:  0  loss:  -9.460472106933594\n",
            "iter:  97  i_batch:  0  loss:  -10.282123565673828\n",
            "iter:  98  i_batch:  0  loss:  -9.427762985229492\n",
            "iter:  99  i_batch:  0  loss:  -9.637114524841309\n"
          ]
        }
      ],
      "source": [
        "encoders = [Transformer(20, 40), Transformer(300, 600)]\n",
        "factorcl_ssl = FactorCLSSL(encoders=encoders, feat_dims=[40, 600], y_ohe_dim=3).cuda()\n",
        "train_ssl_mosi(factorcl_ssl, train_loader, num_epoch=100, num_club_iter=1)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "WKs0Pr9Zt6nv",
        "outputId": "8436560c-e800-4233-c16a-0bb12c830f23"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.10/dist-packages/sklearn/utils/validation.py:1143: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n",
            "  y = column_or_1d(y, warn=True)\n",
            "/usr/local/lib/python3.10/dist-packages/sklearn/linear_model/_logistic.py:458: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
            "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
            "\n",
            "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
            "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
            "Please also refer to the documentation for alternative solver options:\n",
            "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
            "  n_iter_i = _check_optimize_result(\n"
          ]
        }
      ],
      "source": [
        "factorcl_ssl.eval()\n",
        "\n",
        "train_embeds_x1 = np.concatenate([factorcl_ssl.get_embedding(data[0][0].cuda(), data[0][2].cuda())[0].detach().cpu().numpy() for data in eval_train_loader])\n",
        "train_embeds_x2 = np.concatenate([factorcl_ssl.get_embedding(data[0][0].cuda(), data[0][2].cuda())[1].detach().cpu().numpy() for data in eval_train_loader])\n",
        "train_embeds = np.concatenate([train_embeds_x1, train_embeds_x2], axis=1)\n",
        "train_labels = np.concatenate([data[3].detach().cpu().numpy() for data in eval_train_loader])\n",
        "train_labels = mosi_label(train_labels)\n",
        "\n",
        "test_embeds_x1 = np.concatenate([factorcl_ssl.get_embedding(data[0][0].cuda(), data[0][2].cuda())[0].detach().cpu().numpy() for data in eval_test_loader])\n",
        "test_embeds_x2 = np.concatenate([factorcl_ssl.get_embedding(data[0][0].cuda(), data[0][2].cuda())[1].detach().cpu().numpy() for data in eval_test_loader])\n",
        "test_embeds = np.concatenate([test_embeds_x1, test_embeds_x2], axis=1)\n",
        "test_labels = np.concatenate([data[3].detach().cpu().numpy() for data in eval_test_loader])\n",
        "test_labels = mosi_label(test_labels)\n",
        "\n",
        "# Train Logistic Classifier\n",
        "clf = LogisticRegression(max_iter=200).fit(train_embeds, train_labels)\n",
        "score = clf.score(test_embeds, test_labels)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "wnoaN73pt6nw",
        "outputId": "4a0145a6-179c-48f6-9624-929979c4df13"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "0.5408163265306123"
            ]
          },
          "execution_count": 30,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "score"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cbm0ctEWtGfe"
      },
      "source": [
        "##SupCon"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 61,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "hOJzWZQCtH97",
        "outputId": "36b43501-8795-45f9-9373-04c5b9edaf54"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "iter:  0  i_batch:  0  loss:  40.36393737792969\n",
            "iter:  1  i_batch:  0  loss:  29.615131378173828\n",
            "iter:  2  i_batch:  0  loss:  29.602519989013672\n",
            "iter:  3  i_batch:  0  loss:  29.601627349853516\n",
            "iter:  4  i_batch:  0  loss:  29.598114013671875\n",
            "iter:  5  i_batch:  0  loss:  29.59786033630371\n",
            "iter:  6  i_batch:  0  loss:  29.591426849365234\n",
            "iter:  7  i_batch:  0  loss:  29.59424591064453\n",
            "iter:  8  i_batch:  0  loss:  29.598228454589844\n",
            "iter:  9  i_batch:  0  loss:  29.596210479736328\n"
          ]
        }
      ],
      "source": [
        "encoders = [Transformer(20, 40), Transformer(300, 40)]\n",
        "\n",
        "supcon_model = SupConModel(temperature=0.5, encoders=encoders, dim_ins=[40, 40], feat_dims=[40, 40], use_label=True).cuda()\n",
        "supcon_optim = optim.Adam(supcon_model.parameters())\n",
        "train_supcon_mosi(supcon_model, train_loader, supcon_optim, num_epoch=10)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "hJJkTLwDtyIH",
        "outputId": "786d0ac9-3c62-4080-82ad-ba56a7f42a48"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.10/dist-packages/sklearn/utils/validation.py:1143: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n",
            "  y = column_or_1d(y, warn=True)\n",
            "/usr/local/lib/python3.10/dist-packages/sklearn/linear_model/_logistic.py:458: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
            "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
            "\n",
            "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
            "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
            "Please also refer to the documentation for alternative solver options:\n",
            "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
            "  n_iter_i = _check_optimize_result(\n"
          ]
        }
      ],
      "source": [
        "supcon_model.eval()\n",
        "\n",
        "train_embeds_x1 = np.concatenate([supcon_model.get_embedding(data[0][0].cuda(), data[0][2].cuda())[0].detach().cpu().numpy() for data in eval_train_loader])\n",
        "train_embeds_x2 = np.concatenate([supcon_model.get_embedding(data[0][0].cuda(), data[0][2].cuda())[1].detach().cpu().numpy() for data in eval_train_loader])\n",
        "train_embeds = np.concatenate([train_embeds_x1, train_embeds_x2], axis=1)\n",
        "train_labels = np.concatenate([data[3].detach().cpu().numpy() for data in eval_train_loader])\n",
        "train_labels = mosi_label(train_labels)\n",
        "\n",
        "test_embeds_x1 = np.concatenate([supcon_model.get_embedding(data[0][0].cuda(), data[0][2].cuda())[0].detach().cpu().numpy() for data in eval_test_loader])\n",
        "test_embeds_x2 = np.concatenate([supcon_model.get_embedding(data[0][0].cuda(), data[0][2].cuda())[1].detach().cpu().numpy() for data in eval_test_loader])\n",
        "test_embeds = np.concatenate([test_embeds_x1, test_embeds_x2], axis=1)\n",
        "test_labels = np.concatenate([data[3].detach().cpu().numpy() for data in eval_test_loader])\n",
        "test_labels = mosi_label(test_labels)\n",
        "\n",
        "# Train Logistic Classifier\n",
        "clf = LogisticRegression(max_iter=200).fit(train_embeds, train_labels)\n",
        "score = clf.score(test_embeds, test_labels)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "92NOZhHDtzLn",
        "outputId": "d1f7dd25-6cf7-4de2-bd29-86e89d3136b2"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "0.4620991253644315"
            ]
          },
          "execution_count": 33,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "score"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Jt1RD-kdDYqj"
      },
      "source": [
        "#Humor"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 54,
      "metadata": {
        "id": "OMi9uTMhDYqq"
      },
      "outputs": [],
      "source": [
        "from datasets.affect.get_data import get_dataloader\n",
        "\n",
        "train_loader, valid_loader, test_loader = get_dataloader('/content/MultiBench/humor.pkl',\n",
        "                                                         batch_size=128,\n",
        "                                                         data_type='humor')\n",
        "\n",
        "eval_train_loader, eval_valid_loader, eval_test_loader = get_dataloader('/content/MultiBench/humor.pkl',\n",
        "                                                                        batch_size=128,\n",
        "                                                                        data_type='humor',\n",
        "                                                                        train_shuffle=False)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "30jDD1AaDYqq"
      },
      "source": [
        "##FactorCL-SUP"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 55,
      "metadata": {
        "id": "piPCeZLVDYqq",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "338464a9-b2c1-499c-8c3b-9ea718e13e82"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "iter:  0  i_batch:  0  loss:  0.0013007273664698005\n",
            "iter:  1  i_batch:  0  loss:  -0.1793300211429596\n",
            "iter:  2  i_batch:  0  loss:  -0.36496207118034363\n",
            "iter:  3  i_batch:  0  loss:  -0.05508889630436897\n",
            "iter:  4  i_batch:  0  loss:  0.24322691559791565\n",
            "iter:  5  i_batch:  0  loss:  0.6587186455726624\n",
            "iter:  6  i_batch:  0  loss:  1.0358939170837402\n",
            "iter:  7  i_batch:  0  loss:  1.116131067276001\n",
            "iter:  8  i_batch:  0  loss:  1.1662193536758423\n",
            "iter:  9  i_batch:  0  loss:  1.4159873723983765\n",
            "iter:  10  i_batch:  0  loss:  1.4103529453277588\n",
            "iter:  11  i_batch:  0  loss:  1.6665235757827759\n",
            "iter:  12  i_batch:  0  loss:  1.6012569665908813\n",
            "iter:  13  i_batch:  0  loss:  1.6238175630569458\n",
            "iter:  14  i_batch:  0  loss:  1.7627021074295044\n",
            "iter:  15  i_batch:  0  loss:  1.966027855873108\n",
            "iter:  16  i_batch:  0  loss:  1.733707070350647\n",
            "iter:  17  i_batch:  0  loss:  1.935907244682312\n",
            "iter:  18  i_batch:  0  loss:  2.0887207984924316\n",
            "iter:  19  i_batch:  0  loss:  2.180851697921753\n",
            "iter:  20  i_batch:  0  loss:  2.1452243328094482\n",
            "iter:  21  i_batch:  0  loss:  2.1883225440979004\n",
            "iter:  22  i_batch:  0  loss:  2.2730600833892822\n",
            "iter:  23  i_batch:  0  loss:  2.1712002754211426\n",
            "iter:  24  i_batch:  0  loss:  2.4163832664489746\n",
            "iter:  25  i_batch:  0  loss:  2.4386444091796875\n",
            "iter:  26  i_batch:  0  loss:  2.5235021114349365\n",
            "iter:  27  i_batch:  0  loss:  2.6258342266082764\n",
            "iter:  28  i_batch:  0  loss:  2.656353235244751\n",
            "iter:  29  i_batch:  0  loss:  2.6597535610198975\n",
            "iter:  30  i_batch:  0  loss:  2.691378116607666\n",
            "iter:  31  i_batch:  0  loss:  2.701841354370117\n",
            "iter:  32  i_batch:  0  loss:  2.722193717956543\n",
            "iter:  33  i_batch:  0  loss:  2.741342306137085\n",
            "iter:  34  i_batch:  0  loss:  2.830249786376953\n",
            "iter:  35  i_batch:  0  loss:  2.641406774520874\n",
            "iter:  36  i_batch:  0  loss:  2.4510467052459717\n",
            "iter:  37  i_batch:  0  loss:  2.7503247261047363\n",
            "iter:  38  i_batch:  0  loss:  2.953279733657837\n",
            "iter:  39  i_batch:  0  loss:  2.953406572341919\n",
            "iter:  40  i_batch:  0  loss:  3.0367555618286133\n",
            "iter:  41  i_batch:  0  loss:  2.8526992797851562\n",
            "iter:  42  i_batch:  0  loss:  2.9538142681121826\n",
            "iter:  43  i_batch:  0  loss:  2.9913129806518555\n",
            "iter:  44  i_batch:  0  loss:  3.0023884773254395\n",
            "iter:  45  i_batch:  0  loss:  2.9009594917297363\n",
            "iter:  46  i_batch:  0  loss:  2.934813976287842\n",
            "iter:  47  i_batch:  0  loss:  3.0924527645111084\n",
            "iter:  48  i_batch:  0  loss:  3.1950736045837402\n",
            "iter:  49  i_batch:  0  loss:  3.1898369789123535\n",
            "iter:  50  i_batch:  0  loss:  3.18286395072937\n",
            "iter:  51  i_batch:  0  loss:  3.3345236778259277\n",
            "iter:  52  i_batch:  0  loss:  3.3449974060058594\n",
            "iter:  53  i_batch:  0  loss:  2.9209837913513184\n",
            "iter:  54  i_batch:  0  loss:  2.6646478176116943\n",
            "iter:  55  i_batch:  0  loss:  2.760484218597412\n",
            "iter:  56  i_batch:  0  loss:  2.770448923110962\n",
            "iter:  57  i_batch:  0  loss:  2.9298133850097656\n",
            "iter:  58  i_batch:  0  loss:  2.6132373809814453\n",
            "iter:  59  i_batch:  0  loss:  2.9053807258605957\n",
            "iter:  60  i_batch:  0  loss:  2.98210072517395\n",
            "iter:  61  i_batch:  0  loss:  3.0822980403900146\n",
            "iter:  62  i_batch:  0  loss:  2.986262559890747\n",
            "iter:  63  i_batch:  0  loss:  2.765538215637207\n",
            "iter:  64  i_batch:  0  loss:  2.6932382583618164\n",
            "iter:  65  i_batch:  0  loss:  2.6932923793792725\n",
            "iter:  66  i_batch:  0  loss:  2.908081531524658\n",
            "iter:  67  i_batch:  0  loss:  2.88899827003479\n",
            "iter:  68  i_batch:  0  loss:  2.803131341934204\n",
            "iter:  69  i_batch:  0  loss:  2.954724073410034\n",
            "iter:  70  i_batch:  0  loss:  2.802137613296509\n",
            "iter:  71  i_batch:  0  loss:  2.8641364574432373\n",
            "iter:  72  i_batch:  0  loss:  2.96498966217041\n",
            "iter:  73  i_batch:  0  loss:  3.199286699295044\n",
            "iter:  74  i_batch:  0  loss:  3.076458215713501\n",
            "iter:  75  i_batch:  0  loss:  3.165349245071411\n",
            "iter:  76  i_batch:  0  loss:  2.8884785175323486\n",
            "iter:  77  i_batch:  0  loss:  3.083167314529419\n",
            "iter:  78  i_batch:  0  loss:  3.178025722503662\n",
            "iter:  79  i_batch:  0  loss:  3.422637939453125\n",
            "iter:  80  i_batch:  0  loss:  3.467125415802002\n",
            "iter:  81  i_batch:  0  loss:  3.5072438716888428\n",
            "iter:  82  i_batch:  0  loss:  3.562509298324585\n",
            "iter:  83  i_batch:  0  loss:  3.375946044921875\n",
            "iter:  84  i_batch:  0  loss:  3.5278494358062744\n",
            "iter:  85  i_batch:  0  loss:  3.233327865600586\n",
            "iter:  86  i_batch:  0  loss:  3.135256052017212\n",
            "iter:  87  i_batch:  0  loss:  3.080659866333008\n",
            "iter:  88  i_batch:  0  loss:  3.1237056255340576\n",
            "iter:  89  i_batch:  0  loss:  3.0468668937683105\n",
            "iter:  90  i_batch:  0  loss:  3.1506381034851074\n",
            "iter:  91  i_batch:  0  loss:  3.2998621463775635\n",
            "iter:  92  i_batch:  0  loss:  3.2804951667785645\n",
            "iter:  93  i_batch:  0  loss:  3.1534981727600098\n",
            "iter:  94  i_batch:  0  loss:  3.2469987869262695\n",
            "iter:  95  i_batch:  0  loss:  3.0937764644622803\n",
            "iter:  96  i_batch:  0  loss:  3.0271410942077637\n",
            "iter:  97  i_batch:  0  loss:  3.3754680156707764\n",
            "iter:  98  i_batch:  0  loss:  3.1432812213897705\n",
            "iter:  99  i_batch:  0  loss:  3.0338351726531982\n"
          ]
        }
      ],
      "source": [
        "encoders = [Transformer(371, 40), Transformer(300, 40)]\n",
        "factorcl_sup = FactorCLSUP(encoders=encoders, feat_dims=[40, 40], y_ohe_dim=3).cuda()\n",
        "train_sup_sarcasm(factorcl_sup, train_loader, num_epoch=100, num_club_iter=1)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fopg4T_fDYqq"
      },
      "outputs": [],
      "source": [
        "factorcl_sup.eval()\n",
        "\n",
        "train_embeds_x1 = np.concatenate([factorcl_sup.get_embedding(data[0][0].cuda(), data[0][2].cuda())[0].detach().cpu().numpy() for data in eval_train_loader])\n",
        "train_embeds_x2 = np.concatenate([factorcl_sup.get_embedding(data[0][0].cuda(), data[0][2].cuda())[1].detach().cpu().numpy() for data in eval_train_loader])\n",
        "train_embeds = np.concatenate([train_embeds_x1, train_embeds_x2], axis=1)\n",
        "train_labels = np.concatenate([data[3].detach().cpu().numpy() for data in eval_train_loader])\n",
        "train_labels = sarcasm_label(train_labels)\n",
        "\n",
        "test_embeds_x1 = np.concatenate([factorcl_sup.get_embedding(data[0][0].cuda(), data[0][2].cuda())[0].detach().cpu().numpy() for data in eval_test_loader])\n",
        "test_embeds_x2 = np.concatenate([factorcl_sup.get_embedding(data[0][0].cuda(), data[0][2].cuda())[1].detach().cpu().numpy() for data in eval_test_loader])\n",
        "test_embeds = np.concatenate([test_embeds_x1, test_embeds_x2], axis=1)\n",
        "test_labels = np.concatenate([data[3].detach().cpu().numpy() for data in eval_test_loader])\n",
        "test_labels = sarcasm_label(test_labels)\n",
        "\n",
        "# Train Logistic Classifier\n",
        "clf = LogisticRegression(max_iter=200).fit(train_embeds, train_labels)\n",
        "score = clf.score(test_embeds, test_labels)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "RVOYRYf5DYqq",
        "outputId": "423b3fa9-f836-481d-ac46-97f090cb78f7"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "0.6294896030245747"
            ]
          },
          "execution_count": 15,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "score"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AKmM07BzDYqq"
      },
      "source": [
        "##FactorCL-SSL"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "p02lzxQ_DYqr"
      },
      "outputs": [],
      "source": [
        "encoders = [Transformer(371, 40), Transformer(300, 40)]\n",
        "factorcl_ssl = FactorCLSSL(encoders=encoders, feat_dims=[40, 40], y_ohe_dim=3).cuda()\n",
        "train_ssl_sarcasm(factorcl_ssl, train_loader, num_epoch=100, num_club_iter=1)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DDlSUoOsDYqr"
      },
      "outputs": [],
      "source": [
        "factorcl_ssl.eval()\n",
        "\n",
        "train_embeds_x1 = np.concatenate([factorcl_ssl.get_embedding(data[0][0].cuda(), data[0][2].cuda())[0].detach().cpu().numpy() for data in eval_train_loader])\n",
        "train_embeds_x2 = np.concatenate([factorcl_ssl.get_embedding(data[0][0].cuda(), data[0][2].cuda())[1].detach().cpu().numpy() for data in eval_train_loader])\n",
        "train_embeds = np.concatenate([train_embeds_x1, train_embeds_x2], axis=1)\n",
        "train_labels = np.concatenate([data[3].detach().cpu().numpy() for data in eval_train_loader])\n",
        "train_labels = sarcasm_label(train_labels)\n",
        "\n",
        "test_embeds_x1 = np.concatenate([factorcl_ssl.get_embedding(data[0][0].cuda(), data[0][2].cuda())[0].detach().cpu().numpy() for data in eval_test_loader])\n",
        "test_embeds_x2 = np.concatenate([factorcl_ssl.get_embedding(data[0][0].cuda(), data[0][2].cuda())[1].detach().cpu().numpy() for data in eval_test_loader])\n",
        "test_embeds = np.concatenate([test_embeds_x1, test_embeds_x2], axis=1)\n",
        "test_labels = np.concatenate([data[3].detach().cpu().numpy() for data in eval_test_loader])\n",
        "test_labels = sarcasm_label(test_labels)\n",
        "\n",
        "# Train Logistic Classifier\n",
        "clf = LogisticRegression(max_iter=200).fit(train_embeds, train_labels)\n",
        "score = clf.score(test_embeds, test_labels)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "obx0KSM2DYqr",
        "outputId": "047fbdd5-977d-4f57-c05f-1781bd06fc3b"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "0.5964083175803403"
            ]
          },
          "execution_count": 12,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "score"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DYkAL9rWDYqr"
      },
      "source": [
        "##SupCon"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iaPGutr8DYqr"
      },
      "outputs": [],
      "source": [
        "encoders = [Transformer(371, 40), Transformer(300, 40)]\n",
        "\n",
        "# set use_label=False for SimCLR\n",
        "supcon_model = SupConModel(temperature=0.5, encoders=encoders, dim_ins=[40, 40], feat_dims=[40, 40], use_label=True).cuda()\n",
        "\n",
        "supcon_optim = optim.Adam(supcon_model.parameters())\n",
        "train_supcon_sarcasm(supcon_model, train_loader, supcon_optim, modalities=[0,2], num_epoch=100)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "koY0WidZDYqr"
      },
      "outputs": [],
      "source": [
        "supcon_model.eval()\n",
        "\n",
        "train_embeds_x1 = np.concatenate([supcon_model.get_embedding(data[0][0].cuda(), data[0][2].cuda())[0].detach().cpu().numpy() for data in eval_train_loader])\n",
        "train_embeds_x2 = np.concatenate([supcon_model.get_embedding(data[0][0].cuda(), data[0][2].cuda())[1].detach().cpu().numpy() for data in eval_train_loader])\n",
        "train_embeds = np.concatenate([train_embeds_x1, train_embeds_x2], axis=1)\n",
        "train_labels = np.concatenate([data[3].detach().cpu().numpy() for data in eval_train_loader])\n",
        "train_labels = sarcasm_label(train_labels)\n",
        "\n",
        "test_embeds_x1 = np.concatenate([supcon_model.get_embedding(data[0][0].cuda(), data[0][2].cuda())[0].detach().cpu().numpy() for data in eval_test_loader])\n",
        "test_embeds_x2 = np.concatenate([supcon_model.get_embedding(data[0][0].cuda(), data[0][2].cuda())[1].detach().cpu().numpy() for data in eval_test_loader])\n",
        "test_embeds = np.concatenate([test_embeds_x1, test_embeds_x2], axis=1)\n",
        "test_labels = np.concatenate([data[3].detach().cpu().numpy() for data in eval_test_loader])\n",
        "test_labels = sarcasm_label(test_labels)\n",
        "\n",
        "# Train Logistic Classifier\n",
        "clf = LogisticRegression(max_iter=200).fit(train_embeds, train_labels)\n",
        "score = clf.score(test_embeds, test_labels)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "27ExxUC6DYqr",
        "outputId": "55754174-9987-4887-cb5e-645c2a23f1bc"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "0.5160680529300568"
            ]
          },
          "execution_count": 19,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "score"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WhwX0KRkHSqL"
      },
      "source": [
        "#MIMIC"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 64,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "WqFO6jL0HSqS",
        "outputId": "d5299cf3-a1fb-4745-8b0c-63d5397eeebb"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "100%|██████████| 11/11 [00:05<00:00,  2.11it/s]\n",
            "100%|██████████| 11/11 [00:05<00:00,  2.05it/s]\n"
          ]
        }
      ],
      "source": [
        "from datasets.mimic.get_data import get_dataloader\n",
        "\n",
        "train_loader, valid_loader, test_loader = get_dataloader(\n",
        "    7, imputed_path='/content/MultiBench/im.pk')\n",
        "\n",
        "eval_train_loader, eval_valid_loader, eval_test_loader = get_dataloader(\n",
        "    7, imputed_path='/content/MultiBench/im.pk', train_shuffle=False)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QyxFC1whHSqS"
      },
      "source": [
        "##FactorCL-SUP"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 65,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "cQb7qw-HHSqS",
        "outputId": "b39b0bba-1716-4d73-8786-50b35324dc24"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "iter:  0  i_batch:  0  loss:  -0.0012989966198801994\n",
            "iter:  0  i_batch:  100  loss:  -0.2463875561952591\n",
            "iter:  0  i_batch:  200  loss:  -0.5832105875015259\n",
            "iter:  0  i_batch:  300  loss:  -0.6508282423019409\n",
            "iter:  0  i_batch:  400  loss:  -0.7278008460998535\n",
            "iter:  0  i_batch:  500  loss:  -0.7705122232437134\n",
            "iter:  0  i_batch:  600  loss:  -0.8182592391967773\n",
            "iter:  1  i_batch:  0  loss:  -1.0183138847351074\n",
            "iter:  1  i_batch:  100  loss:  -0.8504148125648499\n",
            "iter:  1  i_batch:  200  loss:  -0.8186317682266235\n",
            "iter:  1  i_batch:  300  loss:  -1.3324253559112549\n",
            "iter:  1  i_batch:  400  loss:  -1.222513198852539\n",
            "iter:  1  i_batch:  500  loss:  -0.9067964553833008\n",
            "iter:  1  i_batch:  600  loss:  -0.8614742159843445\n",
            "iter:  2  i_batch:  0  loss:  -1.2362432479858398\n",
            "iter:  2  i_batch:  100  loss:  -0.9707288146018982\n",
            "iter:  2  i_batch:  200  loss:  -0.9806453585624695\n",
            "iter:  2  i_batch:  300  loss:  -1.0745757818222046\n",
            "iter:  2  i_batch:  400  loss:  -0.6654452681541443\n",
            "iter:  2  i_batch:  500  loss:  -1.2130237817764282\n",
            "iter:  2  i_batch:  600  loss:  -1.3717926740646362\n",
            "iter:  3  i_batch:  0  loss:  -1.385075569152832\n",
            "iter:  3  i_batch:  100  loss:  -1.177972674369812\n",
            "iter:  3  i_batch:  200  loss:  -1.138024091720581\n",
            "iter:  3  i_batch:  300  loss:  -1.347625494003296\n",
            "iter:  3  i_batch:  400  loss:  -1.3332023620605469\n",
            "iter:  3  i_batch:  500  loss:  -1.3283668756484985\n",
            "iter:  3  i_batch:  600  loss:  -1.3809776306152344\n",
            "iter:  4  i_batch:  0  loss:  -1.5541335344314575\n",
            "iter:  4  i_batch:  100  loss:  -1.1414012908935547\n",
            "iter:  4  i_batch:  200  loss:  -1.089038610458374\n",
            "iter:  4  i_batch:  300  loss:  -1.3287062644958496\n",
            "iter:  4  i_batch:  400  loss:  -1.7988265752792358\n",
            "iter:  4  i_batch:  500  loss:  -1.5673491954803467\n",
            "iter:  4  i_batch:  600  loss:  -1.6143760681152344\n",
            "iter:  5  i_batch:  0  loss:  -1.4291075468063354\n",
            "iter:  5  i_batch:  100  loss:  -1.5899670124053955\n",
            "iter:  5  i_batch:  200  loss:  -1.564143180847168\n",
            "iter:  5  i_batch:  300  loss:  -0.9226900339126587\n",
            "iter:  5  i_batch:  400  loss:  -1.2433569431304932\n",
            "iter:  5  i_batch:  500  loss:  -1.4997469186782837\n",
            "iter:  5  i_batch:  600  loss:  -1.5986305475234985\n",
            "iter:  6  i_batch:  0  loss:  -1.7559144496917725\n",
            "iter:  6  i_batch:  100  loss:  -1.8194341659545898\n",
            "iter:  6  i_batch:  200  loss:  -1.6761971712112427\n",
            "iter:  6  i_batch:  300  loss:  -1.3228517770767212\n",
            "iter:  6  i_batch:  400  loss:  -1.801643967628479\n",
            "iter:  6  i_batch:  500  loss:  -1.797190546989441\n",
            "iter:  6  i_batch:  600  loss:  -1.213261604309082\n",
            "iter:  7  i_batch:  0  loss:  -1.234779715538025\n",
            "iter:  7  i_batch:  100  loss:  -1.8459200859069824\n",
            "iter:  7  i_batch:  200  loss:  -1.6291757822036743\n",
            "iter:  7  i_batch:  300  loss:  -1.8458904027938843\n",
            "iter:  7  i_batch:  400  loss:  -1.3553235530853271\n",
            "iter:  7  i_batch:  500  loss:  -1.153630256652832\n",
            "iter:  7  i_batch:  600  loss:  -1.6691190004348755\n",
            "iter:  8  i_batch:  0  loss:  -1.840094804763794\n",
            "iter:  8  i_batch:  100  loss:  -1.2960703372955322\n",
            "iter:  8  i_batch:  200  loss:  -1.373004674911499\n",
            "iter:  8  i_batch:  300  loss:  -1.6621216535568237\n",
            "iter:  8  i_batch:  400  loss:  -1.7253553867340088\n",
            "iter:  8  i_batch:  500  loss:  -1.8216700553894043\n",
            "iter:  8  i_batch:  600  loss:  -1.9599533081054688\n",
            "iter:  9  i_batch:  0  loss:  -1.5207377672195435\n",
            "iter:  9  i_batch:  100  loss:  -1.497999668121338\n",
            "iter:  9  i_batch:  200  loss:  -1.7067142724990845\n",
            "iter:  9  i_batch:  300  loss:  -1.9368577003479004\n",
            "iter:  9  i_batch:  400  loss:  -1.8783372640609741\n",
            "iter:  9  i_batch:  500  loss:  -1.6679706573486328\n",
            "iter:  9  i_batch:  600  loss:  -1.08932626247406\n",
            "iter:  10  i_batch:  0  loss:  -1.4916956424713135\n",
            "iter:  10  i_batch:  100  loss:  -1.5659552812576294\n",
            "iter:  10  i_batch:  200  loss:  -1.4953460693359375\n",
            "iter:  10  i_batch:  300  loss:  -1.531096339225769\n",
            "iter:  10  i_batch:  400  loss:  -1.703382134437561\n",
            "iter:  10  i_batch:  500  loss:  -2.255861282348633\n",
            "iter:  10  i_batch:  600  loss:  -1.7569451332092285\n",
            "iter:  11  i_batch:  0  loss:  -1.9469014406204224\n",
            "iter:  11  i_batch:  100  loss:  -1.5182437896728516\n",
            "iter:  11  i_batch:  200  loss:  -1.637331485748291\n",
            "iter:  11  i_batch:  300  loss:  -1.7552509307861328\n",
            "iter:  11  i_batch:  400  loss:  -1.6280875205993652\n",
            "iter:  11  i_batch:  500  loss:  -1.7466896772384644\n",
            "iter:  11  i_batch:  600  loss:  -1.443756103515625\n",
            "iter:  12  i_batch:  0  loss:  -2.244349479675293\n",
            "iter:  12  i_batch:  100  loss:  -1.544471025466919\n",
            "iter:  12  i_batch:  200  loss:  -1.1707839965820312\n",
            "iter:  12  i_batch:  300  loss:  -1.6514174938201904\n",
            "iter:  12  i_batch:  400  loss:  -1.3828468322753906\n",
            "iter:  12  i_batch:  500  loss:  -1.3446025848388672\n",
            "iter:  12  i_batch:  600  loss:  -1.4015986919403076\n",
            "iter:  13  i_batch:  0  loss:  -1.9635004997253418\n",
            "iter:  13  i_batch:  100  loss:  -1.615199327468872\n",
            "iter:  13  i_batch:  200  loss:  -1.7992749214172363\n",
            "iter:  13  i_batch:  300  loss:  -2.03013277053833\n",
            "iter:  13  i_batch:  400  loss:  -1.839434266090393\n",
            "iter:  13  i_batch:  500  loss:  -1.841787338256836\n",
            "iter:  13  i_batch:  600  loss:  -1.8203049898147583\n",
            "iter:  14  i_batch:  0  loss:  -2.0607786178588867\n",
            "iter:  14  i_batch:  100  loss:  -2.15256404876709\n",
            "iter:  14  i_batch:  200  loss:  -1.8011703491210938\n",
            "iter:  14  i_batch:  300  loss:  -2.056570053100586\n",
            "iter:  14  i_batch:  400  loss:  -1.654267430305481\n",
            "iter:  14  i_batch:  500  loss:  -2.0437984466552734\n",
            "iter:  14  i_batch:  600  loss:  -1.7611712217330933\n",
            "iter:  15  i_batch:  0  loss:  -1.6885292530059814\n",
            "iter:  15  i_batch:  100  loss:  -1.8389852046966553\n",
            "iter:  15  i_batch:  200  loss:  -2.0198099613189697\n",
            "iter:  15  i_batch:  300  loss:  -1.784873604774475\n",
            "iter:  15  i_batch:  400  loss:  -1.9512768983840942\n",
            "iter:  15  i_batch:  500  loss:  -2.0298337936401367\n",
            "iter:  15  i_batch:  600  loss:  -2.0258781909942627\n",
            "iter:  16  i_batch:  0  loss:  -2.24532151222229\n",
            "iter:  16  i_batch:  100  loss:  -1.8902900218963623\n",
            "iter:  16  i_batch:  200  loss:  -1.8092625141143799\n",
            "iter:  16  i_batch:  300  loss:  -1.2135753631591797\n",
            "iter:  16  i_batch:  400  loss:  -1.6297876834869385\n",
            "iter:  16  i_batch:  500  loss:  -1.8067348003387451\n",
            "iter:  16  i_batch:  600  loss:  -1.5464894771575928\n",
            "iter:  17  i_batch:  0  loss:  -1.8933161497116089\n",
            "iter:  17  i_batch:  100  loss:  -2.54934024810791\n",
            "iter:  17  i_batch:  200  loss:  -2.266740322113037\n",
            "iter:  17  i_batch:  300  loss:  -2.0111539363861084\n",
            "iter:  17  i_batch:  400  loss:  -1.597668170928955\n",
            "iter:  17  i_batch:  500  loss:  -2.57501220703125\n",
            "iter:  17  i_batch:  600  loss:  -2.0966124534606934\n",
            "iter:  18  i_batch:  0  loss:  -2.2372779846191406\n",
            "iter:  18  i_batch:  100  loss:  -1.8401401042938232\n",
            "iter:  18  i_batch:  200  loss:  -2.474608898162842\n",
            "iter:  18  i_batch:  300  loss:  -2.2913312911987305\n",
            "iter:  18  i_batch:  400  loss:  -2.3477883338928223\n",
            "iter:  18  i_batch:  500  loss:  -1.7914838790893555\n",
            "iter:  18  i_batch:  600  loss:  -1.8324592113494873\n",
            "iter:  19  i_batch:  0  loss:  -2.524364948272705\n",
            "iter:  19  i_batch:  100  loss:  -1.814623236656189\n",
            "iter:  19  i_batch:  200  loss:  -2.4873862266540527\n",
            "iter:  19  i_batch:  300  loss:  -2.426302909851074\n",
            "iter:  19  i_batch:  400  loss:  -1.82527494430542\n",
            "iter:  19  i_batch:  500  loss:  -2.0646514892578125\n",
            "iter:  19  i_batch:  600  loss:  -1.9947938919067383\n",
            "iter:  20  i_batch:  0  loss:  -1.9884772300720215\n",
            "iter:  20  i_batch:  100  loss:  -2.561591625213623\n",
            "iter:  20  i_batch:  200  loss:  -2.016880989074707\n",
            "iter:  20  i_batch:  300  loss:  -2.343414783477783\n",
            "iter:  20  i_batch:  400  loss:  -2.3082656860351562\n",
            "iter:  20  i_batch:  500  loss:  -2.534693956375122\n",
            "iter:  20  i_batch:  600  loss:  -2.0034615993499756\n",
            "iter:  21  i_batch:  0  loss:  -2.3769819736480713\n",
            "iter:  21  i_batch:  100  loss:  -2.365145683288574\n",
            "iter:  21  i_batch:  200  loss:  -2.7161202430725098\n",
            "iter:  21  i_batch:  300  loss:  -1.9398205280303955\n",
            "iter:  21  i_batch:  400  loss:  -2.540271759033203\n",
            "iter:  21  i_batch:  500  loss:  -2.4574451446533203\n",
            "iter:  21  i_batch:  600  loss:  -3.16694974899292\n",
            "iter:  22  i_batch:  0  loss:  -2.472994327545166\n",
            "iter:  22  i_batch:  100  loss:  -2.6019554138183594\n",
            "iter:  22  i_batch:  200  loss:  -2.8366174697875977\n",
            "iter:  22  i_batch:  300  loss:  -2.390040159225464\n",
            "iter:  22  i_batch:  400  loss:  -2.529249668121338\n",
            "iter:  22  i_batch:  500  loss:  -2.278858184814453\n",
            "iter:  22  i_batch:  600  loss:  -2.4989142417907715\n",
            "iter:  23  i_batch:  0  loss:  -2.5472538471221924\n",
            "iter:  23  i_batch:  100  loss:  -2.6855032444000244\n",
            "iter:  23  i_batch:  200  loss:  -2.7795872688293457\n",
            "iter:  23  i_batch:  300  loss:  -2.698218584060669\n",
            "iter:  23  i_batch:  400  loss:  -2.273698329925537\n",
            "iter:  23  i_batch:  500  loss:  -2.0608136653900146\n",
            "iter:  23  i_batch:  600  loss:  -3.0302248001098633\n",
            "iter:  24  i_batch:  0  loss:  -2.7528414726257324\n",
            "iter:  24  i_batch:  100  loss:  -2.8396499156951904\n",
            "iter:  24  i_batch:  200  loss:  -2.859177350997925\n",
            "iter:  24  i_batch:  300  loss:  -2.9230635166168213\n",
            "iter:  24  i_batch:  400  loss:  -2.409689426422119\n",
            "iter:  24  i_batch:  500  loss:  -2.7113020420074463\n",
            "iter:  24  i_batch:  600  loss:  -2.456892490386963\n",
            "iter:  25  i_batch:  0  loss:  -2.531332492828369\n",
            "iter:  25  i_batch:  100  loss:  -2.53786301612854\n",
            "iter:  25  i_batch:  200  loss:  -2.7011396884918213\n",
            "iter:  25  i_batch:  300  loss:  -2.1107828617095947\n",
            "iter:  25  i_batch:  400  loss:  -2.9995405673980713\n",
            "iter:  25  i_batch:  500  loss:  -3.012389659881592\n",
            "iter:  25  i_batch:  600  loss:  -2.6809074878692627\n",
            "iter:  26  i_batch:  0  loss:  -2.744882106781006\n",
            "iter:  26  i_batch:  100  loss:  -3.090906858444214\n",
            "iter:  26  i_batch:  200  loss:  -2.744732618331909\n",
            "iter:  26  i_batch:  300  loss:  -2.83301043510437\n",
            "iter:  26  i_batch:  400  loss:  -2.932544231414795\n",
            "iter:  26  i_batch:  500  loss:  -2.364086151123047\n",
            "iter:  26  i_batch:  600  loss:  -2.485280990600586\n",
            "iter:  27  i_batch:  0  loss:  -3.1551637649536133\n",
            "iter:  27  i_batch:  100  loss:  -2.9077672958374023\n",
            "iter:  27  i_batch:  200  loss:  -3.0302023887634277\n",
            "iter:  27  i_batch:  300  loss:  -3.0141537189483643\n",
            "iter:  27  i_batch:  400  loss:  -2.660649061203003\n",
            "iter:  27  i_batch:  500  loss:  -3.066291332244873\n",
            "iter:  27  i_batch:  600  loss:  -2.754704713821411\n",
            "iter:  28  i_batch:  0  loss:  -2.9776110649108887\n",
            "iter:  28  i_batch:  100  loss:  -2.579167366027832\n",
            "iter:  28  i_batch:  200  loss:  -3.305588483810425\n",
            "iter:  28  i_batch:  300  loss:  -2.7967567443847656\n",
            "iter:  28  i_batch:  400  loss:  -2.782949686050415\n",
            "iter:  28  i_batch:  500  loss:  -2.8204870223999023\n",
            "iter:  28  i_batch:  600  loss:  -2.8337202072143555\n",
            "iter:  29  i_batch:  0  loss:  -2.994739294052124\n",
            "iter:  29  i_batch:  100  loss:  -3.2117507457733154\n",
            "iter:  29  i_batch:  200  loss:  -3.019622802734375\n",
            "iter:  29  i_batch:  300  loss:  -2.8001060485839844\n",
            "iter:  29  i_batch:  400  loss:  -2.9164724349975586\n",
            "iter:  29  i_batch:  500  loss:  -3.059520959854126\n",
            "iter:  29  i_batch:  600  loss:  -2.6599678993225098\n"
          ]
        }
      ],
      "source": [
        "#encoders = [MLP(5, 10, 10, dropout=False),\n",
        "#            GRUWithLinear(12, 30, 15, flatten=True, batch_first=True)]\n",
        "\n",
        "encoders = [MLP(5, 10, 10, dropout=False),\n",
        "            GRU(12, 30, dropout=False, batch_first=True, flatten=True)]\n",
        "\n",
        "factorcl_sup = FactorCLSUP(encoders=encoders, feat_dims=[10, 720], y_ohe_dim=2).cuda()\n",
        "train_sup_mimic(factorcl_sup, train_loader, num_epoch=30, num_club_iter=1)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "QTJiVmH3HSqS",
        "outputId": "7065427a-669c-4051-eab3-f52100ef1b1f"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.10/dist-packages/sklearn/linear_model/_logistic.py:458: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
            "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
            "\n",
            "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
            "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
            "Please also refer to the documentation for alternative solver options:\n",
            "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
            "  n_iter_i = _check_optimize_result(\n"
          ]
        }
      ],
      "source": [
        "factorcl_sup.eval()\n",
        "\n",
        "train_embeds_x1 = np.concatenate([factorcl_sup.get_embedding(data[0].float().cuda(), data[1].float().cuda())[0].detach().cpu().numpy() for data in eval_train_loader])\n",
        "train_embeds_x2 = np.concatenate([factorcl_sup.get_embedding(data[0].float().cuda(), data[1].float().cuda())[1].detach().cpu().numpy() for data in eval_train_loader])\n",
        "train_embeds = np.concatenate([train_embeds_x1, train_embeds_x2], axis=1)\n",
        "train_labels = np.concatenate([data[2].detach().cpu().numpy() for data in eval_train_loader])\n",
        "\n",
        "test_embeds_x1 = np.concatenate([factorcl_sup.get_embedding(data[0].float().cuda(), data[1].float().cuda())[0].detach().cpu().numpy() for data in eval_train_loader])\n",
        "test_embeds_x2 = np.concatenate([factorcl_sup.get_embedding(data[0].float().cuda(), data[1].float().cuda())[1].detach().cpu().numpy() for data in eval_train_loader])\n",
        "test_embeds = np.concatenate([test_embeds_x1, test_embeds_x2], axis=1)\n",
        "test_labels = np.concatenate([data[2].detach().cpu().numpy() for data in eval_train_loader])\n",
        "\n",
        "# Train Logistic Classifier\n",
        "clf = LogisticRegression(max_iter=200).fit(train_embeds, train_labels)\n",
        "score = clf.score(test_embeds, test_labels)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "3w0dkIepHSqS",
        "outputId": "11a241b5-679e-4eb2-81f5-d95765b2e507"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "0.7897903652320546"
            ]
          },
          "execution_count": 35,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "score"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CJBo3NAtHSqT"
      },
      "source": [
        "##FactorCL-SSL"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "1uT354BBHSqT",
        "outputId": "ab7cd438-6695-4889-daa2-80c337377560"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "iter:  0  i_batch:  0  loss:  0.0006223861128091812\n",
            "iter:  0  i_batch:  100  loss:  -3.5920844078063965\n",
            "iter:  0  i_batch:  200  loss:  -4.292657852172852\n",
            "iter:  0  i_batch:  300  loss:  -4.513991355895996\n",
            "iter:  0  i_batch:  400  loss:  -5.4547119140625\n",
            "iter:  0  i_batch:  500  loss:  -5.684263706207275\n",
            "iter:  0  i_batch:  600  loss:  -3.789016008377075\n",
            "iter:  1  i_batch:  0  loss:  -5.70589542388916\n",
            "iter:  1  i_batch:  100  loss:  -6.036656856536865\n",
            "iter:  1  i_batch:  200  loss:  -5.840413570404053\n",
            "iter:  1  i_batch:  300  loss:  -6.080589294433594\n",
            "iter:  1  i_batch:  400  loss:  -5.495912551879883\n",
            "iter:  1  i_batch:  500  loss:  -5.943778038024902\n",
            "iter:  1  i_batch:  600  loss:  -7.005795955657959\n",
            "iter:  2  i_batch:  0  loss:  -5.923552989959717\n",
            "iter:  2  i_batch:  100  loss:  -6.397305488586426\n",
            "iter:  2  i_batch:  200  loss:  -7.055599212646484\n",
            "iter:  2  i_batch:  300  loss:  -6.3207011222839355\n",
            "iter:  2  i_batch:  400  loss:  -6.897139072418213\n",
            "iter:  2  i_batch:  500  loss:  -6.27215051651001\n",
            "iter:  2  i_batch:  600  loss:  -6.160101413726807\n",
            "iter:  3  i_batch:  0  loss:  -6.150088310241699\n",
            "iter:  3  i_batch:  100  loss:  -6.3176164627075195\n",
            "iter:  3  i_batch:  200  loss:  -7.36838960647583\n",
            "iter:  3  i_batch:  300  loss:  -4.8686203956604\n",
            "iter:  3  i_batch:  400  loss:  -6.931423187255859\n",
            "iter:  3  i_batch:  500  loss:  -7.055408477783203\n",
            "iter:  3  i_batch:  600  loss:  -5.7565693855285645\n",
            "iter:  4  i_batch:  0  loss:  -5.697240829467773\n",
            "iter:  4  i_batch:  100  loss:  -7.213937759399414\n",
            "iter:  4  i_batch:  200  loss:  -6.414517879486084\n",
            "iter:  4  i_batch:  300  loss:  -6.63500452041626\n",
            "iter:  4  i_batch:  400  loss:  -6.925203800201416\n",
            "iter:  4  i_batch:  500  loss:  -7.971752166748047\n",
            "iter:  4  i_batch:  600  loss:  -6.317779064178467\n",
            "iter:  5  i_batch:  0  loss:  -7.426606178283691\n",
            "iter:  5  i_batch:  100  loss:  -7.102625846862793\n",
            "iter:  5  i_batch:  200  loss:  -5.554515361785889\n",
            "iter:  5  i_batch:  300  loss:  -6.984551906585693\n",
            "iter:  5  i_batch:  400  loss:  -4.592079162597656\n",
            "iter:  5  i_batch:  500  loss:  -6.657525539398193\n",
            "iter:  5  i_batch:  600  loss:  -6.299755096435547\n",
            "iter:  6  i_batch:  0  loss:  -6.40272855758667\n",
            "iter:  6  i_batch:  100  loss:  -6.424211502075195\n",
            "iter:  6  i_batch:  200  loss:  -6.2604875564575195\n",
            "iter:  6  i_batch:  300  loss:  -7.516908168792725\n",
            "iter:  6  i_batch:  400  loss:  -6.286506652832031\n",
            "iter:  6  i_batch:  500  loss:  -6.51344633102417\n",
            "iter:  6  i_batch:  600  loss:  -6.947033882141113\n",
            "iter:  7  i_batch:  0  loss:  -6.254836559295654\n",
            "iter:  7  i_batch:  100  loss:  -6.139508247375488\n",
            "iter:  7  i_batch:  200  loss:  -6.215014457702637\n",
            "iter:  7  i_batch:  300  loss:  -5.48787784576416\n",
            "iter:  7  i_batch:  400  loss:  -6.263719081878662\n",
            "iter:  7  i_batch:  500  loss:  -7.171139240264893\n",
            "iter:  7  i_batch:  600  loss:  -5.714439868927002\n",
            "iter:  8  i_batch:  0  loss:  -6.583385467529297\n",
            "iter:  8  i_batch:  100  loss:  -7.731355667114258\n",
            "iter:  8  i_batch:  200  loss:  -7.505163192749023\n",
            "iter:  8  i_batch:  300  loss:  -6.248068332672119\n",
            "iter:  8  i_batch:  400  loss:  -4.602375507354736\n",
            "iter:  8  i_batch:  500  loss:  -7.244750499725342\n",
            "iter:  8  i_batch:  600  loss:  -6.506217002868652\n",
            "iter:  9  i_batch:  0  loss:  -6.971248149871826\n",
            "iter:  9  i_batch:  100  loss:  -4.557550430297852\n",
            "iter:  9  i_batch:  200  loss:  -7.830191612243652\n",
            "iter:  9  i_batch:  300  loss:  -6.44489860534668\n",
            "iter:  9  i_batch:  400  loss:  -6.08656644821167\n",
            "iter:  9  i_batch:  500  loss:  -8.11015510559082\n",
            "iter:  9  i_batch:  600  loss:  -7.577457427978516\n",
            "iter:  10  i_batch:  0  loss:  -7.516018390655518\n",
            "iter:  10  i_batch:  100  loss:  -5.959195137023926\n",
            "iter:  10  i_batch:  200  loss:  -4.36962366104126\n",
            "iter:  10  i_batch:  300  loss:  -8.048295974731445\n",
            "iter:  10  i_batch:  400  loss:  -5.96358585357666\n",
            "iter:  10  i_batch:  500  loss:  -7.688719272613525\n",
            "iter:  10  i_batch:  600  loss:  -8.245295524597168\n",
            "iter:  11  i_batch:  0  loss:  -5.118178844451904\n",
            "iter:  11  i_batch:  100  loss:  -8.154534339904785\n",
            "iter:  11  i_batch:  200  loss:  -6.680581092834473\n",
            "iter:  11  i_batch:  300  loss:  -8.038406372070312\n",
            "iter:  11  i_batch:  400  loss:  -8.089735984802246\n",
            "iter:  11  i_batch:  500  loss:  -7.34019660949707\n",
            "iter:  11  i_batch:  600  loss:  -7.888523101806641\n",
            "iter:  12  i_batch:  0  loss:  -4.5143842697143555\n",
            "iter:  12  i_batch:  100  loss:  -6.66663932800293\n",
            "iter:  12  i_batch:  200  loss:  -7.833639621734619\n",
            "iter:  12  i_batch:  300  loss:  -8.422159194946289\n",
            "iter:  12  i_batch:  400  loss:  -8.054841995239258\n",
            "iter:  12  i_batch:  500  loss:  -8.221678733825684\n",
            "iter:  12  i_batch:  600  loss:  -6.471945762634277\n",
            "iter:  13  i_batch:  0  loss:  -8.356306076049805\n",
            "iter:  13  i_batch:  100  loss:  -5.165944576263428\n",
            "iter:  13  i_batch:  200  loss:  -6.411352634429932\n",
            "iter:  13  i_batch:  300  loss:  -7.84371280670166\n",
            "iter:  13  i_batch:  400  loss:  -8.068949699401855\n",
            "iter:  13  i_batch:  500  loss:  -6.515161037445068\n",
            "iter:  13  i_batch:  600  loss:  -6.923438549041748\n",
            "iter:  14  i_batch:  0  loss:  -7.490320205688477\n",
            "iter:  14  i_batch:  100  loss:  -8.613906860351562\n",
            "iter:  14  i_batch:  200  loss:  -7.052690029144287\n",
            "iter:  14  i_batch:  300  loss:  -8.652649879455566\n",
            "iter:  14  i_batch:  400  loss:  -7.130681037902832\n",
            "iter:  14  i_batch:  500  loss:  -7.813324928283691\n",
            "iter:  14  i_batch:  600  loss:  -8.504664421081543\n",
            "iter:  15  i_batch:  0  loss:  -7.317049026489258\n",
            "iter:  15  i_batch:  100  loss:  -7.196036338806152\n",
            "iter:  15  i_batch:  200  loss:  -5.255153179168701\n",
            "iter:  15  i_batch:  300  loss:  -8.50355052947998\n",
            "iter:  15  i_batch:  400  loss:  -8.269947052001953\n",
            "iter:  15  i_batch:  500  loss:  -8.589110374450684\n",
            "iter:  15  i_batch:  600  loss:  -7.614760398864746\n",
            "iter:  16  i_batch:  0  loss:  -8.675949096679688\n",
            "iter:  16  i_batch:  100  loss:  -6.538235664367676\n",
            "iter:  16  i_batch:  200  loss:  -8.300971984863281\n",
            "iter:  16  i_batch:  300  loss:  -5.886682987213135\n",
            "iter:  16  i_batch:  400  loss:  -4.30562162399292\n",
            "iter:  16  i_batch:  500  loss:  -7.736268043518066\n",
            "iter:  16  i_batch:  600  loss:  -8.489278793334961\n",
            "iter:  17  i_batch:  0  loss:  -3.587075710296631\n",
            "iter:  17  i_batch:  100  loss:  -8.465181350708008\n",
            "iter:  17  i_batch:  200  loss:  -8.856968879699707\n",
            "iter:  17  i_batch:  300  loss:  -4.699093341827393\n",
            "iter:  17  i_batch:  400  loss:  -4.983674049377441\n",
            "iter:  17  i_batch:  500  loss:  -8.748300552368164\n",
            "iter:  17  i_batch:  600  loss:  -8.238848686218262\n",
            "iter:  18  i_batch:  0  loss:  -8.542336463928223\n",
            "iter:  18  i_batch:  100  loss:  -9.665167808532715\n",
            "iter:  18  i_batch:  200  loss:  -7.548946380615234\n",
            "iter:  18  i_batch:  300  loss:  -5.013677597045898\n",
            "iter:  18  i_batch:  400  loss:  -8.583763122558594\n",
            "iter:  18  i_batch:  500  loss:  -6.808431148529053\n",
            "iter:  18  i_batch:  600  loss:  -9.065102577209473\n",
            "iter:  19  i_batch:  0  loss:  -4.052534103393555\n",
            "iter:  19  i_batch:  100  loss:  -5.0537824630737305\n",
            "iter:  19  i_batch:  200  loss:  -1.40336012840271\n",
            "iter:  19  i_batch:  300  loss:  -7.359401226043701\n",
            "iter:  19  i_batch:  400  loss:  -8.437915802001953\n",
            "iter:  19  i_batch:  500  loss:  -7.731228828430176\n",
            "iter:  19  i_batch:  600  loss:  -7.636984825134277\n",
            "iter:  20  i_batch:  0  loss:  -8.974037170410156\n",
            "iter:  20  i_batch:  100  loss:  -4.965627670288086\n",
            "iter:  20  i_batch:  200  loss:  -7.816171646118164\n",
            "iter:  20  i_batch:  300  loss:  -8.403440475463867\n",
            "iter:  20  i_batch:  400  loss:  -8.98563003540039\n",
            "iter:  20  i_batch:  500  loss:  -8.637144088745117\n",
            "iter:  20  i_batch:  600  loss:  -4.473073959350586\n",
            "iter:  21  i_batch:  0  loss:  -9.426736831665039\n",
            "iter:  21  i_batch:  100  loss:  -4.90972900390625\n",
            "iter:  21  i_batch:  200  loss:  -8.362104415893555\n",
            "iter:  21  i_batch:  300  loss:  -5.8875274658203125\n",
            "iter:  21  i_batch:  400  loss:  -5.314492225646973\n",
            "iter:  21  i_batch:  500  loss:  -3.6536128520965576\n",
            "iter:  21  i_batch:  600  loss:  -9.017744064331055\n",
            "iter:  22  i_batch:  0  loss:  -9.072072982788086\n",
            "iter:  22  i_batch:  100  loss:  -7.221712589263916\n",
            "iter:  22  i_batch:  200  loss:  -8.57865047454834\n",
            "iter:  22  i_batch:  300  loss:  -7.848546981811523\n",
            "iter:  22  i_batch:  400  loss:  -9.13613510131836\n",
            "iter:  22  i_batch:  500  loss:  -6.314939498901367\n",
            "iter:  22  i_batch:  600  loss:  -5.843133449554443\n",
            "iter:  23  i_batch:  0  loss:  -5.908993721008301\n",
            "iter:  23  i_batch:  100  loss:  -8.650643348693848\n",
            "iter:  23  i_batch:  200  loss:  -9.246330261230469\n",
            "iter:  23  i_batch:  300  loss:  -7.282070159912109\n",
            "iter:  23  i_batch:  400  loss:  -6.859094619750977\n",
            "iter:  23  i_batch:  500  loss:  -9.24421501159668\n",
            "iter:  23  i_batch:  600  loss:  -6.79875373840332\n",
            "iter:  24  i_batch:  0  loss:  -8.0668363571167\n",
            "iter:  24  i_batch:  100  loss:  -8.475898742675781\n",
            "iter:  24  i_batch:  200  loss:  -9.451217651367188\n",
            "iter:  24  i_batch:  300  loss:  -7.478752136230469\n",
            "iter:  24  i_batch:  400  loss:  -5.5274739265441895\n",
            "iter:  24  i_batch:  500  loss:  -6.142536163330078\n",
            "iter:  24  i_batch:  600  loss:  -9.182416915893555\n",
            "iter:  25  i_batch:  0  loss:  -7.5402679443359375\n",
            "iter:  25  i_batch:  100  loss:  -7.615032196044922\n",
            "iter:  25  i_batch:  200  loss:  -6.453925132751465\n",
            "iter:  25  i_batch:  300  loss:  -9.628724098205566\n",
            "iter:  25  i_batch:  400  loss:  -5.430488586425781\n",
            "iter:  25  i_batch:  500  loss:  -6.8378987312316895\n",
            "iter:  25  i_batch:  600  loss:  -8.33421516418457\n",
            "iter:  26  i_batch:  0  loss:  -9.905061721801758\n",
            "iter:  26  i_batch:  100  loss:  -5.788702964782715\n",
            "iter:  26  i_batch:  200  loss:  -7.536861896514893\n",
            "iter:  26  i_batch:  300  loss:  -7.227022171020508\n",
            "iter:  26  i_batch:  400  loss:  -8.080766677856445\n",
            "iter:  26  i_batch:  500  loss:  -7.957943439483643\n",
            "iter:  26  i_batch:  600  loss:  -6.604194641113281\n",
            "iter:  27  i_batch:  0  loss:  -8.834962844848633\n",
            "iter:  27  i_batch:  100  loss:  -6.8940749168396\n",
            "iter:  27  i_batch:  200  loss:  -8.67134952545166\n",
            "iter:  27  i_batch:  300  loss:  -9.645859718322754\n",
            "iter:  27  i_batch:  400  loss:  -8.515857696533203\n",
            "iter:  27  i_batch:  500  loss:  -7.9745893478393555\n",
            "iter:  27  i_batch:  600  loss:  -7.718776702880859\n",
            "iter:  28  i_batch:  0  loss:  -9.768109321594238\n",
            "iter:  28  i_batch:  100  loss:  -5.882572174072266\n",
            "iter:  28  i_batch:  200  loss:  -8.279226303100586\n",
            "iter:  28  i_batch:  300  loss:  -9.559467315673828\n",
            "iter:  28  i_batch:  400  loss:  -8.50065803527832\n",
            "iter:  28  i_batch:  500  loss:  -8.330896377563477\n",
            "iter:  28  i_batch:  600  loss:  -9.753822326660156\n",
            "iter:  29  i_batch:  0  loss:  -8.616969108581543\n",
            "iter:  29  i_batch:  100  loss:  -6.553716659545898\n",
            "iter:  29  i_batch:  200  loss:  -10.237489700317383\n",
            "iter:  29  i_batch:  300  loss:  -7.970396041870117\n",
            "iter:  29  i_batch:  400  loss:  -6.438853740692139\n",
            "iter:  29  i_batch:  500  loss:  -7.133041858673096\n",
            "iter:  29  i_batch:  600  loss:  -9.860801696777344\n"
          ]
        }
      ],
      "source": [
        "#encoders = [MLP(5, 10, 10, dropout=False),\n",
        "#            GRUWithLinear(12, 30, 15, flatten=True, batch_first=True)]\n",
        "\n",
        "encoders = [MLP(5, 10, 10, dropout=False),\n",
        "            GRU(12, 30, dropout=False, batch_first=True, flatten=True)]\n",
        "\n",
        "factorcl_ssl = FactorCLSSL(encoders=encoders, feat_dims=[10, 720], y_ohe_dim=2).cuda()\n",
        "train_ssl_mimic(factorcl_ssl, train_loader, num_epoch=30, num_club_iter=1)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "diBx-pONHSqT",
        "outputId": "f1eb1c87-5e76-422b-c67e-144774493d8a"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.10/dist-packages/sklearn/linear_model/_logistic.py:458: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
            "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
            "\n",
            "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
            "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
            "Please also refer to the documentation for alternative solver options:\n",
            "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
            "  n_iter_i = _check_optimize_result(\n"
          ]
        }
      ],
      "source": [
        "factorcl_ssl.eval()\n",
        "\n",
        "train_embeds_x1 = np.concatenate([factorcl_ssl.get_embedding(data[0].float().cuda(), data[1].float().cuda())[0].detach().cpu().numpy() for data in eval_train_loader])\n",
        "train_embeds_x2 = np.concatenate([factorcl_ssl.get_embedding(data[0].float().cuda(), data[1].float().cuda())[1].detach().cpu().numpy() for data in eval_train_loader])\n",
        "train_embeds = np.concatenate([train_embeds_x1, train_embeds_x2], axis=1)\n",
        "train_labels = np.concatenate([data[2].detach().cpu().numpy() for data in eval_train_loader])\n",
        "\n",
        "test_embeds_x1 = np.concatenate([factorcl_ssl.get_embedding(data[0].float().cuda(), data[1].float().cuda())[0].detach().cpu().numpy() for data in eval_train_loader])\n",
        "test_embeds_x2 = np.concatenate([factorcl_ssl.get_embedding(data[0].float().cuda(), data[1].float().cuda())[1].detach().cpu().numpy() for data in eval_train_loader])\n",
        "test_embeds = np.concatenate([test_embeds_x1, test_embeds_x2], axis=1)\n",
        "test_labels = np.concatenate([data[2].detach().cpu().numpy() for data in eval_train_loader])\n",
        "\n",
        "# Train Logistic Classifier\n",
        "clf = LogisticRegression(max_iter=200).fit(train_embeds, train_labels)\n",
        "score = clf.score(test_embeds, test_labels)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "3j63rz23HSqT",
        "outputId": "d23d0994-7e2a-4982-bf91-3026830797ec"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "0.6653891848388457"
            ]
          },
          "execution_count": 11,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "score"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WmI_QMOgHSqT"
      },
      "source": [
        "##SupCon"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "mF6m3Y8KHSqT",
        "outputId": "2e1eb96d-702a-4168-cfd1-222a17feb21f"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "iter:  0  i_batch:  0  loss:  33.60186767578125\n",
            "iter:  0  i_batch:  100  loss:  31.204633712768555\n",
            "iter:  0  i_batch:  200  loss:  31.2003116607666\n",
            "iter:  0  i_batch:  300  loss:  31.190750122070312\n",
            "iter:  0  i_batch:  400  loss:  31.21358299255371\n",
            "iter:  0  i_batch:  500  loss:  31.224945068359375\n",
            "iter:  0  i_batch:  600  loss:  31.217573165893555\n",
            "iter:  1  i_batch:  0  loss:  31.2016544342041\n",
            "iter:  1  i_batch:  100  loss:  31.154462814331055\n",
            "iter:  1  i_batch:  200  loss:  31.185989379882812\n",
            "iter:  1  i_batch:  300  loss:  31.21261978149414\n",
            "iter:  1  i_batch:  400  loss:  31.23200035095215\n",
            "iter:  1  i_batch:  500  loss:  31.17060661315918\n",
            "iter:  1  i_batch:  600  loss:  31.194486618041992\n",
            "iter:  2  i_batch:  0  loss:  31.2025089263916\n",
            "iter:  2  i_batch:  100  loss:  31.1802978515625\n",
            "iter:  2  i_batch:  200  loss:  31.162372589111328\n",
            "iter:  2  i_batch:  300  loss:  31.20770263671875\n",
            "iter:  2  i_batch:  400  loss:  31.207172393798828\n",
            "iter:  2  i_batch:  500  loss:  31.17645835876465\n",
            "iter:  2  i_batch:  600  loss:  31.118268966674805\n",
            "iter:  3  i_batch:  0  loss:  31.18617057800293\n",
            "iter:  3  i_batch:  100  loss:  31.13673210144043\n",
            "iter:  3  i_batch:  200  loss:  31.215503692626953\n",
            "iter:  3  i_batch:  300  loss:  31.205286026000977\n",
            "iter:  3  i_batch:  400  loss:  31.211902618408203\n",
            "iter:  3  i_batch:  500  loss:  31.063934326171875\n",
            "iter:  3  i_batch:  600  loss:  31.207727432250977\n",
            "iter:  4  i_batch:  0  loss:  31.223188400268555\n",
            "iter:  4  i_batch:  100  loss:  31.180438995361328\n",
            "iter:  4  i_batch:  200  loss:  31.10986328125\n",
            "iter:  4  i_batch:  300  loss:  31.112592697143555\n",
            "iter:  4  i_batch:  400  loss:  31.229507446289062\n",
            "iter:  4  i_batch:  500  loss:  31.20889663696289\n",
            "iter:  4  i_batch:  600  loss:  31.14251136779785\n"
          ]
        }
      ],
      "source": [
        "#encoders = [MLP(5, 10, 10, dropout=False),\n",
        "#            GRUWithLinear(12, 30, 15, flatten=True, batch_first=True)]\n",
        "\n",
        "encoders = [MLP(5, 10, 10, dropout=False),\n",
        "            GRU(12, 30, dropout=False, batch_first=True, flatten=True)]\n",
        "\n",
        "supcon_model = SupConModel(temperature=0.5, encoders=encoders, dim_ins=[10, 720], feat_dims=[40, 40], use_label=True).cuda()\n",
        "supcon_optim = optim.Adam(supcon_model.parameters())\n",
        "train_supcon_mimic(supcon_model, train_loader, supcon_optim, num_epoch=5)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "OV0057jWHSqT",
        "outputId": "13cce792-eb78-4fc0-9058-76d0dbde94c6"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.10/dist-packages/sklearn/linear_model/_logistic.py:458: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
            "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
            "\n",
            "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
            "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
            "Please also refer to the documentation for alternative solver options:\n",
            "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
            "  n_iter_i = _check_optimize_result(\n"
          ]
        }
      ],
      "source": [
        "supcon_model.eval()\n",
        "\n",
        "train_embeds_x1 = np.concatenate([supcon_model.get_embedding(data[0].float().cuda(), data[1].float().cuda())[0].detach().cpu().numpy() for data in eval_train_loader])\n",
        "train_embeds_x2 = np.concatenate([supcon_model.get_embedding(data[0].float().cuda(), data[1].float().cuda())[1].detach().cpu().numpy() for data in eval_train_loader])\n",
        "train_embeds = np.concatenate([train_embeds_x1, train_embeds_x2], axis=1)\n",
        "train_labels = np.concatenate([data[2].detach().cpu().numpy() for data in eval_train_loader])\n",
        "\n",
        "test_embeds_x1 = np.concatenate([supcon_model.get_embedding(data[0].float().cuda(), data[1].float().cuda())[0].detach().cpu().numpy() for data in eval_train_loader])\n",
        "test_embeds_x2 = np.concatenate([supcon_model.get_embedding(data[0].float().cuda(), data[1].float().cuda())[1].detach().cpu().numpy() for data in eval_train_loader])\n",
        "test_embeds = np.concatenate([test_embeds_x1, test_embeds_x2], axis=1)\n",
        "test_labels = np.concatenate([data[2].detach().cpu().numpy() for data in eval_train_loader])\n",
        "\n",
        "# Train Logistic Classifier\n",
        "clf = LogisticRegression(max_iter=200).fit(train_embeds, train_labels)\n",
        "score = clf.score(test_embeds, test_labels)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "SrSTDRxDHSqT",
        "outputId": "e1f2ef71-fd82-4b3f-d720-dd208bec9e7d"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "0.6641628022841375"
            ]
          },
          "execution_count": 20,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "score"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_K1VRs2vBo4E"
      },
      "source": [
        "#Sarcasm"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "id": "po4Sf4EiBo4K"
      },
      "outputs": [],
      "source": [
        "from datasets.affect.get_data import get_dataloader\n",
        "\n",
        "train_loader, valid_loader, test_loader = get_dataloader('/content/MultiBench/sarcasm.pkl',\n",
        "                                                         batch_size=128,\n",
        "                                                         data_type='sarcasm')\n",
        "\n",
        "eval_train_loader, eval_valid_loader, eval_test_loader = get_dataloader('/content/MultiBench/sarcasm.pkl',\n",
        "                                                                        batch_size=128,\n",
        "                                                                        data_type='sarcasm',\n",
        "                                                                        train_shuffle=False)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "paAPfIG0Bo4K"
      },
      "source": [
        "##FactorCL-SUP"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 40,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "lCXOY87MBo4K",
        "outputId": "ee0c0a82-24df-4814-98fc-8ba5ea9442b3"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "iter:  0  i_batch:  0  loss:  0.00022673769854009151\n",
            "iter:  1  i_batch:  0  loss:  -0.010921637527644634\n",
            "iter:  2  i_batch:  0  loss:  -0.019974887371063232\n",
            "iter:  3  i_batch:  0  loss:  -0.028301645070314407\n",
            "iter:  4  i_batch:  0  loss:  -0.03775062412023544\n",
            "iter:  5  i_batch:  0  loss:  -0.05263204872608185\n",
            "iter:  6  i_batch:  0  loss:  -0.06939707696437836\n",
            "iter:  7  i_batch:  0  loss:  -0.09028887748718262\n",
            "iter:  8  i_batch:  0  loss:  -0.12177552282810211\n",
            "iter:  9  i_batch:  0  loss:  -0.1584542989730835\n",
            "iter:  10  i_batch:  0  loss:  -0.20651768147945404\n",
            "iter:  11  i_batch:  0  loss:  -0.24353653192520142\n",
            "iter:  12  i_batch:  0  loss:  -0.29933953285217285\n",
            "iter:  13  i_batch:  0  loss:  -0.39696305990219116\n",
            "iter:  14  i_batch:  0  loss:  -0.4287801682949066\n",
            "iter:  15  i_batch:  0  loss:  -0.4027155935764313\n",
            "iter:  16  i_batch:  0  loss:  -0.44266486167907715\n",
            "iter:  17  i_batch:  0  loss:  -0.5260065197944641\n",
            "iter:  18  i_batch:  0  loss:  -0.5915610194206238\n",
            "iter:  19  i_batch:  0  loss:  -0.7357954978942871\n",
            "iter:  20  i_batch:  0  loss:  -0.7489528656005859\n",
            "iter:  21  i_batch:  0  loss:  -0.7687628865242004\n",
            "iter:  22  i_batch:  0  loss:  -0.9730485081672668\n",
            "iter:  23  i_batch:  0  loss:  -0.9870645999908447\n",
            "iter:  24  i_batch:  0  loss:  -1.058022141456604\n",
            "iter:  25  i_batch:  0  loss:  -1.051788568496704\n",
            "iter:  26  i_batch:  0  loss:  -1.2126520872116089\n",
            "iter:  27  i_batch:  0  loss:  -1.3974565267562866\n",
            "iter:  28  i_batch:  0  loss:  -1.4189915657043457\n",
            "iter:  29  i_batch:  0  loss:  -1.5475046634674072\n",
            "iter:  30  i_batch:  0  loss:  -1.766325831413269\n",
            "iter:  31  i_batch:  0  loss:  -1.7021958827972412\n",
            "iter:  32  i_batch:  0  loss:  -1.8907331228256226\n",
            "iter:  33  i_batch:  0  loss:  -1.9576383829116821\n",
            "iter:  34  i_batch:  0  loss:  -1.9860942363739014\n",
            "iter:  35  i_batch:  0  loss:  -2.1162188053131104\n",
            "iter:  36  i_batch:  0  loss:  -1.773028016090393\n",
            "iter:  37  i_batch:  0  loss:  -2.0484628677368164\n",
            "iter:  38  i_batch:  0  loss:  -1.825541377067566\n",
            "iter:  39  i_batch:  0  loss:  -2.3038570880889893\n",
            "iter:  40  i_batch:  0  loss:  -1.433091640472412\n",
            "iter:  41  i_batch:  0  loss:  -1.899625301361084\n",
            "iter:  42  i_batch:  0  loss:  -2.283717155456543\n",
            "iter:  43  i_batch:  0  loss:  -2.3012969493865967\n",
            "iter:  44  i_batch:  0  loss:  -2.25911808013916\n",
            "iter:  45  i_batch:  0  loss:  -2.341892719268799\n",
            "iter:  46  i_batch:  0  loss:  -2.395864963531494\n",
            "iter:  47  i_batch:  0  loss:  -2.553396701812744\n",
            "iter:  48  i_batch:  0  loss:  -2.3940608501434326\n",
            "iter:  49  i_batch:  0  loss:  -2.509097099304199\n",
            "iter:  50  i_batch:  0  loss:  -2.440009832382202\n",
            "iter:  51  i_batch:  0  loss:  -2.3597967624664307\n",
            "iter:  52  i_batch:  0  loss:  -2.5262451171875\n",
            "iter:  53  i_batch:  0  loss:  -2.7186007499694824\n",
            "iter:  54  i_batch:  0  loss:  -2.7624025344848633\n",
            "iter:  55  i_batch:  0  loss:  -2.0573763847351074\n",
            "iter:  56  i_batch:  0  loss:  -2.4475317001342773\n",
            "iter:  57  i_batch:  0  loss:  -2.641329288482666\n",
            "iter:  58  i_batch:  0  loss:  -2.5970282554626465\n",
            "iter:  59  i_batch:  0  loss:  -2.7980642318725586\n",
            "iter:  60  i_batch:  0  loss:  -2.694045066833496\n",
            "iter:  61  i_batch:  0  loss:  -2.5757060050964355\n",
            "iter:  62  i_batch:  0  loss:  -2.857912302017212\n",
            "iter:  63  i_batch:  0  loss:  -2.6499955654144287\n",
            "iter:  64  i_batch:  0  loss:  -2.220844268798828\n",
            "iter:  65  i_batch:  0  loss:  -2.581118583679199\n",
            "iter:  66  i_batch:  0  loss:  -2.2565157413482666\n",
            "iter:  67  i_batch:  0  loss:  -2.534916877746582\n",
            "iter:  68  i_batch:  0  loss:  -2.8464951515197754\n",
            "iter:  69  i_batch:  0  loss:  -2.438075065612793\n",
            "iter:  70  i_batch:  0  loss:  -2.7606804370880127\n",
            "iter:  71  i_batch:  0  loss:  -2.034738540649414\n",
            "iter:  72  i_batch:  0  loss:  -1.9310920238494873\n",
            "iter:  73  i_batch:  0  loss:  -2.9413986206054688\n",
            "iter:  74  i_batch:  0  loss:  -3.006438970565796\n",
            "iter:  75  i_batch:  0  loss:  -2.9395294189453125\n",
            "iter:  76  i_batch:  0  loss:  -3.01055908203125\n",
            "iter:  77  i_batch:  0  loss:  -3.1051008701324463\n",
            "iter:  78  i_batch:  0  loss:  -2.8905184268951416\n",
            "iter:  79  i_batch:  0  loss:  -2.7825329303741455\n",
            "iter:  80  i_batch:  0  loss:  -2.8001997470855713\n",
            "iter:  81  i_batch:  0  loss:  -2.843099355697632\n",
            "iter:  82  i_batch:  0  loss:  -2.40347957611084\n",
            "iter:  83  i_batch:  0  loss:  -2.491931676864624\n",
            "iter:  84  i_batch:  0  loss:  -2.6011252403259277\n",
            "iter:  85  i_batch:  0  loss:  -2.5312771797180176\n",
            "iter:  86  i_batch:  0  loss:  -2.655810832977295\n",
            "iter:  87  i_batch:  0  loss:  -2.7279114723205566\n",
            "iter:  88  i_batch:  0  loss:  -2.7898061275482178\n",
            "iter:  89  i_batch:  0  loss:  -2.998659610748291\n",
            "iter:  90  i_batch:  0  loss:  -2.860271453857422\n",
            "iter:  91  i_batch:  0  loss:  -2.4903526306152344\n",
            "iter:  92  i_batch:  0  loss:  -2.908328056335449\n",
            "iter:  93  i_batch:  0  loss:  -2.798001289367676\n",
            "iter:  94  i_batch:  0  loss:  -2.562255620956421\n",
            "iter:  95  i_batch:  0  loss:  -3.1179356575012207\n",
            "iter:  96  i_batch:  0  loss:  -2.7856342792510986\n",
            "iter:  97  i_batch:  0  loss:  -2.95184588432312\n",
            "iter:  98  i_batch:  0  loss:  -2.887889862060547\n",
            "iter:  99  i_batch:  0  loss:  -2.313115119934082\n"
          ]
        }
      ],
      "source": [
        "encoders = [Transformer(371, 40), Transformer(300, 40)]\n",
        "factorcl_sup = FactorCLSUP(encoders=encoders, feat_dims=[40, 40], y_ohe_dim=3).cuda()\n",
        "train_sup_sarcasm(factorcl_sup, train_loader, num_epoch=100, num_club_iter=1)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 52,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "fLEdAxnVBo4K",
        "outputId": "4cb80133-a3f3-4a77-f44f-5b580c5d3005"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.10/dist-packages/sklearn/utils/validation.py:1143: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n",
            "  y = column_or_1d(y, warn=True)\n"
          ]
        }
      ],
      "source": [
        "factorcl_sup.eval()\n",
        "\n",
        "train_embeds_x1 = np.concatenate([factorcl_sup.get_embedding(data[0][0].cuda(), data[0][2].cuda())[0].detach().cpu().numpy() for data in eval_train_loader])\n",
        "train_embeds_x2 = np.concatenate([factorcl_sup.get_embedding(data[0][0].cuda(), data[0][2].cuda())[1].detach().cpu().numpy() for data in eval_train_loader])\n",
        "train_embeds = np.concatenate([train_embeds_x1, train_embeds_x2], axis=1)\n",
        "train_labels = np.concatenate([data[3].detach().cpu().numpy() for data in eval_train_loader])\n",
        "train_labels = sarcasm_label(train_labels)\n",
        "\n",
        "test_embeds_x1 = np.concatenate([factorcl_sup.get_embedding(data[0][0].cuda(), data[0][2].cuda())[0].detach().cpu().numpy() for data in eval_test_loader])\n",
        "test_embeds_x2 = np.concatenate([factorcl_sup.get_embedding(data[0][0].cuda(), data[0][2].cuda())[1].detach().cpu().numpy() for data in eval_test_loader])\n",
        "test_embeds = np.concatenate([test_embeds_x1, test_embeds_x2], axis=1)\n",
        "test_labels = np.concatenate([data[3].detach().cpu().numpy() for data in eval_test_loader])\n",
        "test_labels = sarcasm_label(test_labels)\n",
        "\n",
        "# Train Logistic Classifier\n",
        "clf = LogisticRegression(max_iter=200).fit(train_embeds, train_labels)\n",
        "score = clf.score(test_embeds, test_labels)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 53,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "b7iZFTL7Bo4L",
        "outputId": "b84c1a9c-1a25-4462-91ae-80441c9e51d7"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "0.6666666666666666"
            ]
          },
          "metadata": {},
          "execution_count": 53
        }
      ],
      "source": [
        "score"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BQ2bKUZUBo4L"
      },
      "source": [
        "##FactorCL-SSL"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Ys1U5xc0Bo4L",
        "outputId": "8cd4430d-8d74-454a-c937-8865ad3b81ea"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "iter:  0  i_batch:  0  loss:  0.0006452705711126328\n",
            "iter:  1  i_batch:  0  loss:  -0.012466572225093842\n",
            "iter:  2  i_batch:  0  loss:  -0.042170800268650055\n",
            "iter:  3  i_batch:  0  loss:  -0.08974642306566238\n",
            "iter:  4  i_batch:  0  loss:  -0.18145647644996643\n",
            "iter:  5  i_batch:  0  loss:  -0.3081078827381134\n",
            "iter:  6  i_batch:  0  loss:  -0.4158934950828552\n",
            "iter:  7  i_batch:  0  loss:  -0.601329505443573\n",
            "iter:  8  i_batch:  0  loss:  -0.7000826597213745\n",
            "iter:  9  i_batch:  0  loss:  -0.8313818573951721\n",
            "iter:  10  i_batch:  0  loss:  -0.9973404407501221\n",
            "iter:  11  i_batch:  0  loss:  -1.053339958190918\n",
            "iter:  12  i_batch:  0  loss:  -1.1727815866470337\n",
            "iter:  13  i_batch:  0  loss:  -1.2500641345977783\n",
            "iter:  14  i_batch:  0  loss:  -1.4549574851989746\n",
            "iter:  15  i_batch:  0  loss:  -1.601521372795105\n",
            "iter:  16  i_batch:  0  loss:  -1.6972362995147705\n",
            "iter:  17  i_batch:  0  loss:  -1.7666714191436768\n",
            "iter:  18  i_batch:  0  loss:  -2.0024752616882324\n",
            "iter:  19  i_batch:  0  loss:  -2.244112968444824\n",
            "iter:  20  i_batch:  0  loss:  -2.3511226177215576\n",
            "iter:  21  i_batch:  0  loss:  -2.5767922401428223\n",
            "iter:  22  i_batch:  0  loss:  -2.6088366508483887\n",
            "iter:  23  i_batch:  0  loss:  -3.258474826812744\n",
            "iter:  24  i_batch:  0  loss:  -3.4310712814331055\n",
            "iter:  25  i_batch:  0  loss:  -3.3126354217529297\n",
            "iter:  26  i_batch:  0  loss:  -3.735377311706543\n",
            "iter:  27  i_batch:  0  loss:  -3.675428867340088\n",
            "iter:  28  i_batch:  0  loss:  -4.181037425994873\n",
            "iter:  29  i_batch:  0  loss:  -4.495911598205566\n",
            "iter:  30  i_batch:  0  loss:  -4.633365154266357\n",
            "iter:  31  i_batch:  0  loss:  -4.620932579040527\n",
            "iter:  32  i_batch:  0  loss:  -4.7253923416137695\n",
            "iter:  33  i_batch:  0  loss:  -5.132877826690674\n",
            "iter:  34  i_batch:  0  loss:  -5.3106536865234375\n",
            "iter:  35  i_batch:  0  loss:  -5.282912254333496\n",
            "iter:  36  i_batch:  0  loss:  -5.252291202545166\n",
            "iter:  37  i_batch:  0  loss:  -5.4765400886535645\n",
            "iter:  38  i_batch:  0  loss:  -5.090197563171387\n",
            "iter:  39  i_batch:  0  loss:  -4.9729108810424805\n",
            "iter:  40  i_batch:  0  loss:  -5.6346435546875\n",
            "iter:  41  i_batch:  0  loss:  -5.911499500274658\n",
            "iter:  42  i_batch:  0  loss:  -6.384456634521484\n",
            "iter:  43  i_batch:  0  loss:  -6.103287696838379\n",
            "iter:  44  i_batch:  0  loss:  -6.488170623779297\n",
            "iter:  45  i_batch:  0  loss:  -6.35069465637207\n",
            "iter:  46  i_batch:  0  loss:  -6.494596004486084\n",
            "iter:  47  i_batch:  0  loss:  -6.896271705627441\n",
            "iter:  48  i_batch:  0  loss:  -6.844601154327393\n",
            "iter:  49  i_batch:  0  loss:  -6.947994709014893\n",
            "iter:  50  i_batch:  0  loss:  -7.07313871383667\n",
            "iter:  51  i_batch:  0  loss:  -7.221263408660889\n",
            "iter:  52  i_batch:  0  loss:  -7.095343589782715\n",
            "iter:  53  i_batch:  0  loss:  -6.829585552215576\n",
            "iter:  54  i_batch:  0  loss:  -7.424076080322266\n",
            "iter:  55  i_batch:  0  loss:  -6.653810501098633\n",
            "iter:  56  i_batch:  0  loss:  -7.698529243469238\n",
            "iter:  57  i_batch:  0  loss:  -7.418730735778809\n",
            "iter:  58  i_batch:  0  loss:  -7.064931392669678\n",
            "iter:  59  i_batch:  0  loss:  -7.676008224487305\n",
            "iter:  60  i_batch:  0  loss:  -7.1770195960998535\n",
            "iter:  61  i_batch:  0  loss:  -7.444337844848633\n",
            "iter:  62  i_batch:  0  loss:  -7.535398960113525\n",
            "iter:  63  i_batch:  0  loss:  -7.8711981773376465\n",
            "iter:  64  i_batch:  0  loss:  -7.914351940155029\n",
            "iter:  65  i_batch:  0  loss:  -8.048909187316895\n",
            "iter:  66  i_batch:  0  loss:  -7.429112434387207\n",
            "iter:  67  i_batch:  0  loss:  -7.764862537384033\n",
            "iter:  68  i_batch:  0  loss:  -7.939072608947754\n",
            "iter:  69  i_batch:  0  loss:  -8.092193603515625\n",
            "iter:  70  i_batch:  0  loss:  -7.949345111846924\n",
            "iter:  71  i_batch:  0  loss:  -7.876206398010254\n",
            "iter:  72  i_batch:  0  loss:  -8.117385864257812\n",
            "iter:  73  i_batch:  0  loss:  -7.99519157409668\n",
            "iter:  74  i_batch:  0  loss:  -8.126961708068848\n",
            "iter:  75  i_batch:  0  loss:  -8.082456588745117\n",
            "iter:  76  i_batch:  0  loss:  -7.721489906311035\n",
            "iter:  77  i_batch:  0  loss:  -8.204150199890137\n",
            "iter:  78  i_batch:  0  loss:  -7.992454528808594\n",
            "iter:  79  i_batch:  0  loss:  -8.469961166381836\n",
            "iter:  80  i_batch:  0  loss:  -7.911052703857422\n",
            "iter:  81  i_batch:  0  loss:  -7.828105449676514\n",
            "iter:  82  i_batch:  0  loss:  -7.810555934906006\n",
            "iter:  83  i_batch:  0  loss:  -8.74462890625\n",
            "iter:  84  i_batch:  0  loss:  -8.445594787597656\n",
            "iter:  85  i_batch:  0  loss:  -8.421957969665527\n",
            "iter:  86  i_batch:  0  loss:  -8.416309356689453\n",
            "iter:  87  i_batch:  0  loss:  -8.421586990356445\n",
            "iter:  88  i_batch:  0  loss:  -8.095366477966309\n",
            "iter:  89  i_batch:  0  loss:  -8.776723861694336\n",
            "iter:  90  i_batch:  0  loss:  -8.161399841308594\n",
            "iter:  91  i_batch:  0  loss:  -9.122365951538086\n",
            "iter:  92  i_batch:  0  loss:  -8.383651733398438\n",
            "iter:  93  i_batch:  0  loss:  -8.461047172546387\n",
            "iter:  94  i_batch:  0  loss:  -8.793039321899414\n",
            "iter:  95  i_batch:  0  loss:  -9.0235595703125\n",
            "iter:  96  i_batch:  0  loss:  -8.830117225646973\n",
            "iter:  97  i_batch:  0  loss:  -8.772846221923828\n",
            "iter:  98  i_batch:  0  loss:  -8.866401672363281\n",
            "iter:  99  i_batch:  0  loss:  -8.066728591918945\n"
          ]
        }
      ],
      "source": [
        "encoders = [Transformer(371, 40), Transformer(300, 40)]\n",
        "factorcl_ssl = FactorCLSSL(encoders=encoders, feat_dims=[40, 40], y_ohe_dim=3).cuda()\n",
        "train_ssl_sarcasm(factorcl_ssl, train_loader, num_epoch=100, num_club_iter=1)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "IwOOVsr8Bo4L",
        "outputId": "3f29f754-52cc-4802-e375-0de7dfefba2a"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.10/dist-packages/sklearn/utils/validation.py:1143: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n",
            "  y = column_or_1d(y, warn=True)\n",
            "/usr/local/lib/python3.10/dist-packages/sklearn/linear_model/_logistic.py:458: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
            "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
            "\n",
            "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
            "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
            "Please also refer to the documentation for alternative solver options:\n",
            "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
            "  n_iter_i = _check_optimize_result(\n"
          ]
        }
      ],
      "source": [
        "factorcl_ssl.eval()\n",
        "\n",
        "train_embeds_x1 = np.concatenate([factorcl_ssl.get_embedding(data[0][0].cuda(), data[0][2].cuda())[0].detach().cpu().numpy() for data in eval_train_loader])\n",
        "train_embeds_x2 = np.concatenate([factorcl_ssl.get_embedding(data[0][0].cuda(), data[0][2].cuda())[1].detach().cpu().numpy() for data in eval_train_loader])\n",
        "train_embeds = np.concatenate([train_embeds_x1, train_embeds_x2], axis=1)\n",
        "train_labels = np.concatenate([data[3].detach().cpu().numpy() for data in eval_train_loader])\n",
        "train_labels = sarcasm_label(train_labels)\n",
        "\n",
        "test_embeds_x1 = np.concatenate([factorcl_ssl.get_embedding(data[0][0].cuda(), data[0][2].cuda())[0].detach().cpu().numpy() for data in eval_test_loader])\n",
        "test_embeds_x2 = np.concatenate([factorcl_ssl.get_embedding(data[0][0].cuda(), data[0][2].cuda())[1].detach().cpu().numpy() for data in eval_test_loader])\n",
        "test_embeds = np.concatenate([test_embeds_x1, test_embeds_x2], axis=1)\n",
        "test_labels = np.concatenate([data[3].detach().cpu().numpy() for data in eval_test_loader])\n",
        "test_labels = sarcasm_label(test_labels)\n",
        "\n",
        "# Train Logistic Classifier\n",
        "clf = LogisticRegression(max_iter=200).fit(train_embeds, train_labels)\n",
        "score = clf.score(test_embeds, test_labels)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "j6SAQ36fBo4L",
        "outputId": "ab36a954-4f1d-484b-ee08-cc6321a23c46"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "0.6014492753623188"
            ]
          },
          "execution_count": 55,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "score"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wtnYxDD3Bo4L"
      },
      "source": [
        "##SupCon"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "5SsFqRPJBo4L",
        "outputId": "df6cc126-b4ad-4194-9181-1dfbcf81d977"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "iter:  0  i_batch:  0  loss:  49.70327377319336\n",
            "iter:  1  i_batch:  0  loss:  39.85050964355469\n",
            "iter:  2  i_batch:  0  loss:  39.70650863647461\n",
            "iter:  3  i_batch:  0  loss:  39.652442932128906\n",
            "iter:  4  i_batch:  0  loss:  39.626487731933594\n",
            "iter:  5  i_batch:  0  loss:  39.60759353637695\n",
            "iter:  6  i_batch:  0  loss:  39.59840393066406\n",
            "iter:  7  i_batch:  0  loss:  39.597259521484375\n",
            "iter:  8  i_batch:  0  loss:  39.5950927734375\n",
            "iter:  9  i_batch:  0  loss:  39.590240478515625\n",
            "iter:  10  i_batch:  0  loss:  39.590660095214844\n",
            "iter:  11  i_batch:  0  loss:  39.59142303466797\n",
            "iter:  12  i_batch:  0  loss:  39.58800506591797\n",
            "iter:  13  i_batch:  0  loss:  39.587642669677734\n",
            "iter:  14  i_batch:  0  loss:  39.589324951171875\n",
            "iter:  15  i_batch:  0  loss:  39.58818054199219\n",
            "iter:  16  i_batch:  0  loss:  39.58576202392578\n",
            "iter:  17  i_batch:  0  loss:  39.58720397949219\n",
            "iter:  18  i_batch:  0  loss:  39.5872802734375\n",
            "iter:  19  i_batch:  0  loss:  39.585548400878906\n"
          ]
        }
      ],
      "source": [
        "encoders = [Transformer(371, 40), Transformer(300, 40)]\n",
        "\n",
        "# set use_label=False for SimCLR\n",
        "supcon_model = SupConModel(temperature=0.5, encoders=encoders, dim_ins=[40, 40], feat_dims=[40, 40], use_label=True).cuda()\n",
        "\n",
        "supcon_optim = optim.Adam(supcon_model.parameters())\n",
        "train_supcon_sarcasm(supcon_model, train_loader, supcon_optim, modalities=[0,2], num_epoch=20)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "eX4qYFGgBo4L",
        "outputId": "e971d70a-0c66-497f-fe2b-87343121a09a"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.10/dist-packages/sklearn/utils/validation.py:1143: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n",
            "  y = column_or_1d(y, warn=True)\n"
          ]
        }
      ],
      "source": [
        "supcon_model.eval()\n",
        "\n",
        "train_embeds_x1 = np.concatenate([supcon_model.get_embedding(data[0][0].cuda(), data[0][2].cuda())[0].detach().cpu().numpy() for data in eval_train_loader])\n",
        "train_embeds_x2 = np.concatenate([supcon_model.get_embedding(data[0][0].cuda(), data[0][2].cuda())[1].detach().cpu().numpy() for data in eval_train_loader])\n",
        "train_embeds = np.concatenate([train_embeds_x1, train_embeds_x2], axis=1)\n",
        "train_labels = np.concatenate([data[3].detach().cpu().numpy() for data in eval_train_loader])\n",
        "train_labels = sarcasm_label(train_labels)\n",
        "\n",
        "test_embeds_x1 = np.concatenate([supcon_model.get_embedding(data[0][0].cuda(), data[0][2].cuda())[0].detach().cpu().numpy() for data in eval_test_loader])\n",
        "test_embeds_x2 = np.concatenate([supcon_model.get_embedding(data[0][0].cuda(), data[0][2].cuda())[1].detach().cpu().numpy() for data in eval_test_loader])\n",
        "test_embeds = np.concatenate([test_embeds_x1, test_embeds_x2], axis=1)\n",
        "test_labels = np.concatenate([data[3].detach().cpu().numpy() for data in eval_test_loader])\n",
        "test_labels = sarcasm_label(test_labels)\n",
        "\n",
        "# Train Logistic Classifier\n",
        "clf = LogisticRegression(max_iter=200).fit(train_embeds, train_labels)\n",
        "score = clf.score(test_embeds, test_labels)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "VLW-V0KkBo4M",
        "outputId": "d4aafbf0-882a-43cd-9600-5d5a18f0985d"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "0.572463768115942"
            ]
          },
          "execution_count": 52,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "score"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [
        "AQ8W129stEfo",
        "Jt1RD-kdDYqj",
        "_K1VRs2vBo4E",
        "WhwX0KRkHSqL",
        "QyxFC1whHSqS",
        "CJBo3NAtHSqT",
        "WmI_QMOgHSqT"
      ],
      "machine_shape": "hm",
      "provenance": [],
      "gpuType": "V100"
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}