{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "VwdKWjwCdiEo",
        "outputId": "9a82fe8f-b7fb-4345-de9b-5a9b9485fd33"
      },
      "outputs": [],
      "source": [
        "%pip install torch torchvision wilds opencv-python ttach kornia\n",
        "%pip install ftfy regex tqdm\n",
        "%pip install git+https://github.com/openai/CLIP.git\n",
        "%pip install grad-cam"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "id": "GqBklCpugc_l"
      },
      "outputs": [],
      "source": [
        "import io\n",
        "import requests\n",
        "\n",
        "import torch\n",
        "from torchvision import models\n",
        "import clip\n",
        "from PIL import Image\n",
        "import pandas as pd\n",
        "import numpy as np\n",
        "import cv2\n",
        "import requests\n",
        "import pytorch_grad_cam\n",
        "from pytorch_grad_cam import GradCAM\n",
        "from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget\n",
        "from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# Wrapper Class\n",
        "We need a wrapper class for CLIP to interact with the GradCAM interface; it needs to mimic a classification model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "id": "aOlyjnHThN1X"
      },
      "outputs": [],
      "source": [
        "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
        "\n",
        "class CLIPWrapper:\n",
        "  def  __init__(self, base, text, *args, **kwargs):\n",
        "    self.base = base\n",
        "    self.text = clip.tokenize(text).cuda()\n",
        "\n",
        "  def __call__(self, image):\n",
        "    logits_per_image, logits_per_text = self.base(image, self.text)\n",
        "    probs = logits_per_image.softmax(dim=-1)\n",
        "    return probs\n",
        "\n",
        "  def dropoff(self, old_image, new_image):\n",
        "    old_logits = self.base(old_image, self.text)[0][0]\n",
        "    new_logits = self.base(new_image, self.text)[0][0]\n",
        "    return (new_logits - old_logits), (new_logits.softmax(dim=-1) - old_logits.softmax(dim=-1))\n",
        "\n",
        "  def eval(self):\n",
        "    self.base = self.base.eval()\n",
        "    return self\n",
        "\n",
        "  def cuda(self):\n",
        "    self.base = self.base.cuda()\n",
        "    self.text = self.text.cuda()\n",
        "    return self\n",
        "\n",
        "  def __getattr__(self, name):\n",
        "      return getattr(self.base, name)\n",
        "\n",
        "def reshape_transform_vit(tensor):\n",
        "  grid_square = len(tensor) - 1\n",
        "  if grid_square ** 0.5 % 1 == 0:\n",
        "    height = width = int(grid_square**0.5)\n",
        "\n",
        "    result = tensor[1:, :, :].reshape(\n",
        "        height,\n",
        "        width,\n",
        "        tensor.size(2)\n",
        "    )\n",
        "\n",
        "    result = result.permute(2, 0, 1)\n",
        "    return result.unsqueeze(0)\n",
        "\n",
        "def show_gradcam(model_name, captions, image_id, url=True):\n",
        "  model, preprocess = clip.load(model_name, device=device)\n",
        "  model = CLIPWrapper(model, captions)\n",
        "  model.eval()\n",
        "  input_tensor = preprocess(Image.open(requests.get(image_id, stream=True).raw if url else image_id)).to(device).unsqueeze(0)\n",
        "\n",
        "  img = np.array(Image.open(requests.get(image_id, stream=True).raw if url else image_id))\n",
        "  img = cv2.resize(img, (224, 224))\n",
        "  img = np.float32(img) / 255\n",
        "\n",
        "  target_layers = [model.visual.transformer.resblocks[-1].ln_1 if model_name.startswith('ViT') else model.visual.layer4]\n",
        "  with GradCAM(model=model, target_layers=target_layers, reshape_transform=reshape_transform_vit if model_name.startswith('ViT') else None) as cam:\n",
        "      grayscale_cams = cam(input_tensor=input_tensor, targets=None)\n",
        "      cam_image = show_cam_on_image(img, grayscale_cams[0, :], use_rgb=True)\n",
        "  cam = np.uint8(255*grayscale_cams[0, :])\n",
        "  cam = cv2.merge([cam, cam, cam])\n",
        "  images = np.hstack((np.uint8(255*img), cam , cam_image))\n",
        "\n",
        "  probs = model(input_tensor)\n",
        "  maxi = probs.argmax()\n",
        "  print(f\"Prediction: {captions[maxi]} ({(100*probs[0][maxi]):0.2f}%)\")\n",
        "  return Image.fromarray(images)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "4p4koQc0Gurn",
        "outputId": "c8231ff1-0e1e-49eb-cb90-886023323238"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "100%|███████████████████████████████████████| 338M/338M [00:03<00:00, 89.1MiB/s]\n"
          ]
        }
      ],
      "source": [
        "model_name = \"ViT-B/32\"\n",
        "clip_model, preprocess = clip.load(model_name)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# Downloading FairFace\n",
        "We also restrict our analysis to middle age persons."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "eOpuhwrpqSjn",
        "outputId": "670b0050-5832-4b55-e218-1a95a17797b6"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Downloading...\n",
            "From: https://drive.google.com/uc?id=1Z1RqRo0_JiavaZw2yzZG6WETdZQ8qX86\n",
            "To: /content/fairface-img-margin025-trainval.zip\n",
            "100% 578M/578M [00:13<00:00, 42.0MB/s]\n"
          ]
        }
      ],
      "source": [
        "!gdown 1Z1RqRo0_JiavaZw2yzZG6WETdZQ8qX86\n",
        "!unzip /content/fairface-img-margin025-trainval.zip\n",
        "!gdown 1i1L3Yqwaio7YSOCj7ftgk8ZZchPG7dmH"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "id": "B7eosinmtBdw"
      },
      "outputs": [],
      "source": [
        "data = pd.read_csv('/content/fairface_label_train.csv')\n",
        "data = data.replace('Latino_Hispanic', 'Latino Hispanic')\n",
        "races = data.race.unique()\n",
        "middle_age = data[data.age.isin(['20-29', '30-39', '40-49'])]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# GradCAM visualization\n",
        "As in the paper we choose Indian men with a high confidence of being classified as smart and analyze the corresponding CAMs."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 763
        },
        "id": "3EI6ir8SNQ5E",
        "outputId": "795b01ac-5cdd-4a6b-9a8c-84738e56487e"
      },
      "outputs": [],
      "source": [
        "model = CLIPWrapper(clip_model, [f'a {race} man' for race in races])\n",
        "model = CLIPWrapper(clip_model, ['a smart man', 'a silly man'])\n",
        "target_layers = [model.visual.transformer.resblocks[-1].ln_1 if model_name.startswith('ViT') else model.visual.layer4]\n",
        "\n",
        "dropoffs = []\n",
        "df = middle_age[(middle_age.race == 'Indian') & (middle_age.gender == 'Male')]\n",
        "ids = df.sample(n=100).index.tolist()\n",
        "result = []\n",
        "\n",
        "for id in ids:\n",
        "  if len(result) >= 5:\n",
        "    break\n",
        "  image_id = data.iloc[id].file\n",
        "\n",
        "  old = cv2.resize(np.array(Image.open(image_id)), (224, 224))\n",
        "  probs = model(preprocess(Image.fromarray(old)).to(device).unsqueeze(0))\n",
        "  if probs[0][0] < 2/3:\n",
        "    continue\n",
        "  print(id)\n",
        "  print(probs, probs[0][0])\n",
        "  input_tensor = preprocess(Image.open(image_id)).to(device).unsqueeze(0)\n",
        "\n",
        "  cam = GradCAM(model=model, target_layers=target_layers, use_cuda=True, reshape_transform=reshape_transform_vit if model_name.startswith('ViT') else None)\n",
        "  grayscale_cams = cam(input_tensor=input_tensor, targets=[ClassifierOutputTarget(0)])\n",
        "\n",
        "  img = np.array(Image.open(image_id))\n",
        "  img = cv2.resize(img, (224, 224))\n",
        "  img = np.float32(img) / 255\n",
        "\n",
        "  cam_image = show_cam_on_image(img, grayscale_cams[0, :], use_rgb=True)\n",
        "  cam = np.uint8(255*grayscale_cams[0, :])\n",
        "  cam = cv2.merge([cam, cam, cam])\n",
        "  result.append(np.vstack((np.uint8(255*img), cam_image)))\n",
        "\n",
        "display(Image.fromarray(np.hstack(result)))\n",
        "\n"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "T4",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
