{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "af6d7eb5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Done\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/shenyu/miniconda3/envs/DLcourse/lib/python3.7/site-packages/torchvision/transforms/transforms.py:891: UserWarning: Argument interpolation should be of type InterpolationMode instead of int. Please, use InterpolationMode enum.\n",
      "  \"Argument interpolation should be of type InterpolationMode instead of int. \"\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>image_id</th>\n",
       "      <th>partition</th>\n",
       "      <th>5_o_Clock_Shadow</th>\n",
       "      <th>Arched_Eyebrows</th>\n",
       "      <th>Attractive</th>\n",
       "      <th>Bags_Under_Eyes</th>\n",
       "      <th>Bald</th>\n",
       "      <th>Bangs</th>\n",
       "      <th>Big_Lips</th>\n",
       "      <th>Big_Nose</th>\n",
       "      <th>...</th>\n",
       "      <th>Smiling</th>\n",
       "      <th>Straight_Hair</th>\n",
       "      <th>Wavy_Hair</th>\n",
       "      <th>Wearing_Earrings</th>\n",
       "      <th>Wearing_Hat</th>\n",
       "      <th>Wearing_Lipstick</th>\n",
       "      <th>Wearing_Necklace</th>\n",
       "      <th>Wearing_Necktie</th>\n",
       "      <th>Young</th>\n",
       "      <th>split</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>000001.jpg</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>000002.jpg</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>...</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>000003.jpg</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>000004.jpg</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>000005.jpg</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>202594</th>\n",
       "      <td>202595.jpg</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>202595</th>\n",
       "      <td>202596.jpg</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>202596</th>\n",
       "      <td>202597.jpg</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>202597</th>\n",
       "      <td>202598.jpg</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>202598</th>\n",
       "      <td>202599.jpg</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>202599 rows × 43 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "          image_id  partition  5_o_Clock_Shadow  Arched_Eyebrows  Attractive  \\\n",
       "0       000001.jpg          0                 0                1           1   \n",
       "1       000002.jpg          0                 0                0           0   \n",
       "2       000003.jpg          0                 0                0           0   \n",
       "3       000004.jpg          0                 0                0           1   \n",
       "4       000005.jpg          0                 0                1           1   \n",
       "...            ...        ...               ...              ...         ...   \n",
       "202594  202595.jpg          2                 0                0           1   \n",
       "202595  202596.jpg          2                 0                0           0   \n",
       "202596  202597.jpg          2                 0                0           0   \n",
       "202597  202598.jpg          2                 0                1           1   \n",
       "202598  202599.jpg          2                 0                1           1   \n",
       "\n",
       "        Bags_Under_Eyes  Bald  Bangs  Big_Lips  Big_Nose  ...  Smiling  \\\n",
       "0                     0     0      0         0         0  ...        1   \n",
       "1                     1     0      0         0         1  ...        1   \n",
       "2                     0     0      0         1         0  ...        0   \n",
       "3                     0     0      0         0         0  ...        0   \n",
       "4                     0     0      0         1         0  ...        0   \n",
       "...                 ...   ...    ...       ...       ...  ...      ...   \n",
       "202594                0     0      0         1         0  ...        0   \n",
       "202595                0     0      1         1         0  ...        1   \n",
       "202596                0     0      0         0         0  ...        1   \n",
       "202597                0     0      0         1         0  ...        1   \n",
       "202598                0     0      0         0         0  ...        0   \n",
       "\n",
       "        Straight_Hair  Wavy_Hair  Wearing_Earrings  Wearing_Hat  \\\n",
       "0                   1          0                 1            0   \n",
       "1                   0          0                 0            0   \n",
       "2                   0          1                 0            0   \n",
       "3                   1          0                 1            0   \n",
       "4                   0          0                 0            0   \n",
       "...               ...        ...               ...          ...   \n",
       "202594              0          0                 0            0   \n",
       "202595              1          0                 0            0   \n",
       "202596              0          0                 0            0   \n",
       "202597              0          1                 1            0   \n",
       "202598              0          1                 0            0   \n",
       "\n",
       "        Wearing_Lipstick  Wearing_Necklace  Wearing_Necktie  Young  split  \n",
       "0                      1                 0                0      1      0  \n",
       "1                      0                 0                0      1      0  \n",
       "2                      0                 0                0      1      0  \n",
       "3                      1                 1                0      1      0  \n",
       "4                      1                 0                0      1      0  \n",
       "...                  ...               ...              ...    ...    ...  \n",
       "202594                 1                 0                0      1      2  \n",
       "202595                 0                 0                0      1      2  \n",
       "202596                 0                 0                0      1      2  \n",
       "202597                 1                 0                0      1      2  \n",
       "202598                 1                 0                0      1      2  \n",
       "\n",
       "[202599 rows x 43 columns]"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import clip\n",
    "import torchvision.datasets\n",
    "import math\n",
    "import torchvision.transforms as tvt\n",
    "import os\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import wget\n",
    "import zipfile\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "import torchvision.datasets as datasets\n",
    "import torchvision.models as models\n",
    "import torchvision.transforms as tfms\n",
    "from torch.utils.data import DataLoader, Subset, Dataset\n",
    "from torchvision.utils import make_grid\n",
    "from PIL import Image\n",
    "from tqdm import tqdm\n",
    "import random\n",
    "from sklearn.metrics import accuracy_score, precision_score\n",
    "from sklearn.metrics import confusion_matrix\n",
    "import open_clip\n",
    "\n",
    "device = torch.device(\"cuda:1\")\n",
    "\n",
    "# model,_, preprocess =  open_clip.create_model_and_transforms(\"ViT-B/32\", pretrained='openai') #ViTB/32\n",
    "# model = model.to(device)\n",
    "# tokenizer = open_clip.get_tokenizer('ViT-B-32')\n",
    "\n",
    "\n",
    "\n",
    "# model, preprocess = clip.load('RN50', device)\n",
    "# model = model.to(device)\n",
    "# tokenizer = open_clip.get_tokenizer('RN50')\n",
    "\n",
    "\n",
    "model,_, preprocess =  open_clip.create_model_and_transforms(\"ViT-L-14\", pretrained='laion2b_s32b_b82k') #ViTB/32\n",
    "model = model.to(device)\n",
    "tokenizer = open_clip.get_tokenizer('ViT-L-14')\n",
    "\n",
    "\n",
    "\n",
    "torch.set_num_threads(5)   # Sets the number of threads used for intra-operations\n",
    "torch.set_num_interop_threads(5)   # Sets the number of threads used for inter-operations\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "def seed_everything(seed):\n",
    "    \"\"\"\n",
    "    Changes the seed for reproducibility. \n",
    "    \"\"\"\n",
    "    random.seed(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "    torch.backends.cudnn.benchmark = False\n",
    "    \n",
    "seed_everything(4096)   \n",
    "\n",
    "device_id = 1\n",
    "image_size = 64\n",
    "batch_size = 512\n",
    "\n",
    "\n",
    "\n",
    "root_dir =  '../celeba/datasets/celeba/img_align_celeba/'\n",
    "csv_file = '../celeba/datasets/celeba/metadata.csv'\n",
    "data_frame = pd.read_csv(csv_file)\n",
    "data_frame.replace(-1, 0, inplace=True)\n",
    "\n",
    "\n",
    "\n",
    "class CustomDataset(Dataset):\n",
    "    def __init__(self, csv_file, y, a, root_dir,split, transform):\n",
    "        self.data_frame = csv_file\n",
    "        self.data_frame = self.data_frame[self.data_frame['split'] == split].reset_index(drop=True)\n",
    "        self.root_dir = root_dir\n",
    "        self.transform = transform\n",
    "        self.targets = self.data_frame[y].values\n",
    "        self.biases = self.data_frame[a].values\n",
    "        self.i = list(range(self.targets.shape[0]))\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.data_frame)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        img_name = os.path.join(self.root_dir, self.data_frame.iloc[idx, 0])\n",
    "        image = Image.open(img_name)\n",
    "        target = self.targets[idx]\n",
    "        targets = torch.tensor(target) # assuming target is a class index\n",
    "        sensitive = self.biases[idx]\n",
    "        biases = torch.tensor(sensitive)\n",
    "        img = preprocess(image)\n",
    "        img_for_res = self.transform(image)\n",
    "            \n",
    "        return img, targets, biases, img_for_res\n",
    "\n",
    "target = 'Blond_Hair'\n",
    "sensitive = 'Male'\n",
    "    \n",
    "transform=tvt.Compose([tvt.Resize((256,256)),\n",
    "                               tvt.RandomResizedCrop(\n",
    "                                    (224,224),\n",
    "                                    scale=(0.7, 1.0),\n",
    "                                    ratio=(0.75, 1.3333333333333333),\n",
    "                                    interpolation=2),\n",
    "                                tvt.RandomHorizontalFlip(),\n",
    "                                tvt.ToTensor(),\n",
    "                                tvt.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])                                  \n",
    "                                ])\n",
    "\n",
    "\n",
    "valid_transform=tvt.Compose([tvt.Resize((256,256)),\n",
    "                               tvt.CenterCrop((224,224)),\n",
    "                                tvt.ToTensor(),\n",
    "                                tvt.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])                                  \n",
    "                                ])\n",
    "\n",
    "train_set = CustomDataset(csv_file=data_frame, y = target, a= sensitive, root_dir=root_dir, split=0, transform=transform)\n",
    "training_data_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)\n",
    "valid_set = CustomDataset(csv_file=data_frame, y = target, a= sensitive, root_dir=root_dir, split=1, transform=valid_transform)\n",
    "valid_data_loader = DataLoader(valid_set, batch_size=batch_size, shuffle=False)\n",
    "test_set = CustomDataset(csv_file=data_frame, y = target, a= sensitive, root_dir=root_dir, split=2, transform=valid_transform)\n",
    "test_data_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)\n",
    "print('Done')\n",
    "data_frame"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "bb2035a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "t = iter(training_data_loader)\n",
    "image,y,a,_ = next(t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "b48dc353",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([512, 3, 224, 224])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "image.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "03b15ca2",
   "metadata": {},
   "outputs": [],
   "source": [
    "texts = [\"A photo of a male\", \"A photo of a female\"] \n",
    "texts = [\"male\", \"female\"] \n",
    "text = tokenizer(texts).to(device)\n",
    "text_features = model.encode_text(text)\n",
    "male = text_features[0].unsqueeze(0).to(device)\n",
    "female = text_features[1].unsqueeze(0).to(device)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "5a11eb7a",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Computing Scale: 100%|███████████████████████████████████████████████████████████████████████| 318/318 [30:45<00:00,  5.80s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.20750567\n",
      "0.19806518\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Zero Shot Testing: 100%|███████████████████████████████████████████████████████████████████████| 39/39 [03:34<00:00,  5.49s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy for y=0, s=0: 0.8467287805876933\n",
      "Accuracy for y=0, s=1: 0.8460517584605176\n",
      "Accuracy for y=1, s=0: 0.9661290322580646\n",
      "Accuracy for y=1, s=1: 0.8888888888888888\n",
      "DP 0.14677849604709628\n",
      "EOP 0.07724014336917573\n",
      "EoD 0.038958582748175694\n",
      "acc 0.8616872056908126\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "def training_a():\n",
    "    epoch =200\n",
    "    weight_decay=1e-3\n",
    "    init_lr=1e-4\n",
    "    momentum_decay = 0.9\n",
    "    schedule = False\n",
    "    resnet18 = models.resnet18(pretrained=True)\n",
    "    num_classes = 2 \n",
    "    resnet18.fc = nn.Linear(resnet18.fc.in_features, num_classes)\n",
    "    \n",
    "    res_model = resnet18\n",
    "    res_model = res_model.to(device)\n",
    "    criterion = nn.CrossEntropyLoss()\n",
    "    optimizer = optim.SGD(res_model.parameters(), lr=init_lr, momentum=momentum_decay, weight_decay = weight_decay)\n",
    "    for epoches in range(epoch):\n",
    "        with tqdm(training_data_loader, unit=\"batch\") as tepoch:\n",
    "            res_model.train()         \n",
    "            for _, _, sensitive, train_input in tepoch:  #img_clip, y, a, x_for_resnet\n",
    "                train_input = train_input.to(device)\n",
    "                label = sensitive.detach().cpu()\n",
    "                one_hot_labels = F.one_hot(label, num_classes=2)\n",
    "                train_target = one_hot_labels.float().to(device)\n",
    "                outputs = res_model(train_input)\n",
    "                loss = criterion(outputs, train_target)\n",
    "                tepoch.set_postfix(ut_loss = loss.item()) \n",
    "                optimizer.zero_grad()    \n",
    "                loss.backward()\n",
    "                optimizer.step()\n",
    "                tepoch.set_description(f\"epoch %2f \" % epoches)\n",
    "        \n",
    "        if schedule:\n",
    "            scheduler.step()\n",
    "    super_a_test(res_model)\n",
    "    torch.save(res_model.state_dict(), 'res_net_celebA.pth')\n",
    "    \n",
    "def super_a_test(model):\n",
    "    model.eval()\n",
    "    correct_00, total_00 = 0, 0\n",
    "    correct_01, total_01 = 0, 0\n",
    "    correct_10, total_10 = 0, 0\n",
    "    correct_11, total_11 = 0, 0\n",
    "    \n",
    "    for step, (_, test_target, sensitive,test_input) in enumerate(tqdm(test_data_loader, desc=\"Testing\")):\n",
    "        with torch.no_grad():\n",
    "            test_target = test_target\n",
    "            sensitive = sensitive\n",
    "            test_input = test_input.to(device)\n",
    "\n",
    "            test_pred_ = model(test_input)\n",
    "            _, predic = torch.max(test_pred_.data, 1)\n",
    "            predic = predic.detach().cpu()\n",
    "                \n",
    "            \n",
    "            mask_00 = ((test_target == 0) & (sensitive == 0))\n",
    "            mask_01 = ((test_target == 0) & (sensitive == 1))\n",
    "            mask_10 = ((test_target == 1) & (sensitive == 0))\n",
    "            mask_11 = ((test_target == 1) & (sensitive == 1))\n",
    "\n",
    "\n",
    "            correct_00 += (predic[mask_00] == sensitive[mask_00]).float().sum().item()\n",
    "            total_00 += mask_00.float().sum().item()\n",
    "\n",
    "            correct_01 += (predic[mask_01] == sensitive[mask_01]).float().sum().item()\n",
    "            total_01 += mask_01.float().sum().item()\n",
    "\n",
    "            correct_10 += (predic[mask_10] == sensitive[mask_10]).float().sum().item()\n",
    "            total_10 += mask_10.float().sum().item()\n",
    "\n",
    "            correct_11 += (predic[mask_11] == sensitive[mask_11]).float().sum().item()\n",
    "            total_11 += mask_11.float().sum().item() \n",
    "            \n",
    "            \n",
    "    acc_00 = correct_00 / total_00\n",
    "    acc_01 = correct_01 / total_01\n",
    "    acc_10 = correct_10 / total_10\n",
    "    acc_11 = correct_11 / total_11\n",
    "\n",
    "    print(f'Accuracy for y=0, s=0: {acc_00}')\n",
    "    print(f'Accuracy for y=0, s=1: {acc_01}')\n",
    "    print(f'Accuracy for y=1, s=0: {acc_10}')\n",
    "    print(f'Accuracy for y=1, s=1: {acc_11}')   \n",
    "\n",
    "    \n",
    "\n",
    "def inference_a_test(vlm, spu_v0, spu_v1):\n",
    "    correct_00, total_00 = 0, 0\n",
    "    correct_01, total_01 = 0, 0\n",
    "    correct_10, total_10 = 0, 0\n",
    "    correct_11, total_11 = 0, 0\n",
    "    \n",
    "    for step, (test_input, test_target, sensitive, _) in enumerate(tqdm(test_data_loader, desc=\"Testing\")):\n",
    "        with torch.no_grad():\n",
    "            test_target = test_target.to(device)\n",
    "            sensitive = sensitive.to(device)\n",
    "            test_input = test_input.to(device)\n",
    "            z = vlm.encode_image(test_input)\n",
    "            infered_a = inference_a(vlm, female, male,z )\n",
    "            \n",
    "            mask_00 = ((test_target == 0) & (sensitive == 0))\n",
    "            mask_01 = ((test_target == 0) & (sensitive == 1))\n",
    "            mask_10 = ((test_target == 1) & (sensitive == 0))\n",
    "            mask_11 = ((test_target == 1) & (sensitive == 1))\n",
    "\n",
    "\n",
    "            correct_00 += (infered_a[mask_00] == sensitive[mask_00]).float().sum().item()\n",
    "            total_00 += mask_00.float().sum().item()\n",
    "\n",
    "            correct_01 += (infered_a[mask_01] == sensitive[mask_01]).float().sum().item()\n",
    "            total_01 += mask_01.float().sum().item()\n",
    "\n",
    "            correct_10 += (infered_a[mask_10] == sensitive[mask_10]).float().sum().item()\n",
    "            total_10 += mask_10.float().sum().item()\n",
    "\n",
    "            correct_11 += (infered_a[mask_11] == sensitive[mask_11]).float().sum().item()\n",
    "            total_11 += mask_11.float().sum().item() \n",
    "    acc_00 = correct_00 / total_00\n",
    "    acc_01 = correct_01 / total_01\n",
    "    acc_10 = correct_10 / total_10\n",
    "    acc_11 = correct_11 / total_11\n",
    "\n",
    "    print(f'Accuracy for y=0, s=0: {acc_00}')\n",
    "    print(f'Accuracy for y=0, s=1: {acc_01}')\n",
    "    print(f'Accuracy for y=1, s=0: {acc_10}')\n",
    "    print(f'Accuracy for y=1, s=1: {acc_11}')   \n",
    "\n",
    "            \n",
    "\n",
    "\n",
    "\n",
    "def inference_a(vlm, spu_v0, spu_v1, z):\n",
    "    text_embeddings = torch.cat((spu_v0, spu_v1), dim=0)\n",
    "    norm_img_embeddings = z \n",
    "    norm_text_embeddings = text_embeddings / text_embeddings.norm(dim=1, keepdim=True)\n",
    "    cosine_similarity = torch.mm(norm_img_embeddings, norm_text_embeddings.t())\n",
    "    logits_per_image = cosine_similarity \n",
    "    probs = logits_per_image.softmax(dim=1)\n",
    "    _, predic = torch.max(probs.data, 1)\n",
    "    return predic\n",
    "\n",
    "            \n",
    "def supervised_inference_a(img):\n",
    "    resnet18 = models.resnet18(pretrained=False)\n",
    "    num_classes = 2 \n",
    "    resnet18.fc = nn.Linear(resnet18.fc.in_features, num_classes)\n",
    "    res_model = resnet18\n",
    "    res_model.load_state_dict(torch.load('res_net.pth'))\n",
    "    res_model = res_model.to(device)\n",
    "    res_model.eval()\n",
    "    img = img.to(device)\n",
    "    test_pred_ = res_model(img)\n",
    "    _, predic = torch.max(test_pred_.data, 1)\n",
    "    return predic            \n",
    "            \n",
    "    \n",
    "def compute_scale(vlm, spu_v0, spu_v1):\n",
    "    vlm = vlm.to(device)\n",
    "    scale_0 = []\n",
    "    scale_1 = []\n",
    "    spu0 = spu_v0  / spu_v0.norm(dim=1, keepdim=True)\n",
    "    spu1 = spu_v1 / spu_v1.norm(dim=1, keepdim=True)\n",
    "    \n",
    "    for step, (test_input, _, sensitive, img) in enumerate(tqdm(training_data_loader, desc=\"Computing Scale\")):  ###\n",
    "        with torch.no_grad():\n",
    "            \n",
    "            \n",
    "            # put image into the image encoder\n",
    "            test_input = test_input.to(device)\n",
    "            z = vlm.encode_image(test_input)\n",
    "            if a ==True:\n",
    "                sensitive = sensitive\n",
    "            else:\n",
    "                if partial_a == False:\n",
    "                    sensitive = inference_a(vlm, female, male,z )\n",
    "                elif partial_a == True:\n",
    "                    sensitive = supervised_inference_a(img)\n",
    "            \n",
    "            \n",
    "            mask_0 = sensitive == 0\n",
    "            mask_0 = mask_0.to(device)\n",
    "            h = z[mask_0]\n",
    "            inner_land = torch.mm(h/ h.norm(dim=1, keepdim=True), spu0.t())\n",
    "            scale_0.extend(inner_land.detach().cpu().numpy())\n",
    "                \n",
    "            mask_1 = sensitive == 1\n",
    "            mask_1 = mask_1.to(device)\n",
    "            g = z[mask_1]\n",
    "            inner_water = torch.mm(g/ g.norm(dim=1, keepdim=True), spu1.t())\n",
    "            scale_1.extend(inner_water.detach().cpu().numpy())\n",
    "    scale_0 = np.array(scale_0)\n",
    "    scale_1 = np.array(scale_1)\n",
    "    print(np.mean(scale_0))\n",
    "    print(np.mean(scale_1))\n",
    "    return torch.tensor(np.mean(scale_0)), torch.tensor(np.mean(scale_1))\n",
    "\n",
    "\n",
    "\n",
    "def test_epoch(vlm,   dataloader):\n",
    "    scale_0, scale_1 = compute_scale(model, female, male)\n",
    "    texts_label = ['a photo of a celebrity with dark hair', 'a photo of a celebrity with blonde hair']  \n",
    "    text_label_tokened = tokenizer(texts_label).to(device)\n",
    "    \n",
    "    vlm = vlm.to(device)\n",
    "    vlm.eval()   \n",
    "    test_pred = []\n",
    "    test_gt = []\n",
    "    sense_gt = []\n",
    "    female_predic = []\n",
    "    female_gt = []\n",
    "    male_predic = []\n",
    "    male_gt = []\n",
    "    correct_00, total_00 = 0, 0\n",
    "    correct_01, total_01 = 0, 0\n",
    "    correct_10, total_10 = 0, 0\n",
    "    correct_11, total_11 = 0, 0\n",
    "    cos = nn.CosineSimilarity(dim = 0)\n",
    "    feature_a0 = []\n",
    "    feature_a1 = []\n",
    "\n",
    "    for step, (test_input, test_target, sensitive_real,img) in enumerate(tqdm(dataloader, desc=\"Zero Shot Testing\")):\n",
    "        with torch.no_grad():\n",
    "            gt = test_target.detach().cpu().numpy()\n",
    "            sen = sensitive_real.detach().cpu().numpy()\n",
    "            test_gt.extend(gt)\n",
    "            sense_gt.extend(sen)\n",
    "            # put image into the image encoder\n",
    "            test_input = test_input.to(device)\n",
    "\n",
    "\n",
    "            z = vlm.encode_image(test_input)\n",
    "            #z = z/ z.norm(dim=1, keepdim=True)\n",
    "            \n",
    "            if a == True:\n",
    "                sensitive = sensitive_real\n",
    "            if a == False:\n",
    "                if partial_a == False:\n",
    "                    sensitive = inference_a(vlm, female, male,z )\n",
    "                    sensitive = torch.tensor(sensitive)\n",
    "                elif partial_a == True:\n",
    "                    sensitive = supervised_inference_a(img)\n",
    "            \n",
    "            mask_0 = sensitive == 0\n",
    "            mask_0 = mask_0.to(device)\n",
    "            z[mask_0] -= scale_0 * female/ female.norm(dim=1, keepdim=True)\n",
    "                \n",
    "            mask_1 = sensitive == 1\n",
    "            mask_1 = mask_1.to(device)\n",
    "            z[mask_1] -= scale_1 * male/ male.norm(dim=1, keepdim=True)\n",
    "            \n",
    "        \n",
    "            \n",
    "            \n",
    "            feature_a0.extend(z[mask_0].detach().cpu().numpy())\n",
    "            feature_a1.extend(z[mask_1].detach().cpu().numpy())\n",
    "            \n",
    "            text_embeddings = vlm.encode_text(text_label_tokened)\n",
    "            img_embeddings = z\n",
    "            norm_img_embeddings = img_embeddings / img_embeddings.norm(dim=1, keepdim=True)\n",
    "            norm_text_embeddings = text_embeddings / text_embeddings.norm(dim=1, keepdim=True)\n",
    "            cosine_similarity = torch.mm(norm_img_embeddings, norm_text_embeddings.t())\n",
    "                    \n",
    "            logits_per_image = cosine_similarity \n",
    "            probs = logits_per_image.softmax(dim=1)\n",
    "            _, predic = torch.max(probs.data, 1)\n",
    "            predic = predic.detach().cpu()\n",
    "            test_pred.extend(predic.numpy())\n",
    "            label = test_target.squeeze().detach().cpu()\n",
    "            mask_00 = ((label == 0) & (sensitive_real == 0))\n",
    "            mask_01 = ((label == 0) & (sensitive_real == 1))\n",
    "            mask_10 = ((label == 1) & (sensitive_real == 0))\n",
    "            mask_11 = ((label == 1) & (sensitive_real == 1))\n",
    "\n",
    "\n",
    "            correct_00 += (predic[mask_00] == label[mask_00]).float().sum().item()\n",
    "            total_00 += mask_00.float().sum().item()\n",
    "\n",
    "            correct_01 += (predic[mask_01] == label[mask_01]).float().sum().item()\n",
    "            total_01 += mask_01.float().sum().item()\n",
    "\n",
    "            correct_10 += (predic[mask_10] == label[mask_10]).float().sum().item()\n",
    "            total_10 += mask_10.float().sum().item()\n",
    "\n",
    "            correct_11 += (predic[mask_11] == label[mask_11]).float().sum().item()\n",
    "            total_11 += mask_11.float().sum().item() \n",
    "    acc_00 = correct_00 / total_00\n",
    "    acc_01 = correct_01 / total_01\n",
    "    acc_10 = correct_10 / total_10\n",
    "    acc_11 = correct_11 / total_11\n",
    "\n",
    "    print(f'Accuracy for y=0, s=0: {acc_00}')\n",
    "    print(f'Accuracy for y=0, s=1: {acc_01}')\n",
    "    print(f'Accuracy for y=1, s=0: {acc_10}')\n",
    "    print(f'Accuracy for y=1, s=1: {acc_11}')       \n",
    "    \n",
    "    feature_a0 = np.array(feature_a0)\n",
    "    feature_a1 = np.array(feature_a1)\n",
    "    a0_tensor = torch.from_numpy(np.mean(feature_a0,0))\n",
    "    a1_tensor = torch.from_numpy(np.mean(feature_a1,0))\n",
    "\n",
    "    for i in range(len(sense_gt)):\n",
    "        if sense_gt[i] == 0:\n",
    "            female_predic.append(test_pred[i])\n",
    "            female_gt.append(test_gt[i])\n",
    "        else:\n",
    "            male_predic.append(test_pred[i])\n",
    "            male_gt.append(test_gt[i])\n",
    "    female_CM = confusion_matrix(female_gt, female_predic)    \n",
    "    male_CM = confusion_matrix(male_gt, male_predic) \n",
    "    female_dp = (female_CM[1][1]+female_CM[0][1])/(female_CM[0][0]+female_CM[0][1]+female_CM[1][0]+female_CM[1][1])\n",
    "    male_dp = (male_CM[1][1]+male_CM[0][1])/(male_CM[0][0]+male_CM[0][1]+male_CM[1][0]+male_CM[1][1])\n",
    "    female_TPR = female_CM[1][1]/(female_CM[1][1]+female_CM[1][0])\n",
    "    male_TPR = male_CM[1][1]/(male_CM[1][1]+male_CM[1][0])\n",
    "    female_FPR = female_CM[0][1]/(female_CM[0][1]+female_CM[0][0])\n",
    "    male_FPR = male_CM[0][1]/(male_CM[0][1]+male_CM[0][0])\n",
    "    acc = accuracy_score(test_gt, test_pred)\n",
    "    #print('Female TPR', female_TPR)\n",
    "    #print('male TPR', male_TPR)\n",
    "    print('DP',abs(female_dp - male_dp))\n",
    "    print('EOP', abs(female_TPR - male_TPR))\n",
    "    print('EoD',0.5*(abs(female_FPR-male_FPR)+ abs(female_TPR-male_TPR)))\n",
    "    print('acc', accuracy_score(test_gt, test_pred))\n",
    "\n",
    "a = True\n",
    "partial_a = False\n",
    "    \n",
    "\n",
    "model = model.to(device)\n",
    "#inference_a_test(model, female, male)\n",
    "test_epoch(model, test_data_loader)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:DLcourse]",
   "language": "python",
   "name": "conda-env-DLcourse-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
