{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "gpuType": "T4"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3bGElT9XAYa3"
      },
      "outputs": [],
      "source": [
        "!pip install transformers\n",
        "!pip install timm"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "from google.colab import drive\n",
        "drive.mount('/content/drive')"
      ],
      "metadata": {
        "id": "ZUqsu4KE69Lv"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# !pip install pyunpack\n",
        "# !pip install patool\n",
        "# from pyunpack import Archive\n",
        "# Archive('/content/drive/MyDrive/CrisisHateMM_01.zip').extractall('/content/')"
      ],
      "metadata": {
        "id": "Il4aVFYziC3H"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "import os\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "from tqdm import tqdm\n",
        "from time import sleep\n",
        "from transformers import AdamW\n",
        "import torchvision.datasets as datasets\n",
        "from transformers import CLIPProcessor, CLIPModel\n",
        "from transformers import AutoProcessor, BlipModel\n",
        "import torch.nn.functional as F\n",
        "from sklearn.metrics import classification_report, accuracy_score\n",
        "from PIL import Image, ImageDraw\n",
        "import cv2 as cv\n",
        "import numpy as np\n",
        "import sys\n",
        "sys.path.extend(\"/content/drive/MyDrive/Scripts/cf_gen\")\n",
        "from torchvision import transforms\n",
        "import re\n",
        "import random\n",
        "import pandas as pd\n",
        "import ast\n",
        "import timm\n",
        "from sklearn.metrics import roc_auc_score"
      ],
      "metadata": {
        "id": "qLaClBsSiC9q"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "batch_size = 4\n",
        "n_classes = 3\n",
        "lr = 1e-3\n",
        "epochs = 4\n",
        "image_dim = 224\n",
        "root_dir = '/content/CrisisHateMM_01'\n",
        "device = 'cuda'\n",
        "n_classes = 2\n",
        "\n",
        "trans = transforms.Compose(\n",
        "    [\n",
        "        transforms.Resize((image_dim, image_dim)),\n",
        "        transforms.ToTensor(),\n",
        "    ]\n",
        ")\n",
        "\n",
        "def seed_everything(seed: int):\n",
        "\n",
        "    random.seed(seed)\n",
        "    os.environ['PYTHONHASHSEED'] = str(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    torch.cuda.manual_seed(seed)\n",
        "    torch.backends.cudnn.deterministic = True\n",
        "    torch.backends.cudnn.benchmark = True\n",
        "    g = torch.Generator()\n",
        "    g.manual_seed(seed)\n",
        "\n",
        "seed_everything(510)"
      ],
      "metadata": {
        "id": "Ew990AKWiHae"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "class EmbedDataset(torch.utils.data.Dataset):\n",
        "\n",
        "    def __init__(self, img_embeds, text_embeds, labels):\n",
        "        self.img_embeds = img_embeds\n",
        "        self.text_embeds = text_embeds\n",
        "        self.labels = labels\n",
        "\n",
        "    def __getitem__(self, idx):\n",
        "\n",
        "        img_embed = self.img_embeds[idx]\n",
        "        text_embed = self.text_embeds[idx]\n",
        "        label = self.labels[idx]\n",
        "\n",
        "        return (img_embed, text_embed, label)\n",
        "\n",
        "    def __len__(self):\n",
        "        return len(self.text_embeds)\n",
        "\n",
        "class IT_Dataset(torch.utils.data.Dataset):\n",
        "\n",
        "    def __init__(self, images, texts):\n",
        "        self.images = images\n",
        "        self.texts = texts\n",
        "\n",
        "    def __getitem__(self, idx):\n",
        "\n",
        "        image = self.images[idx]\n",
        "        text = self.texts[idx]\n",
        "\n",
        "        return (image, text)\n",
        "\n",
        "    def __len__(self):\n",
        "        return len(self.texts)"
      ],
      "metadata": {
        "id": "Oiv0eoq2iIvO"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "class classifier(nn.Module):\n",
        "\n",
        "    def __init__(self):\n",
        "        super(classifier, self).__init__()\n",
        "\n",
        "        self.fc1 = nn.Linear(1280, 128)\n",
        "        self.fc2 = nn.Linear(128, n_classes)\n",
        "        self.dropout = nn.Dropout(0.25)\n",
        "\n",
        "    def forward(self, img_embed, text_embed):\n",
        "\n",
        "        concat = torch.cat((img_embed, text_embed), 1).to(device)\n",
        "\n",
        "        linear_out = self.dropout(F.relu(self.fc1(concat)))\n",
        "        final_out = self.fc2(linear_out)\n",
        "\n",
        "        return final_out\n",
        "\n",
        "class text_classifier(nn.Module):\n",
        "\n",
        "    def __init__(self):\n",
        "        super(text_classifier, self).__init__()\n",
        "\n",
        "        self.fc1 = nn.Linear(512, 128)\n",
        "        self.fc2 = nn.Linear(128, n_classes)\n",
        "        self.dropout = nn.Dropout(0.25)\n",
        "\n",
        "    def forward(self, img_embed, text_embed):\n",
        "\n",
        "        linear_out = self.dropout(F.relu(self.fc1(text_embed.to(device))))\n",
        "        final_out = self.fc2(linear_out)\n",
        "\n",
        "        return final_out\n",
        "\n",
        "class img_classifier(nn.Module):\n",
        "\n",
        "    def __init__(self):\n",
        "        super(img_classifier, self).__init__()\n",
        "\n",
        "        self.fc1 = nn.Linear(640, 128)\n",
        "        self.fc2 = nn.Linear(128, n_classes)\n",
        "        self.dropout = nn.Dropout(0.25)\n",
        "\n",
        "    def forward(self, img_embed, text_embed):\n",
        "\n",
        "        linear_out = self.dropout(F.relu(self.fc1(img_embed.to(device))))\n",
        "        final_out = self.fc2(linear_out)\n",
        "\n",
        "        return final_out"
      ],
      "metadata": {
        "id": "wbSNTRUmiTWm"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def train(model, train_loader, val_loader):\n",
        "\n",
        "    optimizer = AdamW(model.parameters(), lr = lr, eps=1e-8)\n",
        "    criterion = nn.CrossEntropyLoss()\n",
        "\n",
        "    for epoch in range(epochs):\n",
        "\n",
        "        train_epoch_loss = 0\n",
        "        train_epoch_accuracy = 0\n",
        "\n",
        "        model.train()\n",
        "\n",
        "        for (image_embed, text_embed, labels) in tqdm(train_loader):\n",
        "\n",
        "\n",
        "            optimizer.zero_grad()\n",
        "\n",
        "            y_pred = []\n",
        "            y_true = []\n",
        "\n",
        "            output = model(image_embed, text_embed)\n",
        "\n",
        "            loss = criterion(output, labels.cuda())\n",
        "\n",
        "            loss.backward()\n",
        "            optimizer.step()\n",
        "\n",
        "            _, preds = output.data.max(1)\n",
        "            y_pred.extend(preds.tolist())\n",
        "            y_true.extend(labels.tolist())\n",
        "\n",
        "            acc = accuracy_score(y_true, y_pred)\n",
        "            train_epoch_accuracy += acc / len(train_loader)\n",
        "            train_epoch_loss += loss / len(train_loader)\n",
        "\n",
        "        val_epoch_loss = 0\n",
        "        val_epoch_accuracy = 0\n",
        "\n",
        "        model.eval()\n",
        "\n",
        "        for (image_embed, text_embed, labels) in tqdm(val_loader):\n",
        "            with torch.no_grad():\n",
        "\n",
        "                y_pred = []\n",
        "                y_true = []\n",
        "\n",
        "                output = model(image_embed, text_embed)\n",
        "\n",
        "                loss = criterion(output, labels.cuda())\n",
        "\n",
        "                _, preds = output.data.max(1)\n",
        "                y_pred.extend(preds.tolist())\n",
        "                y_true.extend(labels.tolist())\n",
        "\n",
        "                acc = accuracy_score(y_true, y_pred)\n",
        "                val_epoch_accuracy += acc / len(val_loader)\n",
        "                val_epoch_loss += loss / len(val_loader)\n",
        "\n",
        "\n",
        "        print(f\"Epoch : {epoch+1} - train_loss : {train_epoch_loss:.4f} - train_acc: {train_epoch_accuracy:.4f} - val_loss : {val_epoch_loss:.4f} - val_acc: {val_epoch_accuracy:.4f}\\n\")\n",
        "\n",
        "\n",
        "def test(model, test_loader):\n",
        "\n",
        "    y_pred = []\n",
        "    y_true = []\n",
        "\n",
        "    model.eval()\n",
        "\n",
        "    for (image_embed, text_embed, labels) in tqdm(test_loader):\n",
        "        with torch.no_grad():\n",
        "\n",
        "                output = model(image_embed, text_embed)\n",
        "                _, preds = output.data.max(1)\n",
        "                y_pred.extend(preds.tolist())\n",
        "                y_true.extend(labels.tolist())\n",
        "\n",
        "    print(classification_report(y_true, y_pred))\n",
        "    print(roc_auc_score(y_true, y_pred))\n"
      ],
      "metadata": {
        "id": "8yWZEXM1iVHf"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "images = torch.load(os.path.join('/content/drive/MyDrive', 'images_inpainted.pt'))\n",
        "# images = torch.load(os.path.join('/content/drive/MyDrive', 'images_og.pt'))"
      ],
      "metadata": {
        "id": "8x0EvOXvjxwm"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "df = pd.read_csv('/content/drive/MyDrive/TLP.csv')\n",
        "\n",
        "texts = list(df['text'])\n",
        "labels = list(df['label'])\n",
        "\n",
        "processor = CLIPProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
        "embed_model = CLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\").to(device)\n",
        "\n",
        "\n",
        "\n",
        "# processor = AutoProcessor.from_pretrained(\"Salesforce/blip-image-captioning-base\")\n",
        "# embed_model = BlipModel.from_pretrained(\"Salesforce/blip-image-captioning-base\").to(device)\n",
        "\n",
        "# from transformers import AutoProcessor, AlignModel, Blip2Model, AutoTokenizer\n",
        "\n",
        "# embed_model = AlignModel.from_pretrained(\"kakaobrain/align-base\").to(device)\n",
        "# processor = AutoProcessor.from_pretrained(\"kakaobrain/align-base\")\n",
        "\n",
        "\n",
        "# model = Blip2Model.from_pretrained(\"Salesforce/blip2-opt-2.7b\", torch_dtype=torch.float16)\n",
        "# tokenizer = AutoTokenizer.from_pretrained(\"Salesforce/blip2-opt-2.7b\")\n",
        "# processor = AutoProcessor.from_pretrained(\"Salesforce/blip2-opt-2.7b\")"
      ],
      "metadata": {
        "id": "ECKi0kNHiWij"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def get_embeds(it_dataloader, processor, embed_model):\n",
        "\n",
        "    embed_model.eval()\n",
        "    image_embeds = []\n",
        "    text_embeds = []\n",
        "\n",
        "    for batch in tqdm(it_dataloader):\n",
        "\n",
        "      images, texts = batch\n",
        "\n",
        "      inputs = processor(text=texts, images=images, return_tensors=\"pt\", padding=True, truncation = True).to(device)\n",
        "      clip_output = embed_model(**inputs)\n",
        "\n",
        "      i_e = clip_output.vision_model_output.pooler_output.detach().cpu().numpy()\n",
        "      t_e = clip_output.text_model_output.pooler_output.detach().cpu().numpy()\n",
        "\n",
        "      # inputs = tokenizer(text=texts, return_tensors=\"pt\", padding=True, truncation = True).to(device)\n",
        "      # t_e = embed_model.get_text_features(**inputs).detach().cpu().numpy()\n",
        "      # inputs = processor(images=images, return_tensors=\"pt\", padding=True, truncation = True).to(device)\n",
        "      # i_e = embed_model.get_image_features(**inputs).detach().cpu().numpy()\n",
        "\n",
        "      for i in i_e:\n",
        "        image_embeds.append(i)\n",
        "\n",
        "      for i in t_e:\n",
        "        text_embeds.append(i)\n",
        "\n",
        "    return image_embeds, text_embeds"
      ],
      "metadata": {
        "id": "yMWFSlY4k_f3"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "it_dataset = IT_Dataset(images, texts)\n",
        "it_dataloader = torch.utils.data.DataLoader(it_dataset,batch_size=batch_size,pin_memory=True,num_workers=0,shuffle=False,drop_last=True)\n",
        "\n",
        "image_embeds, text_embeds = get_embeds(it_dataloader, processor, embed_model)"
      ],
      "metadata": {
        "id": "WMBUuLOyimJU"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "embed_dataset = EmbedDataset(image_embeds, text_embeds, labels)\n",
        "train_dataset, test_dataset, val_dataset = torch.utils.data.random_split(embed_dataset, [0.7, 0.15, 0.15])"
      ],
      "metadata": {
        "id": "yV7G1T5GwxkT"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,pin_memory=True,num_workers=0,shuffle=True,drop_last=True)\n",
        "test_loader = torch.utils.data.DataLoader(test_dataset,batch_size=batch_size,pin_memory=True,num_workers=0,shuffle=False,drop_last=True)\n",
        "val_loader = torch.utils.data.DataLoader(val_dataset,batch_size=batch_size,pin_memory=True,num_workers=0,shuffle=True,drop_last=True)"
      ],
      "metadata": {
        "id": "8LetfUFrioMR"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "batch_size = 4\n",
        "lr = 1e-4\n",
        "epochs = 10\n",
        "image_dim = 224\n",
        "device = 'cuda'\n",
        "n_classes = 2\n",
        "\n",
        "# model = classifier().to(device)\n",
        "model = img_classifier().to(device)\n",
        "# model = text_classifier().to(device)\n",
        "\n",
        "for model in [classifier().to(device), img_classifier().to(device), text_classifier().to(device)]:\n",
        "    train(model, train_loader, val_loader)\n",
        "    test(model, test_loader)"
      ],
      "metadata": {
        "id": "MnspVGCfiqco"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import torchvision.transforms as T\n",
        "from PIL import Image\n",
        "\n",
        "transform = T.ToPILImage()\n",
        "img = transform(images[100])\n",
        "display(img)"
      ],
      "metadata": {
        "id": "VyjmBZeZ41HV"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}