{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "af6d7eb5",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/shenyu/miniconda3/envs/sean/lib/python3.8/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: '/home/shenyu/miniconda3/envs/sean/lib/python3.8/site-packages/torchvision/image.so: undefined symbol: _ZN3c107WarningC1ENS_7variantIJNS0_11UserWarningENS0_18DeprecationWarningEEEERKNS_14SourceLocationENSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEEb'If you don't plan on using image functionality from `torchvision.io`, you can ignore this warning. Otherwise, there might be something wrong with your environment. Did you have `libjpeg` or `libpng` installed before building `torchvision` from source?\n",
      "  warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Datasets and loaders ready.\n"
     ]
    }
   ],
   "source": [
    "###### data loader####\n",
    "import os\n",
    "import pandas as pd\n",
    "from torch.utils.data import Dataset\n",
    "import torchvision.transforms as tfms\n",
    "from PIL import Image\n",
    "import random\n",
    "from tqdm import trange\n",
    "from sklearn.metrics import accuracy_score, precision_score\n",
    "from sklearn.metrics import confusion_matrix\n",
    "from torch import nn, optim\n",
    "import torch\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "import open_clip\n",
    "from open_clip import create_model_from_pretrained, get_tokenizer # works on open-clip-torch>=2.23.0, timm>=0.9.8\n",
    "from sklearn.model_selection import train_test_split\n",
    "import os.path as osp\n",
    "\n",
    "torch.set_num_threads(5)   # Sets the number of threads used for intra-operations\n",
    "torch.set_num_interop_threads(5)   # Sets the number of threads used for inter-operations\n",
    "\n",
    "import open_clip\n",
    "\n",
    "device = torch.device(\"cuda:1\" if torch.cuda.is_available() else \"cpu\")\n",
    "logabs = lambda x: torch.log(torch.abs(x))\n",
    "batch_size = 256\n",
    "\n",
    "\n",
    "def seed_everything(seed):\n",
    "    \"\"\"\n",
    "    Changes the seed for reproducibility. \n",
    "    \"\"\"\n",
    "    random.seed(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "    torch.backends.cudnn.benchmark = False\n",
    "    \n",
    "\n",
    "\n",
    "\n",
    "model, preprocess = create_model_from_pretrained('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')\n",
    "model = model.to(device)\n",
    "model = model.eval()\n",
    "tokenizer = get_tokenizer('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')\n",
    "\n",
    "\n",
    "seed_everything(1024)\n",
    "\n",
    "\n",
    "\n",
    "class COVIDChestXrayDataset(Dataset):\n",
    "    def __init__(self, data_dir, split_type):\n",
    "        super().__init__()\n",
    "        self.data_dir = data_dir\n",
    "        self.images_dir = os.path.join(self.data_dir, 'images')\n",
    "        self.metadata = pd.read_csv(os.path.join(self.data_dir, 'metadata.csv'))\n",
    "        view_filter = ['AP', 'AP Erect', 'PA', 'AP Supine']\n",
    "\n",
    "        # Filter dataset\n",
    "        dset = self.metadata[self.metadata['view'].isin(view_filter)]\n",
    "        \n",
    "        # Creating splits\n",
    "        male_covid = dset[(dset['finding'] == 'Pneumonia/Viral/COVID-19') & (dset['sex'] == 'M')]\n",
    "        female_covid = dset[(dset['finding'] == 'Pneumonia/Viral/COVID-19') & (dset['sex'] == 'F')]\n",
    "        male_noncovid = dset[(dset['finding'] != 'Pneumonia/Viral/COVID-19') & (dset['sex'] == 'M')]\n",
    "        female_noncovid = dset[(dset['finding'] != 'Pneumonia/Viral/COVID-19') & (dset['sex'] == 'F')]\n",
    "\n",
    "        self.split_data = {\n",
    "            'train': self.build_split([male_covid, female_covid, male_noncovid, female_noncovid], 76),\n",
    "            'val': self.build_split([male_covid, female_covid, male_noncovid, female_noncovid], [183, 92, 107, 76], [46, 24, 27, 19]),\n",
    "            'test': self.build_split([male_covid, female_covid, male_noncovid, female_noncovid], [183 + 46, 92 + 24, 107 + 27, 76 + 19])\n",
    "        }\n",
    "\n",
    "        self.data = self.split_data[split_type]\n",
    "        self.transform = self.get_transform()\n",
    "\n",
    "    def build_split(self, groups, ranges, counts=None):\n",
    "        if isinstance(ranges, int):\n",
    "            ranges = [ranges] * len(groups)\n",
    "        if counts is None:\n",
    "            counts = [len(g) - r for g, r in zip(groups, ranges)]  # Calculate remaining data for test set\n",
    "        split = []\n",
    "        for group, start, count in zip(groups, ranges, counts):\n",
    "            end = start + count\n",
    "            split.extend(group.iloc[start:end].apply(lambda x: [os.path.join(self.images_dir, x['filename']), int('COVID-19' in x['finding']), int(x['sex'] == 'M')], axis=1).tolist())\n",
    "        return split\n",
    "\n",
    "    def get_transform(self):\n",
    "        return tfms.Compose([\n",
    "            tfms.Resize((224, 224)),\n",
    "            tfms.ToTensor()\n",
    "        ])\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.data)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        img_filename, y, a = self.data[idx]\n",
    "        image = Image.open(img_filename).convert('RGB') \n",
    "        img = preprocess(image)\n",
    "        img_for_res = self.transform(image)\n",
    "        return img, y, a, img_for_res\n",
    "\n",
    "# Example usage\n",
    "data_dir = '../covid-chestxray-dataset'\n",
    "train_dataset = COVIDChestXrayDataset(data_dir, 'train')\n",
    "val_dataset = COVIDChestXrayDataset(data_dir, 'val')\n",
    "test_dataset = COVIDChestXrayDataset(data_dir, 'test')\n",
    "\n",
    "# Example DataLoader setup\n",
    "\n",
    "batch_size = 2  # Define or adjust your batch size\n",
    "training_data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)\n",
    "val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=True)\n",
    "test_data_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=True)\n",
    "\n",
    "print('Datasets and loaders ready.')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "2c46aa40",
   "metadata": {},
   "outputs": [],
   "source": [
    "spurious_text = [\"An X-ray image from a male\",  \"An X-ray image from a female\"] \n",
    "\n",
    "texts = tokenizer(spurious_text).to(device)\n",
    "null_image = torch.rand((1,3,224,224)).to(device)\n",
    "model = model.to(device)\n",
    "_, spurious_embedding, _ = model(null_image, texts)\n",
    "\n",
    "female = spurious_embedding[1].unsqueeze(0).to(device)\n",
    "male = spurious_embedding[0].unsqueeze(0).to(device)\n",
    "\n",
    "no_patch = female\n",
    "patch = male"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "5a11eb7a",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Computing Scale: 100%|███████████████████████████████████████████████████████████████████████| 207/207 [00:27<00:00,  7.51it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.3436713\n",
      "0.3295871\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Zero Shot Testing: 100%|███████████████████████████████████████████████████████████████████████| 72/72 [00:11<00:00,  6.54it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy for y=0, s=0: 0.5217391304347826\n",
      "Accuracy for y=0, s=1: 0.5294117647058824\n",
      "Accuracy for y=1, s=0: 0.5517241379310345\n",
      "Accuracy for y=1, s=1: 0.7586206896420928\n",
      "DP 0.132943143812709\n",
      "EOP 0.2068965517241379\n",
      "EoD 0.10728459299761883\n",
      "acc 0.625\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "def inference_a_test(vlm, spu_v0, spu_v1):\n",
    "    correct_00, total_00 = 0, 0\n",
    "    correct_01, total_01 = 0, 0\n",
    "    correct_10, total_10 = 0, 0\n",
    "    correct_11, total_11 = 0, 0\n",
    "    \n",
    "    for step, (test_input, test_target, sensitive, _) in enumerate(tqdm(test_data_loader, desc=\"Testing\")):\n",
    "        with torch.no_grad():\n",
    "            test_target = test_target.to(device)\n",
    "            sensitive = sensitive.to(device)\n",
    "            test_target = test_target.squeeze()\n",
    "            test_input = test_input.to(device)\n",
    "            z = vlm.encode_image(test_input)\n",
    "            infered_a = inference_a(vlm, no_patch, patch,z )\n",
    "            \n",
    "            mask_00 = ((test_target == 0) & (sensitive == 0))\n",
    "            mask_01 = ((test_target == 0) & (sensitive == 1))\n",
    "            mask_10 = ((test_target == 1) & (sensitive == 0))\n",
    "            mask_11 = ((test_target == 1) & (sensitive == 1))\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "            correct_00 += (infered_a[mask_00] == sensitive[mask_00]).float().sum().item()\n",
    "            total_00 += mask_00.float().sum().item()\n",
    "\n",
    "            correct_01 += (infered_a[mask_01] == sensitive[mask_01]).float().sum().item()\n",
    "            total_01 += mask_01.float().sum().item()\n",
    "\n",
    "            correct_10 += (infered_a[mask_10] == sensitive[mask_10]).float().sum().item()\n",
    "            total_10 += mask_10.float().sum().item()\n",
    "\n",
    "            correct_11 += (infered_a[mask_11] == sensitive[mask_11]).float().sum().item()\n",
    "            total_11 += mask_11.float().sum().item() \n",
    "    acc_00 = correct_00 / total_00\n",
    "    acc_01 = correct_01 / total_01\n",
    "    acc_10 = correct_10 / total_10\n",
    "    acc_11 = correct_11 / (total_11+1e-9)\n",
    "\n",
    "    print(f'Accuracy for y=0, s=0: {acc_00}')\n",
    "    print(f'Accuracy for y=0, s=1: {acc_01}')\n",
    "    print(f'Accuracy for y=1, s=0: {acc_10}')\n",
    "    print(f'Accuracy for y=1, s=1: {acc_11}')   \n",
    "\n",
    "            \n",
    "\n",
    "\n",
    "\n",
    "def inference_a(vlm, spu_v0, spu_v1, z):\n",
    "    text_embeddings = torch.cat((spu_v0, spu_v1), dim=0)\n",
    "    norm_img_embeddings = z \n",
    "    norm_text_embeddings = text_embeddings / text_embeddings.norm(dim=1, keepdim=True)\n",
    "    cosine_similarity = torch.mm(norm_img_embeddings, norm_text_embeddings.t())\n",
    "    logits_per_image = cosine_similarity \n",
    "    probs = logits_per_image.softmax(dim=1)\n",
    "    _, predic = torch.max(probs.data, 1)\n",
    "    return predic\n",
    "\n",
    "            \n",
    "def supervised_inference_a(img):\n",
    "    resnet18 = models.resnet18(pretrained=False)\n",
    "    num_classes = 2 \n",
    "    resnet18.fc = nn.Linear(resnet18.fc.in_features, num_classes)\n",
    "    res_model = resnet18\n",
    "    res_model.load_state_dict(torch.load('res_net.pth'))\n",
    "    res_model = res_model.to(device)\n",
    "    res_model.eval()\n",
    "    img = img.to(device)\n",
    "    test_pred_ = res_model(img)\n",
    "    _, predic = torch.max(test_pred_.data, 1)\n",
    "    return predic            \n",
    "            \n",
    "    \n",
    "def compute_scale(vlm, spu_v0, spu_v1):\n",
    "    vlm = vlm.to(device)\n",
    "    scale_0 = []\n",
    "    scale_1 = []\n",
    "    spu0 = spu_v0  / spu_v0.norm(dim=1, keepdim=True)\n",
    "    spu1 = spu_v1 / spu_v1.norm(dim=1, keepdim=True)\n",
    "    #spu0 =  spu_v0 - spu_v1\n",
    "    #spu0 = spu0 / spu0.norm(dim=1, keepdim=True)\n",
    "    \n",
    "    #spu1 =  spu_v1 - spu_v0\n",
    "    #spu1 = spu1 / spu1.norm(dim=1, keepdim=True)\n",
    "    \n",
    "    for step, (test_input, _, sensitive, img) in enumerate(tqdm(training_data_loader, desc=\"Computing Scale\")):\n",
    "        with torch.no_grad():\n",
    "            \n",
    "            \n",
    "            # put image into the image encoder\n",
    "            test_input = test_input.to(device)\n",
    "            z = vlm.encode_image(test_input)\n",
    "            if a ==True:\n",
    "                sensitive = sensitive\n",
    "            else:\n",
    "                if partial_a == False:\n",
    "                    sensitive = inference_a(vlm, no_patch, patch,z )\n",
    "                elif partial_a == True:\n",
    "                    sensitive = supervised_inference_a(img)\n",
    "            \n",
    "            \n",
    "            mask_0 = sensitive == 0\n",
    "            mask_0 = mask_0.to(device)\n",
    "            h = z[mask_0]\n",
    "            inner_no_patch = torch.mm(h/ h.norm(dim=1, keepdim=True), spu0.t())\n",
    "            scale_0.extend(inner_no_patch.detach().cpu().numpy())\n",
    "                \n",
    "            mask_1 = sensitive == 1\n",
    "            mask_1 = mask_1.to(device)\n",
    "            g = z[mask_1]\n",
    "            inner_patch = torch.mm(g/ g.norm(dim=1, keepdim=True), spu1.t())\n",
    "            scale_1.extend(inner_patch.detach().cpu().numpy())\n",
    "    scale_0 = np.array(scale_0)\n",
    "    scale_1 = np.array(scale_1)\n",
    "    print(np.mean(scale_0))\n",
    "    print(np.mean(scale_1))\n",
    "    return torch.tensor(np.mean(scale_0)), torch.tensor(np.mean(scale_1))\n",
    "\n",
    "\n",
    "\n",
    "def test_epoch(vlm,   dataloader):\n",
    "    scale_0, scale_1 = compute_scale(model, no_patch, patch)\n",
    "\n",
    "    texts_label = [\"an X-ray image of a chest without Pneumonia\", \"an X-ray image of a chest with Pneumonia\"] \n",
    "    text_label_tokened = tokenizer(texts_label).to(device)\n",
    "    \n",
    "    vlm = vlm.to(device)\n",
    "    vlm.eval()   \n",
    "    test_pred = []\n",
    "    test_gt = []\n",
    "    sense_gt = []\n",
    "    female_predic = []\n",
    "    female_gt = []\n",
    "    male_predic = []\n",
    "    male_gt = []\n",
    "    correct_00, total_00 = 0, 0\n",
    "    correct_01, total_01 = 0, 0\n",
    "    correct_10, total_10 = 0, 0\n",
    "    correct_11, total_11 = 0, 0\n",
    "    cos = nn.CosineSimilarity(dim = 0)\n",
    "    feature_a0 = []\n",
    "    feature_a1 = []\n",
    "\n",
    "    for step, (test_input, test_target, sensitive_real,img) in enumerate(tqdm(dataloader, desc=\"Zero Shot Testing\")):\n",
    "        test_target = test_target.squeeze()\n",
    "        with torch.no_grad():\n",
    "            gt = test_target.detach().cpu().numpy()\n",
    "            sen = sensitive_real.detach().cpu().numpy()\n",
    "            test_gt.extend(gt)\n",
    "            sense_gt.extend(sen)\n",
    "            # put image into the image encoder\n",
    "            test_input = test_input.to(device)\n",
    "\n",
    "            z = vlm.encode_image(test_input)\n",
    "            z = z/ z.norm(dim=1, keepdim=True)\n",
    "            \n",
    "            if a == True:\n",
    "                sensitive = sensitive_real\n",
    "            if a == False:\n",
    "                if partial_a == False:\n",
    "                    sensitive = inference_a(vlm, no_patch, patch,z )\n",
    "                    sensitive = torch.tensor(sensitive)\n",
    "                elif partial_a == True:\n",
    "                    sensitive = supervised_inference_a(img)\n",
    "            \n",
    "            mask_0 = sensitive == 0\n",
    "            mask_0 = mask_0.to(device)\n",
    "            z[mask_0] -= scale_0 * no_patch/ no_patch.norm(dim=1, keepdim=True)\n",
    "                \n",
    "            mask_1 = sensitive == 1\n",
    "            mask_1 = mask_1.to(device)\n",
    "            z[mask_1] -= scale_1 * patch/ patch.norm(dim=1, keepdim=True)\n",
    "            \n",
    "        \n",
    "            \n",
    "            \n",
    "            feature_a0.extend(z[mask_0].detach().cpu().numpy())\n",
    "            feature_a1.extend(z[mask_1].detach().cpu().numpy())\n",
    "            \n",
    "            text_embeddings = vlm.encode_text(text_label_tokened)\n",
    "            img_embeddings = z\n",
    "            norm_img_embeddings = img_embeddings / img_embeddings.norm(dim=1, keepdim=True)\n",
    "            norm_text_embeddings = text_embeddings / text_embeddings.norm(dim=1, keepdim=True)\n",
    "            cosine_similarity = torch.mm(norm_img_embeddings, norm_text_embeddings.t())\n",
    "                    \n",
    "            logits_per_image = cosine_similarity \n",
    "            probs = logits_per_image.softmax(dim=1)\n",
    "            _, predic = torch.max(probs.data, 1)\n",
    "            predic = predic.detach().cpu()\n",
    "            test_pred.extend(predic.numpy())\n",
    "            label = test_target.squeeze().detach().cpu()\n",
    "            mask_00 = ((label == 0) & (sensitive_real == 0))\n",
    "            mask_01 = ((label == 0) & (sensitive_real == 1))\n",
    "            mask_10 = ((label == 1) & (sensitive_real == 0))\n",
    "            mask_11 = ((label == 1) & (sensitive_real == 1))\n",
    "\n",
    "\n",
    "            correct_00 += (predic[mask_00] == label[mask_00]).float().sum().item()\n",
    "            total_00 += mask_00.float().sum().item()\n",
    "\n",
    "            correct_01 += (predic[mask_01] == label[mask_01]).float().sum().item()\n",
    "            total_01 += mask_01.float().sum().item()\n",
    "\n",
    "            correct_10 += (predic[mask_10] == label[mask_10]).float().sum().item()\n",
    "            total_10 += mask_10.float().sum().item()\n",
    "\n",
    "            correct_11 += (predic[mask_11] == label[mask_11]).float().sum().item()\n",
    "            total_11 += mask_11.float().sum().item() \n",
    "    acc_00 = correct_00 / total_00\n",
    "    acc_01 = correct_01 / total_01\n",
    "    acc_10 = correct_10 / total_10\n",
    "    acc_11 = correct_11 / (total_11+1e-9)\n",
    "\n",
    "    print(f'Accuracy for y=0, s=0: {acc_00}')\n",
    "    print(f'Accuracy for y=0, s=1: {acc_01}')\n",
    "    print(f'Accuracy for y=1, s=0: {acc_10}')\n",
    "    print(f'Accuracy for y=1, s=1: {acc_11}')       \n",
    "    \n",
    "    feature_a0 = np.array(feature_a0)\n",
    "    feature_a1 = np.array(feature_a1)\n",
    "    a0_tensor = torch.from_numpy(np.mean(feature_a0,0))\n",
    "    a1_tensor = torch.from_numpy(np.mean(feature_a1,0))\n",
    "\n",
    "    for i in range(len(sense_gt)):\n",
    "        if sense_gt[i] == 0:\n",
    "            female_predic.append(test_pred[i])\n",
    "            female_gt.append(test_gt[i])\n",
    "        else:\n",
    "            male_predic.append(test_pred[i])\n",
    "            male_gt.append(test_gt[i])\n",
    "    female_CM = confusion_matrix(female_gt, female_predic)    \n",
    "    male_CM = confusion_matrix(male_gt, male_predic) \n",
    "    female_dp = (female_CM[1][1]+female_CM[0][1])/(female_CM[0][0]+female_CM[0][1]+female_CM[1][0]+female_CM[1][1])\n",
    "    male_dp = (male_CM[1][1]+male_CM[0][1])/(male_CM[0][0]+male_CM[0][1]+male_CM[1][0]+male_CM[1][1])\n",
    "    female_TPR = female_CM[1][1]/(female_CM[1][1]+female_CM[1][0])\n",
    "    male_TPR = male_CM[1][1]/(male_CM[1][1]+male_CM[1][0])\n",
    "    female_FPR = female_CM[0][1]/(female_CM[0][1]+female_CM[0][0])\n",
    "    male_FPR = male_CM[0][1]/(male_CM[0][1]+male_CM[0][0])\n",
    "    acc = accuracy_score(test_gt, test_pred)\n",
    "    #print('Female TPR', female_TPR)\n",
    "    #print('male TPR', male_TPR)\n",
    "    print('DP',abs(female_dp - male_dp))\n",
    "    print('EOP', abs(female_TPR - male_TPR))\n",
    "    print('EoD',0.5*(abs(female_FPR-male_FPR)+ abs(female_TPR-male_TPR)))\n",
    "    print('acc', accuracy_score(test_gt, test_pred))\n",
    "\n",
    "a = True\n",
    "partial_a = False\n",
    "    \n",
    "\n",
    "model = model.to(device)\n",
    "#inference_a_test(model, no_patch, patch)\n",
    "test_epoch(model, test_data_loader)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:sean]",
   "language": "python",
   "name": "conda-env-sean-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
