{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Package requirement**: \n",
    "\n",
    "    - xgboost >= 1.4.0\n",
    "    - numpy, pandas, sklearn, h5py\n",
    "    - pytorch (for FMOW CLIP featurized only)\n",
    "    - wilds (including modify code to get more meta data field)\n",
    "    - clip\n",
    "    - tqdm\n",
    "**helper.py**: adjusting fine-tuning paramters"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Preprocess\n",
    "In this section we extract ImageNet pretrained models and CLIP based models. We consider models from three sources: \n",
    "\n",
    "    1. densenet models in the wilds paper (pretrained densenet121() ERM/ERM_ID/groupDRO/IRM)\n",
    "    2. CLIP ViT-32B and CLIP ViT-32B Logistic Regression \n",
    "    3. imagenet pretrained models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "import torch\n",
    "import torchvision\n",
    "import torch.nn as nn\n",
    "from collections import OrderedDict\n",
    "import fmow\n",
    "import wilds\n",
    "import torchvision.transforms as transforms\n",
    "from wilds.common.data_loaders import get_eval_loader\n",
    "from tqdm import tqdm\n",
    "\n",
    "classnames = [\"airport\", \"airport_hangar\", \"airport_terminal\", \"amusement_park\", \"aquaculture\",\n",
    "                  \"archaeological_site\", \"barn\", \"border_checkpoint\", \"burial_site\", \"car_dealership\",\n",
    "                  \"construction_site\", \"crop_field\", \"dam\", \"debris_or_rubble\", \"educational_institution\",\n",
    "                  \"electric_substation\", \"factory_or_powerplant\", \"fire_station\", \"flooded_road\", \"fountain\",\n",
    "                  \"gas_station\", \"golf_course\", \"ground_transportation_station\", \"helipad\", \"hospital\",\n",
    "                  \"impoverished_settlement\", \"interchange\", \"lake_or_pond\", \"lighthouse\", \"military_facility\",\n",
    "                  \"multi-unit_residential\", \"nuclear_powerplant\", \"office_building\", \"oil_or_gas_facility\", \"park\",\n",
    "                  \"parking_lot_or_garage\", \"place_of_worship\", \"police_station\", \"port\", \"prison\", \"race_track\",\n",
    "                  \"railway_bridge\", \"recreational_facility\", \"road_bridge\", \"runway\", \"shipyard\", \"shopping_mall\",\n",
    "                  \"single-unit_residential\", \"smokestack\", \"solar_farm\", \"space_facility\", \"stadium\", \"storage_tank\",\n",
    "                  \"surface_mine\", \"swimming_pool\", \"toll_booth\", \"tower\", \"tunnel_opening\", \"waste_disposal\",\n",
    "                  \"water_treatment_facility\", \"wind_farm\", \"zoo\"]\n",
    "\n",
    "transform_steps = [transforms.ToTensor(),\n",
    "                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]\n",
    "transform = transforms.Compose(transform_steps)\n",
    "\n",
    "device = torch.device(\"cuda\")\n",
    "\n",
    "def fmow_eval(loader, model, device):\n",
    "    criterion = nn.CrossEntropyLoss(reduction='none')\n",
    "    test_loss = 0\n",
    "    correct = 0\n",
    "    total = 0\n",
    "    model.eval()\n",
    "    loss = []\n",
    "    correct_arr = []\n",
    "    with torch.no_grad():\n",
    "            for images, labels, metadata in tqdm(loader):\n",
    "                inputs, targets = images.to(device), labels.to(device)\n",
    "                outputs = model(inputs)\n",
    "                loss.append(criterion(outputs, targets))\n",
    "                _, predicted = outputs.max(1)\n",
    "                total += targets.size(0)\n",
    "                correct += predicted.eq(targets).sum().item()\n",
    "                correct_arr.append(predicted.eq(targets))\n",
    "    return 100.*correct/total, torch.cat(loss).cpu().numpy(), torch.cat(correct_arr).cpu().numpy()\n",
    "\n",
    "\n",
    "def extract_meta(loader):\n",
    "    all_labels = []\n",
    "    all_metadata = []\n",
    "    with torch.no_grad():\n",
    "        for _, _, metadata in tqdm(loader):\n",
    "            all_metadata.append(metadata)\n",
    "    return torch.cat(all_metadata).cpu().numpy()\n",
    "\n",
    "\n",
    "## load current model\n",
    "def get_model(m_name):\n",
    "    model = torchvision.models.densenet121()\n",
    "    setattr(model, 'classifier', nn.Linear(1024, 62))\n",
    "    state_dict = torch.load('./fmow_pretrained_models/' + m_name + '/best_model.pth')['algorithm']\n",
    "    new_state_dict = OrderedDict()\n",
    "    for k, v in state_dict.items():\n",
    "        if k[:6] == 'model.':\n",
    "            name = k[6:]  # remove `module.`\n",
    "        else:\n",
    "            name = k\n",
    "        new_state_dict[name] = v\n",
    "    model.load_state_dict(new_state_dict, strict=False) \n",
    "    model = model.to(device)\n",
    "    return model\n",
    "\n",
    "## CLIP fine-tuned and imagenet pretrained \n",
    "def one_hot_y(y, num_class=62):\n",
    "    return np.eye(num_class)[y]\n",
    "\n",
    "def categorical_cross_entropy(y_true, y_pred):\n",
    "    y_pred = np.clip(y_pred, 1e-9, 1 - 1e-9)\n",
    "    return -(one_hot_y(y_true) * np.log(y_pred)).sum(axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Process dataset \n",
    "# meta is a dictionary. Key = split type; value = meta_data\n",
    "# meta_loader is a dictionary. Key = split type; value = meta_data_loader\n",
    "dataset = wilds.get_dataset(dataset='fmow', root_dir=path)\n",
    "\n",
    "split_type = ['val', 'ood_val', 'test']\n",
    "meta = dict.fromkeys(split_type)\n",
    "meta_loader = dict.fromkeys(split_type)\n",
    "for split in split_type:\n",
    "    cur_subset = dataset.get_subset(split, transform=transform)\n",
    "    cur_loader = get_eval_loader(\"standard\", cur_subset, num_workers=16, batch_size=64)\n",
    "    cur_meta = extract_meta(cur_loader)\n",
    "    meta[split] = pd.DataFrame(cur_meta, columns=dataset._metadata_fields)\n",
    "    meta_loader[split] = cur_loader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# fmow_pretrained_models \n",
    "model_list = os.listdir(\"./fmow_pretrained_models/\")\n",
    "model_list = [i for i in model_list if 'DS' not in i]\n",
    "\n",
    "for m_name in model_list:\n",
    "    print('curmodel: ' + m_name)\n",
    "    cur_model = get_model(m_name)\n",
    "    for split in split_type:\n",
    "        cur_loader = meta_loader[split]\n",
    "        acc, loss_array, acc_array = fmow_eval(cur_loader, cur_model, device)\n",
    "        meta[split][m_name + '_acc'] = acc_array\n",
    "        meta[split][m_name + '_loss'] = loss_array\n",
    "\n",
    "# save meta data and losses to csv\n",
    "for split in split_type:\n",
    "    meta[split].to_csv('meta_' + split + '.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# clip zero-shot\n",
    "import os\n",
    "import clip \n",
    "import torch\n",
    "\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "model, preprocess = clip.load('ViT-B/32', device)\n",
    "\n",
    "dataset = wilds.get_dataset(dataset='fmow', root_dir=path)\n",
    "\n",
    "def zeroshot_classifier(model, classnames, templates, device, skip_renorm=False):\n",
    "    all_texts = []\n",
    "    with torch.no_grad():\n",
    "        zeroshot_weights = []\n",
    "        for classname in tqdm(classnames):\n",
    "            texts = [template(classname) for template in templates] #format with class\n",
    "            for txt in texts: all_texts.append(txt)\n",
    "            texts = clip.tokenize(texts).to(device) #tokenize\n",
    "            class_embeddings = model.encode_text(texts) #embed with text encoder\n",
    "            class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)\n",
    "            class_embedding = class_embeddings.mean(dim=0)\n",
    "            if not skip_renorm:\n",
    "                class_embedding /= class_embedding.norm()\n",
    "            zeroshot_weights.append(class_embedding)\n",
    "        zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device)\n",
    "    return zeroshot_weights\n",
    "\n",
    "fmow_template = [\n",
    "    lambda c : f\"satellite photo of a {c}.\",\n",
    "    lambda c : f\"an image of a {c}.\",\n",
    "    lambda c : f\"satellite photo of an {c}.\",\n",
    "    lambda c : f\"an image of an {c}.\",\n",
    "    lambda c : f\"satellite photo of a {c} in asia.\",\n",
    "    lambda c : f\"an image of a {c} in asia.\",\n",
    "    lambda c : f\"satellite photo of a {c} in africa.\",\n",
    "    lambda c : f\"an image of a {c} in africa.\",\n",
    "    lambda c : f\"satellite photo of a {c} in the americas.\",\n",
    "    lambda c : f\"an image of a {c} in the americas.\",\n",
    "    lambda c : f\"satellite photo of a {c} in europe.\",\n",
    "    lambda c : f\"an image of a {c} in europe.\",\n",
    "    lambda c : f\"satellite photo of a {c} in oceania.\",\n",
    "    lambda c : f\"an image of a {c} in oceania.\",\n",
    "    lambda c: f\"a photo of a {c}.\",\n",
    "    lambda c: f\"a picture of a {c}.\",\n",
    "    lambda c: f\"{c}.\",\n",
    "]\n",
    "\n",
    "template = [t for t in fmow_template]\n",
    "classifier = zeroshot_classifier(model, classnames, template, device)\n",
    "np.save('./clip_folder/class_token.npy', classifier.cpu().numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "split_type = ['val', 'ood_val', 'test']\n",
    "meta_clip = dict.fromkeys(split_type)\n",
    "meta_clip_loader = dict.fromkeys(split_type)\n",
    "criterion = nn.CrossEntropyLoss(reduction='none')\n",
    "\n",
    "for split in split_type:\n",
    "    cur_subset = dataset.get_subset(split, transform=preprocess)\n",
    "    cur_loader = get_eval_loader(\"standard\", cur_subset, num_workers=16, batch_size=64)\n",
    "    cur_meta = extract_meta(cur_loader)\n",
    "    meta_clip[split] = pd.DataFrame(cur_meta, columns=dataset._metadata_fields)\n",
    "    meta_clip_loader[split] = cur_loader\n",
    "    \n",
    "    correct_arr = []\n",
    "    loss = []\n",
    "    \n",
    "    # for clip_fine_tune\n",
    "    all_features = []\n",
    "    all_labels = []\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        for images, targets, _ in tqdm(cur_loader):\n",
    "            images = images.to(device)\n",
    "            image_features = model.encode_image(images)\n",
    "            all_features.append(image_features)\n",
    "            \n",
    "            all_labels.append(targets)\n",
    "            targets = targets.to(device)\n",
    "\n",
    "            image_features /= image_features.norm(dim=-1, keepdim=True)\n",
    "            logits = 100.0 * image_features @ classifier\n",
    "            _, predicted = logits.max(1)\n",
    "            correct_arr.append(predicted.eq(targets))\n",
    "            loss.append(criterion(logits, targets))\n",
    "    meta_clip[split]['clip_acc'] = torch.cat(correct_arr).cpu().numpy()\n",
    "    meta_clip[split]['clip_loss'] = torch.cat(loss).cpu().numpy()\n",
    "    np.save('clip_features_' + split + '.npy', torch.cat(all_features).cpu().numpy())\n",
    "    np.save('clip_labels_' + split + '.npy', torch.cat(all_labels).cpu().numpy())\n",
    "\n",
    "# save models\n",
    "split_type = ['val', 'ood_val', 'test']\n",
    "meta = dict.fromkeys(split_type)\n",
    "for split in split_type:\n",
    "    meta[split] = pd.read_csv('meta_' + split + '_cvar.csv')\n",
    "    meta[split]['clip_acc'] = meta_clip[split]['clip_acc']\n",
    "    meta[split]['clip_loss'] = meta_clip[split]['clip_loss']\n",
    "    meta[split].to_csv('clip_meta_' + split + '_cvar.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# clip L2 fine-tune; clip linear probe\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.metrics import log_loss\n",
    "split_type = ['val', 'ood_val', 'test']\n",
    "features = dict.fromkeys(split_type)\n",
    "labels = dict.fromkeys(split_type)\n",
    "meta = dict.fromkeys(split_type)\n",
    "\n",
    "import torch\n",
    "\n",
    "train_features = torch.load(path)\n",
    "train_labels = torch.load(path)\n",
    "\n",
    "classifier = LogisticRegression(random_state=0, C=0.316, max_iter=100, verbose=1)\n",
    "classifier.fit(train_features, train_labels)\n",
    "\n",
    "for split in split_type:\n",
    "    meta[split] = pd.read_csv(path + split + '_clip.csv')\n",
    "    features[split] = np.load(path + split + '.npy')\n",
    "    labels[split] = np.load(path + split + '.npy')\n",
    "    \n",
    "    predictions = classifier.predict(features[split])\n",
    "    meta[split]['clip_linear_probe'] = predictions\n",
    "    meta[split]['clip_linear_probe_acc'] = (predictions == labels[split])\n",
    "    meta[split]['clip_linear_probe_loss'] = categorical_cross_entropy(labels[split], classifier.predict_proba(features[split]))\n",
    "    \n",
    "    meta[split].to_csv(split + '_clip.csv', index=False)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# other imagenet pretrained models\n",
    "import os \n",
    "\n",
    "split_type = ['val', 'ood_val', 'test']\n",
    "meta = dict.fromkeys(split_type)\n",
    "\n",
    "for split in split_type:\n",
    "    meta[split] = pd.read_csv(path + split + '_aggregate.csv') \n",
    "    \n",
    "model_list = ['imgnt_resnet18','imgnt_resnet50','imgnt_alexnet','imgnt_dpn68','imgnt_vgg11']\n",
    "\n",
    "for split in split_type:\n",
    "    cur_file = pd.read_csv('/imagenet_pretrained_models/' + split + '.csv')\n",
    "    for m in model_list:\n",
    "        meta[split][m+'_loss'] = cur_file[m+'_loss']\n",
    "        meta[split][m+'_acc'] = cur_file[m+'_acc']\n",
    "\n",
    "for split in split_type:\n",
    "    meta[split].to_csv(path + split + '_aggregate.csv', index=False) "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
