{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from PIL import Image\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import open_clip\n",
    "from tqdm import tqdm\n",
    "import random\n",
    "from sklearn.metrics import accuracy_score\n",
    "from util import *\n",
    "from wilds.datasets.wilds_dataset import WILDSDataset\n",
    "from wilds.common.grouper import CombinatorialGrouper\n",
    "from wilds.common.metrics.all_metrics import Accuracy\n",
    "from wilds import get_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load OpenCLIP model and tokenizer\n",
    "model_name = 'ViT-B-32'\n",
    "model, _, preprocess = open_clip.create_model_and_transforms(model_name, pretrained='openai')\n",
    "tokenizer = open_clip.get_tokenizer(model_name)\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "model = model.to(device)\n",
    "embedding_dim = model.visual.output_dim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = get_dataset('iwildcam', download=True)\n",
    "test_data = dataset.get_subset('test')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "test_loader = DataLoader(\n",
    "    test_data, \n",
    "    batch_size=256, \n",
    "    shuffle=False, \n",
    "    num_workers=8,\n",
    "    collate_fn=lambda batch: (\n",
    "        torch.stack([preprocess(img) for img, _, _ in batch]),  # inputs\n",
    "        torch.tensor([label for _, label, _ in batch]),  # labels\n",
    "       torch.stack([attr for _, _, attr in batch])  # attributes\n",
    "    )\n",
    "    )\n",
    "num_images = len(test_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 168/168 [00:37<00:00,  4.50it/s]\n"
     ]
    }
   ],
   "source": [
    "# Tensors to store embeddings, labels, and gender attributes\n",
    "all_image_features = torch.empty((num_images, embedding_dim), dtype=torch.float32, device=device)\n",
    "all_labels = torch.empty((num_images,), dtype=torch.long, device=device)\n",
    "num_attributes = test_data[0][2].shape[0]  # Assuming test_data returns (image, label, attributes)\n",
    "\n",
    "# Preallocate tensor for all attributes\n",
    "all_attributes = torch.empty((num_images, num_attributes), dtype=torch.long, device=device)\n",
    "start_idx = 0\n",
    "\n",
    "for inputs, labels, attributes in tqdm(test_loader):\n",
    "    # Move the inputs, labels, and attributes to the correct device (GPU)\n",
    "    inputs = inputs.to(device)\n",
    "    labels = labels.to(device)\n",
    "    attributes = attributes.to(device)\n",
    "\n",
    "    with torch.no_grad():\n",
    "        # Encode the images to get their embeddings\n",
    "        image_features = model.encode_image(inputs)  # (batch_size, embedding_dim)\n",
    "        image_features = image_features.to(dtype=torch.float32, device=device)\n",
    "\n",
    "    # Store the batch embeddings, labels, and attributes in the preallocated tensors\n",
    "    end_idx = start_idx + image_features.size(0)\n",
    "    all_image_features[start_idx:end_idx] = image_features\n",
    "    all_labels[start_idx:end_idx] = labels\n",
    "    all_attributes[start_idx:end_idx] = attributes\n",
    "    start_idx = end_idx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_1359385/2110215827.py:3: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  features = torch.tensor(all_image_features).to(device)\n"
     ]
    }
   ],
   "source": [
    "# Move tensors to device\n",
    "labels = all_labels.to(device)\n",
    "features = torch.tensor(all_image_features).to(device)\n",
    "\n",
    "# Get unique labels from our labels tensor\n",
    "unique_labels = sorted(torch.unique(labels).cpu().numpy())\n",
    "# Create simple descriptions for each category\n",
    "class_names = [f'a wild animal of category {i}' for i in unique_labels]\n",
    "\n",
    "# Tokenize and encode text descriptions\n",
    "text_inputs = tokenizer(class_names).to(device)\n",
    "\n",
    "with torch.no_grad():\n",
    "    # Get text embeddings\n",
    "    text_embeddings = model.encode_text(text_inputs)\n",
    "    \n",
    "    # Compute similarity scores\n",
    "    similarity = features @ text_embeddings.T\n",
    "    \n",
    "    # Get predictions\n",
    "    preds = similarity.argmax(dim=-1)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "predictions = all_image_features @ text_embeddings.T\n",
    "preds = predictions.argmax(dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "preds = preds.to('cpu')\n",
    "all_labels = all_labels.to('cpu')\n",
    "all_attributes = all_attributes.to('cpu')\n",
    "result, results_str = dataset.eval(preds, all_labels, all_attributes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average acc: 0.062\n",
      "Recall macro: 0.002\n",
      "F1 macro: 0.001\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(results_str)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "random.seed(42)\n",
    "all_image_features = all_image_features.to('cpu')\n",
    "k = 50\n",
    "L = glaplacian(all_image_features)\n",
    "U, S, Vt = randomized_svd(L, n_components=k)\n",
    "S_diag = np.diag(S)\n",
    "Z = U @ S_diag\n",
    "Z_rotated, Rz = varimax_with_rotation(Z)\n",
    "sign_Z = np.diag(np.where(np.mean(Z_rotated**3, axis=0) >= 0, 1, -1))\n",
    "Z_hat = Z_rotated @ sign_Z\n",
    "Y_hat = sign_Z @ Rz.T @ Vt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## reconstructed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.03834918557640626\n",
      "Average acc: 0.038\n",
      "Recall macro: 0.001\n",
      "F1 macro: 0.001\n",
      "\n"
     ]
    }
   ],
   "source": [
    "reconstructed = torch.tensor(Z_hat @ Y_hat)\n",
    "reconstructed = reconstructed.to(torch.float)\n",
    "reconstructed = reconstructed.to('cpu')\n",
    "text_embeddings = text_embeddings.to('cpu')\n",
    "recons_preds = reconstructed @ text_embeddings.T\n",
    "recons_preds = recons_preds.argmax(dim=-1)\n",
    "correct = (recons_preds == all_labels).sum().item()\n",
    "total = all_labels.size(0)\n",
    "accuracy = correct / total\n",
    "print(accuracy)\n",
    "# recons_preds = recons_preds.to('cuda')\n",
    "# all_labels = all_labels.to('cuda')\n",
    "_, results_str = dataset.eval(recons_preds, all_labels, all_attributes)\n",
    "print(results_str)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## spurious"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = open_clip.get_tokenizer(model_name)\n",
    "\n",
    "text_descriptions = [\n",
    "    'A natural scene or environment showing only background elements like vegetation, terrain, or sky with no living creatures visible',\n",
    "    'Wildlife captured in any lighting condition - including animals photographed during day, night, or by camera traps, showing any species like birds, mammals, or reptiles'\n",
    "]\n",
    "\n",
    "inputs = tokenizer(text_descriptions).to(device)\n",
    "    \n",
    "with torch.no_grad():\n",
    "    # Get text embeddings\n",
    "    des_embeddings = model.encode_text(inputs)\n",
    "    des_embeddings_np = des_embeddings.to('cpu').numpy()\n",
    "concept_des = Y_hat @ des_embeddings_np.T"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "text_features = text_embeddings.to('cuda')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.1879600850646164\n",
      "Average acc: 0.188\n",
      "Recall macro: 0.006\n",
      "F1 macro: 0.003\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_1359385/1119039344.py:10: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  all_labels = torch.tensor(all_labels).to(device)\n"
     ]
    }
   ],
   "source": [
    "spurious_rows = np.where(concept_des[:, 0] > concept_des[:, 1])[0].tolist()\n",
    "Z_mask = Z_hat.copy()\n",
    "cols_to_ablate = spurious_rows\n",
    "# col_means = Z_hat[:, cols_to_ablate].mean(axis=0)\n",
    "Z_mask[:, cols_to_ablate] = 0\n",
    "reconstructed = torch.tensor(Z_mask @ Y_hat).to(device)\n",
    "reconstructed = reconstructed.to(torch.float)\n",
    "recons_preds = reconstructed @ text_features.T\n",
    "recons_preds = recons_preds.argmax(dim=-1)\n",
    "all_labels = torch.tensor(all_labels).to(device)\n",
    "correct = (recons_preds == all_labels).sum().item()\n",
    "total = all_labels.size(0)\n",
    "accuracy = correct / total\n",
    "print(accuracy)\n",
    "recons_preds = recons_preds.to('cpu')\n",
    "all_labels = all_labels.to('cpu')\n",
    "_, results_str = dataset.eval(recons_preds, all_labels, all_attributes)\n",
    "print(results_str)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "2020-exotda",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
