{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "4b0b5f9b-3e01-43a3-bd3b-addff770ac5a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import os.path as osp\n",
    "import numpy as np\n",
    "\n",
    "import torch\n",
    "import torchvision.datasets as datasets\n",
    "import torchvision.transforms as transforms\n",
    "from torch.utils.data import Subset, DataLoader, Dataset\n",
    "from wilds.datasets.fmow_dataset import FMoWDataset\n",
    "from PIL import Image\n",
    "from tqdm import trange\n",
    "from sklearn.metrics import confusion_matrix\n",
    "import open_clip\n",
    "\n",
    "torch.set_num_threads(5)   \n",
    "torch.set_num_interop_threads(5)   \n",
    "\n",
    "device = torch.device(\"cuda:2\" if torch.cuda.is_available() else \"cpu\")\n",
    "model,_, preprocess =  open_clip.create_model_and_transforms(\"ViT-L-14\", pretrained='laion2b_s32b_b82k') \n",
    "model = model.to(device)\n",
    "tokenizer = open_clip.get_tokenizer('ViT-L-14')\n",
    "\n",
    "batch_size = 256\n",
    "\n",
    "root_dir = r\"../../../Dataset/data\"\n",
    "\n",
    "def get_transform():\n",
    "    transform = transforms.Compose([\n",
    "            transforms.Resize((224,224)),\n",
    "            transforms.ToTensor(),\n",
    "        ])\n",
    "    return transform\n",
    "    \n",
    "train_transform = get_transform()    \n",
    "dataset = FMoWDataset(root_dir=root_dir, download=False)\n",
    "train_data = dataset.get_subset('train')\n",
    "val_data = dataset.get_subset('val')\n",
    "test_data = dataset.get_subset('test')\n",
    "\n",
    "\n",
    "# set group str\n",
    "metadata_str = dataset.metadata_map['region']\n",
    "meadata_idx = dataset.metadata_fields.index('region')\n",
    "metadata = dataset._metadata_array = dataset._metadata_array[:, meadata_idx]\n",
    "train_data.group_str = val_data.group_str = test_data.group_str = lambda id: metadata_str[id]\n",
    "\n",
    "# set n_groups\n",
    "n_groups = len(metadata_str)\n",
    "setattr(train_data, 'n_groups', n_groups)\n",
    "setattr(val_data, 'n_groups', n_groups)\n",
    "setattr(test_data, 'n_groups', n_groups)\n",
    "\n",
    "# set group counts\n",
    "split_dict = dataset.split_dict\n",
    "split_array = dataset.split_array\n",
    "y_array = dataset.y_array\n",
    "\n",
    "train_data.get_group_array = metadata[split_array == split_dict['train']]\n",
    "val_data.get_group_array = metadata[split_array == split_dict['val']]\n",
    "test_data.get_group_array = metadata[split_array == split_dict['test']]\n",
    "\n",
    "class ConfounderDataset(Dataset):\n",
    "    def __init__(self):\n",
    "        self.dataset = dataset\n",
    "        raise NotImplementedError\n",
    "\n",
    "    def __len__(self):\n",
    "        if self.dataset == \"train\":\n",
    "            return len(train_data)\n",
    "        if self.dataset == \"val\":\n",
    "            return len(val_data)\n",
    "        if self.dataset == \"test\":\n",
    "            return len(test_data)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        if self.dataset == \"train\":\n",
    "            img, y, _ = train_data[idx]\n",
    "            a = self.train_a[idx]\n",
    "            x = preprocess(img)\n",
    "            return x, y, a\n",
    "            \n",
    "        if self.dataset == \"val\":\n",
    "            img, y, _ = val_data[idx]\n",
    "            a = self.val_a[idx]\n",
    "            x = preprocess(img)\n",
    "            return x, y, a\n",
    "\n",
    "        if self.dataset == \"test\":\n",
    "            img, y, _ = test_data[idx]\n",
    "            a = self.test_a[idx]\n",
    "            x = preprocess(img)\n",
    "            return x, y, a\n",
    "            \n",
    "\n",
    "\n",
    "class FMOW(ConfounderDataset):\n",
    "    def __init__(self, dataset):\n",
    "        self.dataset = dataset\n",
    "        train_data.get_group_array = metadata[split_array == split_dict['train']]\n",
    "        val_data.get_group_array = metadata[split_array == split_dict['val']]\n",
    "        test_data.get_group_array = metadata[split_array == split_dict['test']]\n",
    "        self.train_a = train_data.get_group_array\n",
    "        self.train_y = train_data.y_array\n",
    "        self.val_a = val_data.get_group_array\n",
    "        self.val_y = val_data.y_array\n",
    "        self.test_a = test_data.get_group_array\n",
    "        self.test_y = test_data.y_array\n",
    "        \n",
    "\n",
    "\n",
    "\n",
    "training_dataset = FMOW('train')\n",
    "test_dataset = FMOW('test')\n",
    "\n",
    "training_data_loader  = torch.utils.data.DataLoader(dataset = training_dataset,\n",
    "                                                batch_size= batch_size,\n",
    "                                                shuffle=False,\n",
    "                                                num_workers=16,\n",
    "                                                drop_last=True)\n",
    "\n",
    "test_data_loader  = torch.utils.data.DataLoader(dataset = test_dataset,\n",
    "                                                batch_size= batch_size,\n",
    "                                                shuffle=False,\n",
    "                                                num_workers=16,\n",
    "                                                drop_last=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "804491e4-c282-4612-9142-a553af3247fc",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Computing Scale: 100%|################################################################| 305/305 [21:15<00:00,  4.18s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.14512545\n",
      "0.1435668\n",
      "0.16128178\n",
      "0.13966314\n",
      "0.1632383\n",
      "['A satellite image of airport.', 'A satellite image of airport hangar.', 'A satellite image of airport terminal.', 'A satellite image of amusement park.', 'A satellite image of aquaculture.', 'A satellite image of archaeological site.', 'A satellite image of barn.', 'A satellite image of border checkpoint.', 'A satellite image of burial site.', 'A satellite image of car dealership.', 'A satellite image of construction site.', 'A satellite image of crop field.', 'A satellite image of dam.', 'A satellite image of debris or rubble.', 'A satellite image of educational institution.', 'A satellite image of electric substation.', 'A satellite image of factory or powerplant.', 'A satellite image of fire station.', 'A satellite image of flooded road.', 'A satellite image of fountain.', 'A satellite image of gas station.', 'A satellite image of golf course.', 'A satellite image of ground transportation station.', 'A satellite image of helipad.', 'A satellite image of hospital.', 'A satellite image of impoverished settlement.', 'A satellite image of interchange.', 'A satellite image of lake or pond.', 'A satellite image of lighthouse.', 'A satellite image of military facility.', 'A satellite image of multi-unit residential.', 'A satellite image of nuclear powerplant.', 'A satellite image of office building.', 'A satellite image of oil or gas facility.', 'A satellite image of park.', 'A satellite image of parking lot or garage.', 'A satellite image of place of worship.', 'A satellite image of police station.', 'A satellite image of port.', 'A satellite image of prison.', 'A satellite image of race track.', 'A satellite image of railway bridge.', 'A satellite image of recreational facility.', 'A satellite image of road bridge.', 'A satellite image of runway.', 'A satellite image of shipyard.', 'A satellite image of shopping mall.', 'A satellite image of single-unit residential.', 'A satellite image of smokestack.', 'A satellite image of solar farm.', 'A satellite image of space facility.', 'A satellite image of stadium.', 'A satellite image of storage tank.', 'A satellite image of surface mine.', 'A satellite image of swimming pool.', 'A satellite image of toll booth.', 'A satellite image of tower.', 'A satellite image of tunnel opening.', 'A satellite image of waste disposal.', 'A satellite image of water treatment facility.', 'A satellite image of wind farm.', 'A satellite image of zoo.']\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Zero Shot Testing: 100%|################################################################| 87/87 [07:34<00:00,  5.23s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy for s=0: 0.2037\n",
      "Accuracy for s=1: 0.2750\n",
      "Accuracy for s=2: 0.2019\n",
      "Accuracy for s=3: 0.3062\n",
      "Accuracy for s=4: 0.4242\n",
      "Total accuracy: 0.2662\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "from torch import nn\n",
    "\n",
    "texts = [\"Over Europe.\", \"Over Asia\", \"Over America.\", \"Over Africa\", \"Over Oceania\"]\n",
    "text = tokenizer(texts).to(device)\n",
    "text_features = model.encode_text(text)\n",
    "Eurobg = text_features[0].unsqueeze(0)\n",
    "Asiabg = text_features[1].unsqueeze(0)\n",
    "Americabg = text_features[2].unsqueeze(0)\n",
    "Africabg = text_features[3].unsqueeze(0)\n",
    "Oceaniabg = text_features[4].unsqueeze(0)\n",
    "\n",
    "\n",
    "def process_group(mask, spu, z):\n",
    "    mask = mask.to(device)\n",
    "    subset = z[mask]\n",
    "    inner_result = torch.mm(subset / subset.norm(dim=1, keepdim=True), spu.t())\n",
    "    return inner_result.detach().cpu().numpy()            \n",
    "    \n",
    "def compute_scale(vlm, spu_v0, spu_v1, spu_v2, spu_v3, spu_v4):\n",
    "    vlm = vlm.to(device)\n",
    "    scale_0 = []\n",
    "    scale_1 = []\n",
    "    scale_2 = []\n",
    "    scale_3 = []\n",
    "    scale_4 = []\n",
    "\n",
    "    \n",
    "    spu0 = spu_v0  / spu_v0.norm(dim=1, keepdim=True)\n",
    "    spu1 = spu_v1 / spu_v1.norm(dim=1, keepdim=True)\n",
    "    spu2 = spu_v2  / spu_v2.norm(dim=1, keepdim=True)\n",
    "    spu3 = spu_v3 / spu_v3.norm(dim=1, keepdim=True)\n",
    "    spu4 = spu_v4  / spu_v4.norm(dim=1, keepdim=True)\n",
    "\n",
    "    \n",
    "    for step, (test_input, _, sensitive ) in enumerate(tqdm(training_data_loader, desc=\"Computing Scale\", dynamic_ncols=False, ascii=True)):\n",
    "        with torch.no_grad():\n",
    "            \n",
    "            \n",
    "            # put image into the image encoder\n",
    "            test_input = test_input.to(device)\n",
    "            z = vlm.encode_image(test_input)\n",
    "            if a ==True:\n",
    "                sensitive = sensitive\n",
    "            else:\n",
    "                if partial_a == False:\n",
    "                    sensitive = inference_a(vlm, landbg, waterbg,z )\n",
    "                elif partial_a == True:\n",
    "                    sensitive = supervised_inference_a(img)\n",
    "            \n",
    "            \n",
    "            mask_0 = sensitive == 0\n",
    "            scale_0.extend(process_group(mask_0, spu0, z))\n",
    "            \n",
    "            mask_1 = sensitive == 1\n",
    "            scale_1.extend(process_group(mask_1, spu1, z))\n",
    "\n",
    "            mask_2 = sensitive == 2\n",
    "            scale_2.extend(process_group(mask_2, spu2, z))\n",
    "            \n",
    "            mask_3 = sensitive == 3\n",
    "            scale_3.extend(process_group(mask_3, spu3, z))\n",
    "\n",
    "            mask_4 = sensitive == 4\n",
    "            scale_4.extend(process_group(mask_4, spu4, z))\n",
    "\n",
    "\n",
    "\n",
    "    \n",
    "    scale_0 = np.array(scale_0)\n",
    "    scale_1 = np.array(scale_1)\n",
    "    scale_2 = np.array(scale_2)\n",
    "    scale_3 = np.array(scale_3)\n",
    "    scale_4 = np.array(scale_4)\n",
    "    \n",
    "    print(np.mean(scale_0))\n",
    "    print(np.mean(scale_1))\n",
    "    print(np.mean(scale_2))\n",
    "    print(np.mean(scale_3))\n",
    "    print(np.mean(scale_4))\n",
    "\n",
    "    return torch.tensor(np.mean(scale_0)), torch.tensor(np.mean(scale_1)), torch.tensor(np.mean(scale_2)), torch.tensor(np.mean(scale_3)), torch.tensor(np.mean(scale_4))\n",
    "\n",
    "\n",
    "\n",
    "def test_epoch(vlm,   dataloader):\n",
    "\n",
    "    \n",
    "\n",
    "\n",
    "    scale_0, scale_1, scale_2, scale_3, scale_4 = compute_scale(model, Eurobg, Asiabg,Americabg, Africabg, Oceaniabg)\n",
    "    texts = [\"airport\", \"airport_hangar\", \"airport_terminal\", \"amusement_park\", \"aquaculture\", \"archaeological_site\", \"barn\", \"border_checkpoint\", \n",
    "         \"burial_site\", \"car_dealership\", \"construction_site\", \"crop_field\", \"dam\", \"debris_or_rubble\", \"educational_institution\", \n",
    "         \"electric_substation\", \"factory_or_powerplant\", \"fire_station\", \"flooded_road\", \"fountain\", \"gas_station\", \"golf_course\", \n",
    "         \"ground_transportation_station\", \"helipad\", \"hospital\", \"impoverished_settlement\", \"interchange\", \"lake_or_pond\", \"lighthouse\", \n",
    "         \"military_facility\", \"multi-unit_residential\", \"nuclear_powerplant\", \"office_building\", \"oil_or_gas_facility\", \"park\", \n",
    "         \"parking_lot_or_garage\", \"place_of_worship\", \"police_station\", \"port\", \"prison\", \"race_track\", \"railway_bridge\", \"recreational_facility\", \n",
    "         \"road_bridge\", \"runway\", \"shipyard\", \"shopping_mall\", \"single-unit_residential\", \"smokestack\", \"solar_farm\", \"space_facility\", \"stadium\", \n",
    "         \"storage_tank\", \"surface_mine\", \"swimming_pool\", \"toll_booth\", \"tower\", \"tunnel_opening\", \"waste_disposal\", \"water_treatment_facility\", \n",
    "         \"wind_farm\", \"zoo\"\n",
    "        ]\n",
    "    expanded_texts =  [\"A satellite image of \" + text.replace('_', ' ') + \".\" for text in texts]\n",
    "    print(expanded_texts)\n",
    "    text_label_tokened = tokenizer(expanded_texts).to(device)\n",
    "    text_embeddings = vlm.encode_text(text_label_tokened)\n",
    "    norm_text_embeddings = text_embeddings / text_embeddings.norm(dim=1, keepdim=True)\n",
    "\n",
    "   \n",
    "    \n",
    "    vlm = vlm.to(device)\n",
    "    vlm.eval()   \n",
    "    test_pred = []\n",
    "    test_gt = []\n",
    "    sense_gt = []\n",
    "    cos = nn.CosineSimilarity(dim = 0)\n",
    "\n",
    "    correct_0, total_0 = 0, 0\n",
    "    correct_1, total_1 = 0, 0\n",
    "    correct_2, total_2 = 0, 0\n",
    "    correct_3, total_3 = 0, 0\n",
    "    correct_4, total_4 = 0, 0\n",
    "    \n",
    "    total_correct = 0\n",
    "    total_count = 0\n",
    "    num_classes = 5  # Assuming there are 5 classes as indicated by your masks\n",
    "\n",
    "    # Initialize arrays to hold the counts\n",
    "    correct = [0] * num_classes\n",
    "    total = [0] * num_classes\n",
    "\n",
    "    for step, (test_input, test_target, sensitive_real) in enumerate(tqdm(dataloader, desc=\"Zero Shot Testing\", dynamic_ncols=False, ascii=True)):\n",
    "        with torch.no_grad():\n",
    "            gt = test_target.detach().cpu().numpy()\n",
    "            sen = sensitive_real.detach().cpu().numpy()\n",
    "            test_gt.extend(gt)\n",
    "            sense_gt.extend(sen)\n",
    "            # put image into the image encoder\n",
    "            test_input = test_input.to(device)\n",
    "\n",
    "            text_label_tokened\n",
    "            z = vlm.encode_image(test_input)\n",
    "            z = z/ z.norm(dim=1, keepdim=True)\n",
    "            \n",
    "            if a == True:\n",
    "                sensitive = sensitive_real\n",
    "            if a == False:\n",
    "                if partial_a == False:\n",
    "                    sensitive = inference_a(vlm, landbg, waterbg,z )\n",
    "                    sensitive = torch.tensor(sensitive)\n",
    "                elif partial_a == True:\n",
    "                    sensitive = supervised_inference_a(img)\n",
    "\n",
    "            \n",
    "            mask_0 = sensitive == 0\n",
    "            mask_0 = mask_0.to(device)\n",
    "            z[mask_0] -= scale_0 * Eurobg/ Eurobg.norm(dim=1, keepdim=True)\n",
    "                \n",
    "            mask_1 = sensitive == 1\n",
    "            mask_1 = mask_1.to(device)\n",
    "            z[mask_1] -= scale_1 * Asiabg/ Asiabg.norm(dim=1, keepdim=True)\n",
    "\n",
    "            mask_2 = sensitive == 2\n",
    "            mask_2 = mask_2.to(device)\n",
    "            z[mask_2] -= scale_2 * Americabg/ Americabg.norm(dim=1, keepdim=True)\n",
    "\n",
    "            mask_3 = sensitive == 3\n",
    "            mask_3 = mask_3.to(device)\n",
    "            z[mask_3] -= scale_3 * Africabg/ Africabg.norm(dim=1, keepdim=True)\n",
    "\n",
    "            mask_4 = sensitive == 4\n",
    "            mask_4 = mask_4.to(device)\n",
    "            z[mask_4] -= scale_4 * Oceaniabg/ Oceaniabg.norm(dim=1, keepdim=True)\n",
    "            \n",
    "            \n",
    " \n",
    "            img_embeddings = z\n",
    "            norm_img_embeddings = img_embeddings / img_embeddings.norm(dim=1, keepdim=True)\n",
    "            \n",
    "            cosine_similarity = torch.mm(norm_img_embeddings, norm_text_embeddings.t())\n",
    "            logits_per_image = cosine_similarity             \n",
    "            probs = logits_per_image.softmax(dim=1)\n",
    "            _, predic = torch.max(probs.data, 1)\n",
    "            predic = predic.detach().cpu()\n",
    "            \n",
    "            test_pred.extend(predic.numpy())\n",
    "            label = test_target.squeeze().detach().cpu()\n",
    "            \n",
    "            \n",
    "            # Compute the masks and update the counts in a loop\n",
    "            for i in range(num_classes):\n",
    "                mask = (sensitive_real == i)\n",
    "                correct_predictions = (predic[mask] == label[mask]).float().sum().item()\n",
    "                count = mask.float().sum().item()\n",
    "            \n",
    "                correct[i] += correct_predictions\n",
    "                total[i] += count\n",
    "            \n",
    "                # Update total correct and count\n",
    "                total_correct += correct_predictions\n",
    "                total_count += count\n",
    "            \n",
    "            # Compute accuracies\n",
    "    accuracies = [correct[i] / total[i] if total[i] != 0 else 0 for i in range(num_classes)]\n",
    "    total_accuracy = total_correct / total_count if total_count != 0 else 0\n",
    "    for i in range(num_classes):\n",
    "        print(f'Accuracy for s={i}: {accuracies[i]:.4f}')\n",
    "    print(f'Total accuracy: {total_accuracy:.4f}')\n",
    "\n",
    "a = True\n",
    "partial_a = False\n",
    "    \n",
    "\n",
    "model = model.to(device)\n",
    "#inference_a_test(model, landbg, waterbg)\n",
    "test_epoch(model, test_data_loader)\n",
    "\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python (DLcourse)",
   "language": "python",
   "name": "dlcourse"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
