{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "machine_shape": "hm",
      "gpuType": "A100"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "# FactorCL on IRFL"
      ],
      "metadata": {
        "id": "LTEV_elPBK32"
      }
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fWNBW_QzYrvs"
      },
      "source": [
        "##Preparation"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "BgGoqg0DYP92",
        "outputId": "594cd117-4762-494f-9167-220079a111d4"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Collecting datasets\n",
            "  Downloading datasets-2.14.6-py3-none-any.whl (493 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m493.7/493.7 kB\u001b[0m \u001b[31m7.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from datasets) (1.23.5)\n",
            "Requirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (9.0.0)\n",
            "Collecting dill<0.3.8,>=0.3.0 (from datasets)\n",
            "  Downloading dill-0.3.7-py3-none-any.whl (115 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m115.3/115.3 kB\u001b[0m \u001b[31m14.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (1.5.3)\n",
            "Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (2.31.0)\n",
            "Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (4.66.1)\n",
            "Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets) (3.4.1)\n",
            "Collecting multiprocess (from datasets)\n",
            "  Downloading multiprocess-0.70.15-py310-none-any.whl (134 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m18.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: fsspec[http]<=2023.10.0,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (2023.6.0)\n",
            "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.8.6)\n",
            "Collecting huggingface-hub<1.0.0,>=0.14.0 (from datasets)\n",
            "  Downloading huggingface_hub-0.19.0-py3-none-any.whl (311 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m311.2/311.2 kB\u001b[0m \u001b[31m30.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets) (23.2)\n",
            "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (6.0.1)\n",
            "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (23.1.0)\n",
            "Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (3.3.2)\n",
            "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.0.4)\n",
            "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.3)\n",
            "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.9.2)\n",
            "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.4.0)\n",
            "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.1)\n",
            "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0.0,>=0.14.0->datasets) (3.13.1)\n",
            "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0.0,>=0.14.0->datasets) (4.5.0)\n",
            "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (3.4)\n",
            "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (2.0.7)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (2023.7.22)\n",
            "Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.8.2)\n",
            "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2023.3.post1)\n",
            "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.1->pandas->datasets) (1.16.0)\n",
            "Installing collected packages: dill, multiprocess, huggingface-hub, datasets\n",
            "Successfully installed datasets-2.14.6 dill-0.3.7 huggingface-hub-0.19.0 multiprocess-0.70.15\n",
            "Collecting transformers\n",
            "  Downloading transformers-4.35.0-py3-none-any.whl (7.9 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.9/7.9 MB\u001b[0m \u001b[31m50.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.13.1)\n",
            "Requirement already satisfied: huggingface-hub<1.0,>=0.16.4 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.19.0)\n",
            "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (1.23.5)\n",
            "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (23.2)\n",
            "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0.1)\n",
            "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2023.6.3)\n",
            "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.31.0)\n",
            "Collecting tokenizers<0.15,>=0.14 (from transformers)\n",
            "  Downloading tokenizers-0.14.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.8 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.8/3.8 MB\u001b[0m \u001b[31m104.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hCollecting safetensors>=0.3.1 (from transformers)\n",
            "  Downloading safetensors-0.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m81.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.66.1)\n",
            "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.16.4->transformers) (2023.6.0)\n",
            "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.16.4->transformers) (4.5.0)\n",
            "Collecting huggingface-hub<1.0,>=0.16.4 (from transformers)\n",
            "  Downloading huggingface_hub-0.17.3-py3-none-any.whl (295 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m295.0/295.0 kB\u001b[0m \u001b[31m35.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.3.2)\n",
            "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.4)\n",
            "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.0.7)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2023.7.22)\n",
            "Installing collected packages: safetensors, huggingface-hub, tokenizers, transformers\n",
            "  Attempting uninstall: huggingface-hub\n",
            "    Found existing installation: huggingface-hub 0.19.0\n",
            "    Uninstalling huggingface-hub-0.19.0:\n",
            "      Successfully uninstalled huggingface-hub-0.19.0\n",
            "Successfully installed huggingface-hub-0.17.3 safetensors-0.4.0 tokenizers-0.14.1 transformers-4.35.0\n"
          ]
        }
      ],
      "source": [
        "!pip install datasets\n",
        "!pip install transformers"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "r1KGIlFvYqYH",
        "outputId": "94eb6864-6e97-4f4a-9224-23638e79ceb4"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Cloning into 'IRFL'...\n",
            "remote: Enumerating objects: 49, done.\u001b[K\n",
            "remote: Counting objects: 100% (13/13), done.\u001b[K\n",
            "remote: Compressing objects: 100% (12/12), done.\u001b[K\n",
            "remote: Total 49 (delta 1), reused 12 (delta 1), pack-reused 36\u001b[K\n",
            "Receiving objects: 100% (49/49), 45.70 MiB | 10.98 MiB/s, done.\n",
            "Resolving deltas: 100% (2/2), done.\n"
          ]
        }
      ],
      "source": [
        "!git clone https://github.com/irfl-dataset/IRFL"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "id": "S3RpE80CYxIY"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "import numpy as np\n",
        "import pandas as pd\n",
        "import json\n",
        "\n",
        "from torch.utils.data import Dataset\n",
        "from torch.utils.data import DataLoader\n",
        "import torch.optim as optim\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from datasets import load_dataset\n",
        "\n",
        "import PIL.Image as Image\n",
        "import requests\n",
        "from urllib.request import urlopen"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "!git clone https://github.com/pliang279/FactorCL"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "IHKnbVkoASPI",
        "outputId": "c799ac31-899e-4e71-d055-10fdff9ce032"
      },
      "execution_count": 4,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Cloning into 'FactorCL'...\n",
            "remote: Enumerating objects: 104, done.\u001b[K\n",
            "remote: Counting objects: 100% (104/104), done.\u001b[K\n",
            "remote: Compressing objects: 100% (96/96), done.\u001b[K\n",
            "remote: Total 104 (delta 47), reused 0 (delta 0), pack-reused 0\u001b[K\n",
            "Receiving objects: 100% (104/104), 268.97 KiB | 9.96 MiB/s, done.\n",
            "Resolving deltas: 100% (47/47), done.\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "%cd FactorCL"
      ],
      "metadata": {
        "id": "pyWpONMApjAO",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "a6340568-d027-4938-b720-9a29ba8216c1"
      },
      "execution_count": 5,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "/content/FactorCL\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "import os\n",
        "import sys\n",
        "from torch.utils.data import DataLoader\n",
        "from sklearn.linear_model import LogisticRegression\n",
        "from datasets import load_dataset\n",
        "from transformers import AutoProcessor, CLIPModel\n",
        "\n",
        "from IRFL_model import*"
      ],
      "metadata": {
        "id": "A0mgHF3bAgvo"
      },
      "execution_count": 6,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "##IRFL Dataset"
      ],
      "metadata": {
        "id": "AVfFndkOl_Tr"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "simile_df = pd.read_csv('/content/IRFL/assets/tasks/simile_understanding_task.csv')\n",
        "idiom_df = pd.read_csv('/content/IRFL/assets/tasks/idiom_understanding_task.csv')\n",
        "metaphor_df = pd.read_csv('/content/IRFL/assets/tasks/metaphor_understanding_task.csv')"
      ],
      "metadata": {
        "id": "p8cqWO_Ul9BG"
      },
      "execution_count": 7,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def process_df(df):\n",
        "  distractors_urls = df['distractors'].to_list()\n",
        "  answers_urls = df['distractors'].to_list()\n",
        "  phrases = df['phrase'].to_list()\n",
        "  fig_types = df['figurative_type'].to_list()\n",
        "\n",
        "  distractors = []\n",
        "  answers = []\n",
        "  texts = []\n",
        "  types = []\n",
        "\n",
        "  for i in range(len(distractors_urls)):\n",
        "    print(f'{i}/{len(distractors_urls)}')\n",
        "    try:\n",
        "      d_urls = distractors_urls[i]\n",
        "      distractor = [Image.open(urlopen(url)) for url in eval(d_urls)]\n",
        "\n",
        "      a_urls = answers_urls[i]\n",
        "      answer = Image.open(urlopen(eval(a_urls)[0]))\n",
        "\n",
        "      text = phrases[i]\n",
        "      fig_type = fig_types[i]\n",
        "\n",
        "      distractors.append(distractor)\n",
        "      answers.append(answer)\n",
        "      texts.append(text)\n",
        "      types.append(fig_type)\n",
        "    except:\n",
        "      continue\n",
        "\n",
        "  return distractors, answers, texts, types\n",
        "\n",
        "\n",
        "\n",
        "def collate_fn(batch):\n",
        "    images = [data[0] for data in batch]\n",
        "    texts = [data[1] for data in batch]\n",
        "    labels = [data[2] for data in batch]\n",
        "\n",
        "    return images, texts, torch.tensor(labels, dtype=int)\n",
        "\n",
        "\n",
        "class FigTypeDataset(Dataset):\n",
        "    def __init__(self, answers, texts, types):\n",
        "        self.types= types\n",
        "        self.images = answers\n",
        "        self.texts = texts\n",
        "\n",
        "        self.type_map = {'idiom': 0, 'simile': 1, 'metaphor': 2}\n",
        "\n",
        "        self.labels = list(map(lambda x: self.type_map[x], self.types))\n",
        "\n",
        "\n",
        "    def __len__(self):\n",
        "        return len(self.images)\n",
        "\n",
        "    def __getitem__(self, idx):\n",
        "        return self.images[idx], self.texts[idx], self.labels[idx]"
      ],
      "metadata": {
        "id": "l70txHYvl4rV"
      },
      "execution_count": 8,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def process_fn(batch):\n",
        "    images, texts, contrastive_labels = batch\n",
        "    batch = processor(images=images, text=texts, padding=True, return_tensors='pt')\n",
        "\n",
        "    return batch, contrastive_labels\n",
        "\n",
        "\n",
        "def get_embeds(model, processor, train_loader, test_loader):\n",
        "    train_embeds = []\n",
        "    train_labels = []\n",
        "    test_embeds = []\n",
        "    test_labels = []\n",
        "    for i_batch, x in enumerate(train_loader):\n",
        "\n",
        "        inputs, label = process_fn(x)\n",
        "        inputs, label = inputs.to(device), label.to(device)\n",
        "\n",
        "        outputs = model(**inputs)\n",
        "        image_embeds = outputs.image_embeds.detach().cpu().numpy()\n",
        "        text_embeds = outputs.text_embeds.detach().cpu().numpy()\n",
        "\n",
        "        embeds = np.concatenate([image_embeds, text_embeds], axis=1)\n",
        "        train_embeds.append(embeds)\n",
        "        train_labels.append(label.detach().cpu().numpy())\n",
        "\n",
        "    for i_batch, x in enumerate(test_loader):\n",
        "\n",
        "        inputs, label = process_fn(x)\n",
        "        inputs, label = inputs.to(device), label.to(device)\n",
        "\n",
        "        outputs = model(**inputs)\n",
        "        image_embeds = outputs.image_embeds.detach().cpu().numpy()\n",
        "        text_embeds = outputs.text_embeds.detach().cpu().numpy()\n",
        "\n",
        "        embeds = np.concatenate([image_embeds, text_embeds], axis=1)\n",
        "        test_embeds.append(embeds)\n",
        "        test_labels.append(label.detach().cpu().numpy())\n",
        "\n",
        "    train_embeds = np.concatenate(train_embeds, axis=0)\n",
        "    test_embeds = np.concatenate(test_embeds, axis=0)\n",
        "    train_labels = np.concatenate(train_labels, axis=0)\n",
        "    test_labels = np.concatenate(test_labels, axis=0)\n",
        "\n",
        "    return train_embeds, train_labels, test_embeds, test_labels"
      ],
      "metadata": {
        "id": "-xCFTuq-p_Sn"
      },
      "execution_count": 16,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "distractors_simile, answers_simile, texts_simile, types_simile = process_df(simile_df)\n",
        "distractors_idiom, answers_idiom, texts_idiom, types_idiom = process_df(idiom_df)\n",
        "distractors_metaphor, answers_metaphor, texts_metaphor, types_metaphor = process_df(metaphor_df)"
      ],
      "metadata": {
        "id": "wm2643lupH0E",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "ab35becd-ab2b-4e59-9fa7-b7edb3330663"
      },
      "execution_count": 10,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "0/277\n",
            "1/277\n",
            "2/277\n",
            "3/277\n",
            "4/277\n",
            "5/277\n",
            "6/277\n",
            "7/277\n",
            "8/277\n",
            "9/277\n",
            "10/277\n",
            "11/277\n",
            "12/277\n",
            "13/277\n",
            "14/277\n",
            "15/277\n",
            "16/277\n",
            "17/277\n",
            "18/277\n",
            "19/277\n",
            "20/277\n",
            "21/277\n",
            "22/277\n",
            "23/277\n",
            "24/277\n",
            "25/277\n",
            "26/277\n",
            "27/277\n",
            "28/277\n",
            "29/277\n",
            "30/277\n",
            "31/277\n",
            "32/277\n",
            "33/277\n",
            "34/277\n",
            "35/277\n",
            "36/277\n",
            "37/277\n",
            "38/277\n",
            "39/277\n",
            "40/277\n",
            "41/277\n",
            "42/277\n",
            "43/277\n",
            "44/277\n",
            "45/277\n",
            "46/277\n",
            "47/277\n",
            "48/277\n",
            "49/277\n",
            "50/277\n",
            "51/277\n",
            "52/277\n",
            "53/277\n",
            "54/277\n",
            "55/277\n",
            "56/277\n",
            "57/277\n",
            "58/277\n",
            "59/277\n",
            "60/277\n",
            "61/277\n",
            "62/277\n",
            "63/277\n",
            "64/277\n",
            "65/277\n",
            "66/277\n",
            "67/277\n",
            "68/277\n",
            "69/277\n",
            "70/277\n",
            "71/277\n",
            "72/277\n",
            "73/277\n",
            "74/277\n",
            "75/277\n",
            "76/277\n",
            "77/277\n",
            "78/277\n",
            "79/277\n",
            "80/277\n",
            "81/277\n",
            "82/277\n",
            "83/277\n",
            "84/277\n",
            "85/277\n",
            "86/277\n",
            "87/277\n",
            "88/277\n",
            "89/277\n",
            "90/277\n",
            "91/277\n",
            "92/277\n",
            "93/277\n",
            "94/277\n",
            "95/277\n",
            "96/277\n",
            "97/277\n",
            "98/277\n",
            "99/277\n",
            "100/277\n",
            "101/277\n",
            "102/277\n",
            "103/277\n",
            "104/277\n",
            "105/277\n",
            "106/277\n",
            "107/277\n",
            "108/277\n",
            "109/277\n",
            "110/277\n",
            "111/277\n",
            "112/277\n",
            "113/277\n",
            "114/277\n",
            "115/277\n",
            "116/277\n",
            "117/277\n",
            "118/277\n",
            "119/277\n",
            "120/277\n",
            "121/277\n",
            "122/277\n",
            "123/277\n",
            "124/277\n",
            "125/277\n",
            "126/277\n",
            "127/277\n",
            "128/277\n",
            "129/277\n",
            "130/277\n",
            "131/277\n",
            "132/277\n",
            "133/277\n",
            "134/277\n",
            "135/277\n",
            "136/277\n",
            "137/277\n",
            "138/277\n",
            "139/277\n",
            "140/277\n",
            "141/277\n",
            "142/277\n",
            "143/277\n",
            "144/277\n",
            "145/277\n",
            "146/277\n",
            "147/277\n",
            "148/277\n",
            "149/277\n",
            "150/277\n",
            "151/277\n",
            "152/277\n",
            "153/277\n",
            "154/277\n",
            "155/277\n",
            "156/277\n",
            "157/277\n",
            "158/277\n",
            "159/277\n",
            "160/277\n",
            "161/277\n",
            "162/277\n",
            "163/277\n",
            "164/277\n",
            "165/277\n",
            "166/277\n",
            "167/277\n",
            "168/277\n",
            "169/277\n",
            "170/277\n",
            "171/277\n",
            "172/277\n",
            "173/277\n",
            "174/277\n",
            "175/277\n",
            "176/277\n",
            "177/277\n",
            "178/277\n",
            "179/277\n",
            "180/277\n",
            "181/277\n",
            "182/277\n",
            "183/277\n",
            "184/277\n",
            "185/277\n",
            "186/277\n",
            "187/277\n",
            "188/277\n",
            "189/277\n",
            "190/277\n",
            "191/277\n",
            "192/277\n",
            "193/277\n",
            "194/277\n",
            "195/277\n",
            "196/277\n",
            "197/277\n",
            "198/277\n",
            "199/277\n",
            "200/277\n",
            "201/277\n",
            "202/277\n",
            "203/277\n",
            "204/277\n",
            "205/277\n",
            "206/277\n",
            "207/277\n",
            "208/277\n",
            "209/277\n",
            "210/277\n",
            "211/277\n",
            "212/277\n",
            "213/277\n",
            "214/277\n",
            "215/277\n",
            "216/277\n",
            "217/277\n",
            "218/277\n",
            "219/277\n",
            "220/277\n",
            "221/277\n",
            "222/277\n",
            "223/277\n",
            "224/277\n",
            "225/277\n",
            "226/277\n",
            "227/277\n",
            "228/277\n",
            "229/277\n",
            "230/277\n",
            "231/277\n",
            "232/277\n",
            "233/277\n",
            "234/277\n",
            "235/277\n",
            "236/277\n",
            "237/277\n",
            "238/277\n",
            "239/277\n",
            "240/277\n",
            "241/277\n",
            "242/277\n",
            "243/277\n",
            "244/277\n",
            "245/277\n",
            "246/277\n",
            "247/277\n",
            "248/277\n",
            "249/277\n",
            "250/277\n",
            "251/277\n",
            "252/277\n",
            "253/277\n",
            "254/277\n",
            "255/277\n",
            "256/277\n",
            "257/277\n",
            "258/277\n",
            "259/277\n",
            "260/277\n",
            "261/277\n",
            "262/277\n",
            "263/277\n",
            "264/277\n",
            "265/277\n",
            "266/277\n",
            "267/277\n",
            "268/277\n",
            "269/277\n",
            "270/277\n",
            "271/277\n",
            "272/277\n",
            "273/277\n",
            "274/277\n",
            "275/277\n",
            "276/277\n",
            "0/200\n",
            "1/200\n",
            "2/200\n",
            "3/200\n",
            "4/200\n",
            "5/200\n",
            "6/200\n",
            "7/200\n",
            "8/200\n",
            "9/200\n",
            "10/200\n",
            "11/200\n",
            "12/200\n",
            "13/200\n",
            "14/200\n",
            "15/200\n",
            "16/200\n",
            "17/200\n",
            "18/200\n",
            "19/200\n",
            "20/200\n",
            "21/200\n",
            "22/200\n",
            "23/200\n",
            "24/200\n",
            "25/200\n",
            "26/200\n",
            "27/200\n",
            "28/200\n",
            "29/200\n",
            "30/200\n",
            "31/200\n",
            "32/200\n",
            "33/200\n",
            "34/200\n",
            "35/200\n",
            "36/200\n",
            "37/200\n",
            "38/200\n",
            "39/200\n",
            "40/200\n",
            "41/200\n",
            "42/200\n",
            "43/200\n",
            "44/200\n",
            "45/200\n",
            "46/200\n",
            "47/200\n",
            "48/200\n",
            "49/200\n",
            "50/200\n",
            "51/200\n",
            "52/200\n",
            "53/200\n",
            "54/200\n",
            "55/200\n",
            "56/200\n",
            "57/200\n",
            "58/200\n",
            "59/200\n",
            "60/200\n",
            "61/200\n",
            "62/200\n",
            "63/200\n",
            "64/200\n",
            "65/200\n",
            "66/200\n",
            "67/200\n",
            "68/200\n",
            "69/200\n",
            "70/200\n",
            "71/200\n",
            "72/200\n",
            "73/200\n",
            "74/200\n",
            "75/200\n",
            "76/200\n",
            "77/200\n",
            "78/200\n",
            "79/200\n",
            "80/200\n",
            "81/200\n",
            "82/200\n",
            "83/200\n",
            "84/200\n",
            "85/200\n",
            "86/200\n",
            "87/200\n",
            "88/200\n",
            "89/200\n",
            "90/200\n",
            "91/200\n",
            "92/200\n",
            "93/200\n",
            "94/200\n",
            "95/200\n",
            "96/200\n",
            "97/200\n",
            "98/200\n",
            "99/200\n",
            "100/200\n",
            "101/200\n",
            "102/200\n",
            "103/200\n",
            "104/200\n",
            "105/200\n",
            "106/200\n",
            "107/200\n",
            "108/200\n",
            "109/200\n",
            "110/200\n",
            "111/200\n",
            "112/200\n",
            "113/200\n",
            "114/200\n",
            "115/200\n",
            "116/200\n",
            "117/200\n",
            "118/200\n",
            "119/200\n",
            "120/200\n",
            "121/200\n",
            "122/200\n",
            "123/200\n",
            "124/200\n",
            "125/200\n",
            "126/200\n",
            "127/200\n",
            "128/200\n",
            "129/200\n",
            "130/200\n",
            "131/200\n",
            "132/200\n",
            "133/200\n",
            "134/200\n",
            "135/200\n",
            "136/200\n",
            "137/200\n",
            "138/200\n",
            "139/200\n",
            "140/200\n",
            "141/200\n",
            "142/200\n",
            "143/200\n",
            "144/200\n",
            "145/200\n",
            "146/200\n",
            "147/200\n",
            "148/200\n",
            "149/200\n",
            "150/200\n",
            "151/200\n",
            "152/200\n",
            "153/200\n",
            "154/200\n",
            "155/200\n",
            "156/200\n",
            "157/200\n",
            "158/200\n",
            "159/200\n",
            "160/200\n",
            "161/200\n",
            "162/200\n",
            "163/200\n",
            "164/200\n",
            "165/200\n",
            "166/200\n",
            "167/200\n",
            "168/200\n",
            "169/200\n",
            "170/200\n",
            "171/200\n",
            "172/200\n",
            "173/200\n",
            "174/200\n",
            "175/200\n",
            "176/200\n",
            "177/200\n",
            "178/200\n",
            "179/200\n",
            "180/200\n",
            "181/200\n",
            "182/200\n",
            "183/200\n",
            "184/200\n",
            "185/200\n",
            "186/200\n",
            "187/200\n",
            "188/200\n",
            "189/200\n",
            "190/200\n",
            "191/200\n",
            "192/200\n",
            "193/200\n",
            "194/200\n",
            "195/200\n",
            "196/200\n",
            "197/200\n",
            "198/200\n",
            "199/200\n",
            "0/333\n",
            "1/333\n",
            "2/333\n",
            "3/333\n",
            "4/333\n",
            "5/333\n",
            "6/333\n",
            "7/333\n",
            "8/333\n",
            "9/333\n",
            "10/333\n",
            "11/333\n",
            "12/333\n",
            "13/333\n",
            "14/333\n",
            "15/333\n",
            "16/333\n",
            "17/333\n",
            "18/333\n",
            "19/333\n",
            "20/333\n",
            "21/333\n",
            "22/333\n",
            "23/333\n",
            "24/333\n",
            "25/333\n",
            "26/333\n",
            "27/333\n",
            "28/333\n",
            "29/333\n",
            "30/333\n",
            "31/333\n",
            "32/333\n",
            "33/333\n",
            "34/333\n",
            "35/333\n",
            "36/333\n",
            "37/333\n",
            "38/333\n",
            "39/333\n",
            "40/333\n",
            "41/333\n",
            "42/333\n",
            "43/333\n",
            "44/333\n",
            "45/333\n",
            "46/333\n",
            "47/333\n",
            "48/333\n",
            "49/333\n",
            "50/333\n",
            "51/333\n",
            "52/333\n",
            "53/333\n",
            "54/333\n",
            "55/333\n",
            "56/333\n",
            "57/333\n",
            "58/333\n",
            "59/333\n",
            "60/333\n",
            "61/333\n",
            "62/333\n",
            "63/333\n",
            "64/333\n",
            "65/333\n",
            "66/333\n",
            "67/333\n",
            "68/333\n",
            "69/333\n",
            "70/333\n",
            "71/333\n",
            "72/333\n",
            "73/333\n",
            "74/333\n",
            "75/333\n",
            "76/333\n",
            "77/333\n",
            "78/333\n",
            "79/333\n",
            "80/333\n",
            "81/333\n",
            "82/333\n",
            "83/333\n",
            "84/333\n",
            "85/333\n",
            "86/333\n",
            "87/333\n",
            "88/333\n",
            "89/333\n",
            "90/333\n",
            "91/333\n",
            "92/333\n",
            "93/333\n",
            "94/333\n",
            "95/333\n",
            "96/333\n",
            "97/333\n",
            "98/333\n",
            "99/333\n",
            "100/333\n",
            "101/333\n",
            "102/333\n",
            "103/333\n",
            "104/333\n",
            "105/333\n",
            "106/333\n",
            "107/333\n",
            "108/333\n",
            "109/333\n",
            "110/333\n",
            "111/333\n",
            "112/333\n",
            "113/333\n",
            "114/333\n",
            "115/333\n",
            "116/333\n",
            "117/333\n",
            "118/333\n",
            "119/333\n",
            "120/333\n",
            "121/333\n",
            "122/333\n",
            "123/333\n",
            "124/333\n",
            "125/333\n",
            "126/333\n",
            "127/333\n",
            "128/333\n",
            "129/333\n",
            "130/333\n",
            "131/333\n",
            "132/333\n",
            "133/333\n",
            "134/333\n",
            "135/333\n",
            "136/333\n",
            "137/333\n",
            "138/333\n",
            "139/333\n",
            "140/333\n",
            "141/333\n",
            "142/333\n",
            "143/333\n",
            "144/333\n",
            "145/333\n",
            "146/333\n",
            "147/333\n",
            "148/333\n",
            "149/333\n",
            "150/333\n",
            "151/333\n",
            "152/333\n",
            "153/333\n",
            "154/333\n",
            "155/333\n",
            "156/333\n",
            "157/333\n",
            "158/333\n",
            "159/333\n",
            "160/333\n",
            "161/333\n",
            "162/333\n",
            "163/333\n",
            "164/333\n",
            "165/333\n",
            "166/333\n",
            "167/333\n",
            "168/333\n",
            "169/333\n",
            "170/333\n",
            "171/333\n",
            "172/333\n",
            "173/333\n",
            "174/333\n",
            "175/333\n",
            "176/333\n",
            "177/333\n",
            "178/333\n",
            "179/333\n",
            "180/333\n",
            "181/333\n",
            "182/333\n",
            "183/333\n",
            "184/333\n",
            "185/333\n",
            "186/333\n",
            "187/333\n",
            "188/333\n",
            "189/333\n",
            "190/333\n",
            "191/333\n",
            "192/333\n",
            "193/333\n",
            "194/333\n",
            "195/333\n",
            "196/333\n",
            "197/333\n",
            "198/333\n",
            "199/333\n",
            "200/333\n",
            "201/333\n",
            "202/333\n",
            "203/333\n",
            "204/333\n",
            "205/333\n",
            "206/333\n",
            "207/333\n",
            "208/333\n",
            "209/333\n",
            "210/333\n",
            "211/333\n",
            "212/333\n",
            "213/333\n",
            "214/333\n",
            "215/333\n",
            "216/333\n",
            "217/333\n",
            "218/333\n",
            "219/333\n",
            "220/333\n",
            "221/333\n",
            "222/333\n",
            "223/333\n",
            "224/333\n",
            "225/333\n",
            "226/333\n",
            "227/333\n",
            "228/333\n",
            "229/333\n",
            "230/333\n",
            "231/333\n",
            "232/333\n",
            "233/333\n",
            "234/333\n",
            "235/333\n",
            "236/333\n",
            "237/333\n",
            "238/333\n",
            "239/333\n",
            "240/333\n",
            "241/333\n",
            "242/333\n",
            "243/333\n",
            "244/333\n",
            "245/333\n",
            "246/333\n",
            "247/333\n",
            "248/333\n",
            "249/333\n",
            "250/333\n",
            "251/333\n",
            "252/333\n",
            "253/333\n",
            "254/333\n",
            "255/333\n",
            "256/333\n",
            "257/333\n",
            "258/333\n",
            "259/333\n",
            "260/333\n",
            "261/333\n",
            "262/333\n",
            "263/333\n",
            "264/333\n",
            "265/333\n",
            "266/333\n",
            "267/333\n",
            "268/333\n",
            "269/333\n",
            "270/333\n",
            "271/333\n",
            "272/333\n",
            "273/333\n",
            "274/333\n",
            "275/333\n",
            "276/333\n",
            "277/333\n",
            "278/333\n",
            "279/333\n",
            "280/333\n",
            "281/333\n",
            "282/333\n",
            "283/333\n",
            "284/333\n",
            "285/333\n",
            "286/333\n",
            "287/333\n",
            "288/333\n",
            "289/333\n",
            "290/333\n",
            "291/333\n",
            "292/333\n",
            "293/333\n",
            "294/333\n",
            "295/333\n",
            "296/333\n",
            "297/333\n",
            "298/333\n",
            "299/333\n",
            "300/333\n",
            "301/333\n",
            "302/333\n",
            "303/333\n",
            "304/333\n",
            "305/333\n",
            "306/333\n",
            "307/333\n",
            "308/333\n",
            "309/333\n",
            "310/333\n",
            "311/333\n",
            "312/333\n",
            "313/333\n",
            "314/333\n",
            "315/333\n",
            "316/333\n",
            "317/333\n",
            "318/333\n",
            "319/333\n",
            "320/333\n",
            "321/333\n",
            "322/333\n",
            "323/333\n",
            "324/333\n",
            "325/333\n",
            "326/333\n",
            "327/333\n",
            "328/333\n",
            "329/333\n",
            "330/333\n",
            "331/333\n",
            "332/333\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "distractors = distractors_idiom + distractors_simile + distractors_metaphor\n",
        "answers = answers_idiom + answers_simile + answers_metaphor\n",
        "texts = texts_idiom + texts_simile + texts_metaphor\n",
        "types = types_idiom + types_simile + types_metaphor"
      ],
      "metadata": {
        "id": "tdzj1UEjpNUN"
      },
      "execution_count": 11,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "batch_size = 16\n",
        "\n",
        "dataset = FigTypeDataset(answers, texts, types)\n",
        "\n",
        "train_dataset, test_dataset = torch.utils.data.random_split(dataset, [int(0.8*len(dataset)), len(dataset)-int(0.8*len(dataset))])\n",
        "train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, collate_fn=collate_fn)\n",
        "test_loader = DataLoader(test_dataset, shuffle=False, batch_size=batch_size, collate_fn=collate_fn)"
      ],
      "metadata": {
        "id": "IrBVVKsopsR5"
      },
      "execution_count": 12,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "device = 'cuda'"
      ],
      "metadata": {
        "id": "3ODm8q__VAXs"
      },
      "execution_count": 14,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LgpMeb4eRI5Z"
      },
      "source": [
        "##FactorCL-SUP"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 19,
      "metadata": {
        "id": "Ha5JwFzHRKjd",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "e7f34872-1447-4e7a-9219-d7d7375c8f14"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "iter:  0  i_batch:  0  loss:  -0.0011595366522669792\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.10/dist-packages/PIL/Image.py:996: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images\n",
            "  warnings.warn(\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "iter:  1  i_batch:  0  loss:  -0.002346084453165531\n",
            "iter:  2  i_batch:  0  loss:  -0.0030839326791465282\n",
            "iter:  3  i_batch:  0  loss:  -0.0037324423901736736\n",
            "iter:  4  i_batch:  0  loss:  -0.0055291056632995605\n",
            "iter:  5  i_batch:  0  loss:  -0.006623566150665283\n",
            "iter:  6  i_batch:  0  loss:  -0.007975934073328972\n",
            "iter:  7  i_batch:  0  loss:  -0.010226668789982796\n",
            "iter:  8  i_batch:  0  loss:  -0.012427756562829018\n",
            "iter:  9  i_batch:  0  loss:  -0.014025865122675896\n"
          ]
        }
      ],
      "source": [
        "model = CLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
        "processor = AutoProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
        "\n",
        "factorcl_sup = FactorCLSUP(model, processor, [512,512], 3, device, lr=1e-6).to(device)\n",
        "factorcl_sup.train()\n",
        "\n",
        "train_sup_model(factorcl_sup, train_loader, num_epoch=10, num_club_iter=1)\n",
        "\n",
        "model.eval()\n",
        "train_embeds, train_labels, test_embeds, test_labels = get_embeds(model, processor, train_loader, test_loader)\n",
        "\n",
        "clf = LogisticRegression(max_iter=200).fit(train_embeds, train_labels)\n",
        "score = clf.score(test_embeds, test_labels)"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "score"
      ],
      "metadata": {
        "id": "oTEJEZ1wdE_M",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "741211ce-5596-4143-afb4-29614994952e"
      },
      "execution_count": 20,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "1.0"
            ]
          },
          "metadata": {},
          "execution_count": 20
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "##FactorCL-SSL"
      ],
      "metadata": {
        "id": "MqKnax76Rh-g"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "model = CLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
        "processor = AutoProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
        "\n",
        "factorcl_ssl = FactorCLSSL(model, processor, [512,512], 3, device, lr=1e-6).to(device)\n",
        "factorcl_ssl.train()\n",
        "\n",
        "train_ssl_model(factorcl_ssl, train_loader, num_epoch=10, num_club_iter=1)\n",
        "\n",
        "model.eval()\n",
        "train_embeds, train_labels, test_embeds, test_labels = get_embeds(model, processor, train_loader, test_loader)\n",
        "\n",
        "clf = LogisticRegression(max_iter=200).fit(train_embeds, train_labels)\n",
        "score = clf.score(test_embeds, test_labels)"
      ],
      "metadata": {
        "id": "ofhQhcPFRYah"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "score"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "EnptiDeNRY8h",
        "outputId": "066af378-521a-4a55-925d-343d700d7487"
      },
      "execution_count": 22,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "0.9324324324324325"
            ]
          },
          "metadata": {},
          "execution_count": 22
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8bvQGiz2nq_J"
      },
      "source": [
        "##SimCLR"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 23,
      "metadata": {
        "id": "T7cS_0F7IYnr",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "80dd1088-9926-44b9-e524-d841dc9aaae2"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "iter:  0  i_batch:  0  loss:  28.418033599853516\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.10/dist-packages/PIL/Image.py:996: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images\n",
            "  warnings.warn(\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "iter:  1  i_batch:  0  loss:  27.95946502685547\n",
            "iter:  2  i_batch:  0  loss:  27.49124526977539\n",
            "iter:  3  i_batch:  0  loss:  27.05012321472168\n",
            "iter:  4  i_batch:  0  loss:  26.655895233154297\n",
            "iter:  5  i_batch:  0  loss:  26.386932373046875\n",
            "iter:  6  i_batch:  0  loss:  26.163578033447266\n",
            "iter:  7  i_batch:  0  loss:  25.994287490844727\n",
            "iter:  8  i_batch:  0  loss:  25.879667282104492\n",
            "iter:  9  i_batch:  0  loss:  25.78017234802246\n"
          ]
        }
      ],
      "source": [
        "model = CLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
        "processor = AutoProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
        "\n",
        "# Set use_label=True for SupCon\n",
        "simclr_model = SupConModel(model, processor, 0.5, [512,512], [512,512], use_label=False).to(device)\n",
        "simclr_model.train()\n",
        "\n",
        "optimizer = optim.Adam(simclr_model.parameters(), lr=1e-6)\n",
        "\n",
        "train_supcon(simclr_model, train_loader, optimizer, num_epoch=10)\n",
        "\n",
        "model.eval()\n",
        "train_embeds, train_labels, test_embeds, test_labels = get_embeds(model, processor, train_loader, test_loader)\n",
        "\n",
        "clf = LogisticRegression(max_iter=200).fit(train_embeds, train_labels)\n",
        "score = clf.score(test_embeds, test_labels)"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "score"
      ],
      "metadata": {
        "id": "zxpn6JcEWQc1",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "aab0cb94-7836-4527-a647-8d6a70425644"
      },
      "execution_count": 24,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "0.8918918918918919"
            ]
          },
          "metadata": {},
          "execution_count": 24
        }
      ]
    }
  ]
}