{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from PIL import Image\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import open_clip\n",
    "from tqdm import tqdm\n",
    "import random\n",
    "from sklearn.metrics import accuracy_score\n",
    "from util import *\n",
    "from wilds.datasets.wilds_dataset import WILDSDataset\n",
    "from wilds.common.grouper import CombinatorialGrouper\n",
    "from wilds.common.metrics.all_metrics import Accuracy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## modify dataset class for customed spurious attribute"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CelebADataset(WILDSDataset):\n",
    "\n",
    "    _dataset_name = 'celebA'\n",
    "    _versions_dict = {\n",
    "        '1.0': {\n",
    "            'download_url': 'https://worksheets.codalab.org/rest/bundles/0xfe55077f5cd541f985ebf9ec50473293/contents/blob/',\n",
    "            'compressed_size': 1_308_557_312}}\n",
    "\n",
    "    def __init__(self, version=None, root_dir='data', download=False, split_scheme='official'):\n",
    "        self._version = version\n",
    "        self._data_dir = self.initialize_data_dir(root_dir, download)\n",
    "        target_name = 'Blond_Hair'\n",
    "        confounder_names = ['Eyeglasses']\n",
    "\n",
    "        # Read in attributes\n",
    "        attrs_df = pd.read_csv(\n",
    "            os.path.join(self.data_dir, 'list_attr_celeba.csv'))\n",
    "\n",
    "        # Split out filenames and attribute names\n",
    "        # Note: idx and filenames are off by one.\n",
    "        self._input_array = attrs_df['image_id'].values\n",
    "        self._original_resolution = (178, 218)\n",
    "        attrs_df = attrs_df.drop(labels='image_id', axis='columns')\n",
    "        attr_names = attrs_df.columns.copy()\n",
    "        def attr_idx(attr_name):\n",
    "            return attr_names.get_loc(attr_name)\n",
    "\n",
    "        # Then cast attributes to numpy array and set them to 0 and 1\n",
    "        # (originally, they're -1 and 1)\n",
    "        attrs_df = attrs_df.values\n",
    "        attrs_df[attrs_df == -1] = 0\n",
    "\n",
    "        # Get the y values\n",
    "        target_idx = attr_idx(target_name)\n",
    "        self._y_array = torch.LongTensor(attrs_df[:, target_idx])\n",
    "        self._y_size = 1\n",
    "        self._n_classes = 2\n",
    "\n",
    "        # Get metadata\n",
    "        confounder_idx = [attr_idx(a) for a in confounder_names]\n",
    "        confounders = attrs_df[:, confounder_idx]\n",
    "\n",
    "        self._metadata_array = torch.cat(\n",
    "            (torch.LongTensor(confounders), self._y_array.reshape((-1, 1))),\n",
    "            dim=1)\n",
    "        confounder_names = [s.lower() for s in confounder_names]\n",
    "        self._metadata_fields = confounder_names + ['y']\n",
    "        self._metadata_map = {\n",
    "            'y': ['not blond', '    blond'] # Padding for str formatting\n",
    "        }\n",
    "\n",
    "        self._eval_grouper = CombinatorialGrouper(\n",
    "            dataset=self,\n",
    "            groupby_fields=(confounder_names + ['y']))\n",
    "\n",
    "        # Extract splits\n",
    "        self._split_scheme = split_scheme\n",
    "        if self._split_scheme != 'official':\n",
    "            raise ValueError(f'Split scheme {self._split_scheme} not recognized')\n",
    "        split_df = pd.read_csv(\n",
    "            os.path.join(self.data_dir, 'list_eval_partition.csv'))\n",
    "        self._split_array = split_df['partition'].values\n",
    "\n",
    "        super().__init__(root_dir, download, split_scheme)\n",
    "\n",
    "    def get_input(self, idx):\n",
    "       # Note: idx and filenames are off by one.\n",
    "       img_filename = os.path.join(\n",
    "           self.data_dir,\n",
    "           'img_align_celeba',\n",
    "           self._input_array[idx])\n",
    "       x = Image.open(img_filename).convert('RGB')\n",
    "       return x\n",
    "\n",
    "    def eval(self, y_pred, y_true, metadata, prediction_fn=None):\n",
    "        metric = Accuracy(prediction_fn=prediction_fn)\n",
    "        return self.standard_group_eval(\n",
    "            metric,\n",
    "            self._eval_grouper,\n",
    "            y_pred, y_true, metadata)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load OpenCLIP model and tokenizer\n",
    "model_name = 'ViT-B-32'\n",
    "model, _, preprocess = open_clip.create_model_and_transforms(model_name, pretrained='openai')\n",
    "tokenizer = open_clip.get_tokenizer(model_name)\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "model = model.to(device)\n",
    "embedding_dim = model.visual.output_dim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = CelebADataset(root_dir='data', download=False)\n",
    "dataset.confounder_names = ['Eyeglasses']\n",
    "dataset._metadata_fields = ['eyeglasses', 'y']  # Update metadata fields\n",
    "# Reload attribute file\n",
    "attrs_df = pd.read_csv(f\"{dataset.data_dir}/list_attr_celeba.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "attrs_df.iloc[:, 1:] = (attrs_df.iloc[:, 1:] == 1).astype(int)\n",
    "def attr_idx(attr_name):\n",
    "    return attrs_df.columns.get_loc(attr_name)\n",
    "\n",
    "confounder_idx = attr_idx('Eyeglasses')\n",
    "dataset._metadata_array[:, 0] = torch.LongTensor(attrs_df.iloc[:, confounder_idx].values)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_data = dataset.get_subset('test')\n",
    "test_loader = DataLoader(\n",
    "    test_data, \n",
    "    batch_size=256, \n",
    "    shuffle=False, \n",
    "    num_workers=8,\n",
    "    collate_fn=lambda batch: (\n",
    "        torch.stack([preprocess(img) for img, _, _ in batch]),  # inputs\n",
    "        torch.tensor([label for _, label, _ in batch]),  # labels\n",
    "       torch.stack([attr for _, _, attr in batch])  # attributes\n",
    "    )\n",
    "    )\n",
    "num_images = len(test_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "19962"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "num_images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 78/78 [00:14<00:00,  5.40it/s]\n"
     ]
    }
   ],
   "source": [
    "# Tensors to store embeddings, labels, and gender attributes\n",
    "all_image_features = torch.empty((num_images, embedding_dim), dtype=torch.float32, device=device)\n",
    "all_labels = torch.empty((num_images,), dtype=torch.long, device=device)\n",
    "num_attributes = test_data[0][2].shape[0]  # Assuming test_data returns (image, label, attributes)\n",
    "\n",
    "# Preallocate tensor for all attributes\n",
    "all_celeba_attributes = torch.empty((num_images, num_attributes), dtype=torch.long, device=device)\n",
    "# all_celeba_attributes = torch.empty((num_images,), dtype=torch.long, device=device)  # Store attribute as long tensor\n",
    "# all_celeba_attributes = torch.empty((num_images, 3), dtype=torch.long, device=device)\n",
    "start_idx = 0\n",
    "\n",
    "for inputs, labels, celeba_attributes in tqdm(test_loader):\n",
    "    # Move the inputs, labels, and attributes to the correct device (GPU)\n",
    "    inputs = inputs.to(device)\n",
    "    labels = labels.to(device)\n",
    "    celeba_attributes = celeba_attributes.to(device)\n",
    "\n",
    "    with torch.no_grad():\n",
    "        # Encode the images to get their embeddings\n",
    "        image_features = model.encode_image(inputs)  # (batch_size, embedding_dim)\n",
    "        image_features = image_features.to(dtype=torch.float32, device=device)\n",
    "\n",
    "    # Store the batch embeddings, labels, and attributes in the preallocated tensors\n",
    "    end_idx = start_idx + image_features.size(0)\n",
    "    all_image_features[start_idx:end_idx] = image_features\n",
    "    all_labels[start_idx:end_idx] = labels\n",
    "    all_celeba_attributes[start_idx:end_idx] = celeba_attributes\n",
    "    start_idx = end_idx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compute predictions\n",
    "class_names = ['person with dark hair', 'person with blond hair']\n",
    "text_inputs = tokenizer(class_names).to(device)\n",
    "\n",
    "with torch.no_grad():\n",
    "    text_features = model.encode_text(text_inputs).cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_image_features = all_image_features.to('cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "predictions = all_image_features @ text_features.T\n",
    "preds = predictions.argmax(dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "preds = preds.to('cpu')\n",
    "all_labels = all_labels.to('cpu')\n",
    "all_celeba_attributes = all_celeba_attributes.to('cpu')\n",
    "result, results_str = dataset.eval(preds, all_labels, all_celeba_attributes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average acc: 0.812\n",
      "  y = not blond, eyeglasses = 0  [n =  16075]:\tacc = 0.790\n",
      "  y = not blond, eyeglasses = 1  [n =   1227]:\tacc = 0.742\n",
      "  y =     blond, eyeglasses = 0  [n =   2598]:\tacc = 0.976\n",
      "  y =     blond, eyeglasses = 1  [n =     62]:\tacc = 0.887\n",
      "Worst-group acc: 0.742\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(results_str)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "random.seed(1234)\n",
    "k = 50\n",
    "L = glaplacian(all_image_features)\n",
    "U, S, Vt = randomized_svd(L, n_components=k)\n",
    "S_diag = np.diag(S)\n",
    "Z = U @ S_diag\n",
    "Z_rotated, Rz = varimax_with_rotation(Z)\n",
    "sign_Z = np.diag(np.where(np.mean(Z_rotated**3, axis=0) >= 0, 1, -1))\n",
    "Z_hat = Z_rotated @ sign_Z\n",
    "Y_hat = sign_Z @ Rz.T @ Vt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## reconstructed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.7807333934475503\n",
      "Average acc: 0.781\n",
      "  y = not blond, eyeglasses = 0  [n =  16075]:\tacc = 0.749\n",
      "  y = not blond, eyeglasses = 1  [n =   1227]:\tacc = 0.770\n",
      "  y =     blond, eyeglasses = 0  [n =   2598]:\tacc = 0.981\n",
      "  y =     blond, eyeglasses = 1  [n =     62]:\tacc = 0.919\n",
      "Worst-group acc: 0.749\n",
      "\n"
     ]
    }
   ],
   "source": [
    "reconstructed = torch.tensor(Z_hat @ Y_hat)\n",
    "reconstructed = reconstructed.to(torch.float)\n",
    "reconstructed = reconstructed.to('cpu')\n",
    "recons_preds = reconstructed @ text_features.T\n",
    "recons_preds = recons_preds.argmax(dim=-1)\n",
    "correct = (recons_preds == all_labels).sum().item()\n",
    "total = all_labels.size(0)\n",
    "accuracy = correct / total\n",
    "print(accuracy)\n",
    "recons_preds = recons_preds.to('cpu')\n",
    "all_labels = all_labels.to('cpu')\n",
    "_, results_str = dataset.eval(recons_preds, all_labels, all_celeba_attributes)\n",
    "print(results_str)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## spurious"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = open_clip.get_tokenizer(model_name)\n",
    "# text_descriptions = [\n",
    "#     # Spurious Concept: Eyewear-related features\n",
    "#     \"Male facial characteristics including jawline structure and facial hair, seen across all hair colors\",\n",
    "#     # Useful Concept: Hair color-related features\n",
    "#     \"Dark-haired individuals with varying facial structures and features typical across both men and women\"\n",
    "# ]\n",
    "\n",
    "text_descriptions = [\n",
    "    # Spurious concept: Person wearing glasses\n",
    "    \"A photo focusing on person's eyewear like eyeglasses\",\n",
    "    # Useful concept: Hair color\n",
    "    \"A photo focusing on person's hair color, texture and shape\"\n",
    "]\n",
    "# text_descriptions = [\n",
    "#     # Spurious Concept: Eyewear-related features\n",
    "#     \"A photo focusing on the person's eyewear.\",\n",
    "#     # Useful Concept: Hair color-related features\n",
    "#     \"A photo focusing on the person’s hair color.\"\n",
    "# ]\n",
    "inputs = tokenizer(text_descriptions).to(device)\n",
    "    \n",
    "with torch.no_grad():\n",
    "    # Get text embeddings\n",
    "    des_embeddings = model.encode_text(inputs)\n",
    "    des_embeddings_np = des_embeddings.to('cpu').numpy()\n",
    "concept_des = Y_hat @ des_embeddings_np.T"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "text_features = text_features.to('cuda')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.8257689610259493\n",
      "Average acc: 0.826\n",
      "  y = not blond, eyeglasses = 0  [n =  16075]:\tacc = 0.813\n",
      "  y = not blond, eyeglasses = 1  [n =   1227]:\tacc = 0.751\n",
      "  y =     blond, eyeglasses = 0  [n =   2598]:\tacc = 0.941\n",
      "  y =     blond, eyeglasses = 1  [n =     62]:\tacc = 0.871\n",
      "Worst-group acc: 0.751\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_2330857/3186743575.py:10: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  all_labels = torch.tensor(all_labels).to(device)\n"
     ]
    }
   ],
   "source": [
    "spurious_rows = np.where(concept_des[:, 0] > concept_des[:, 1])[0].tolist()\n",
    "Z_mask = Z_hat.copy()\n",
    "cols_to_ablate = spurious_rows\n",
    "# col_means = Z_hat[:, cols_to_ablate].mean(axis=0)\n",
    "Z_mask[:, cols_to_ablate] = 0\n",
    "reconstructed = torch.tensor(Z_mask @ Y_hat).to(device)\n",
    "reconstructed = reconstructed.to(torch.float)\n",
    "recons_preds = reconstructed @ text_features.T\n",
    "recons_preds = recons_preds.argmax(dim=-1)\n",
    "all_labels = torch.tensor(all_labels).to(device)\n",
    "correct = (recons_preds == all_labels).sum().item()\n",
    "total = all_labels.size(0)\n",
    "accuracy = correct / total\n",
    "print(accuracy)\n",
    "recons_preds = recons_preds.to('cpu')\n",
    "all_labels = all_labels.to('cpu')\n",
    "# _, results_str = dataset.eval(recons_preds, all_labels, group_mask)\n",
    "_, results_str = dataset.eval(recons_preds, all_labels, all_celeba_attributes)\n",
    "print(results_str)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "2020-exotda",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
