{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import glob\n",
    "import copy\n",
    "from tqdm import tqdm\n",
    "from typing import List\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torchvision as tv\n",
    "from torch.utils.data import Subset, Dataset, DataLoader\n",
    "\n",
    "from source.constants import RESULTS_PATH\n",
    "from source.data.medical_imaging import get_chexpert, TransformWrapper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "method_seeds = [42, 142, 242, 342, 442]\n",
    "dseed = 42\n",
    "\n",
    "model = [\"resnet18\", \"resnet34\", \"resnet50\", \"regnet\", \"efficientnet\"][3]\n",
    "\n",
    "device = \"cuda:7\"\n",
    "batch_size = 2048 # 1024 ~ 10GB VRAM / 2048 ~ 15GB VRAM for resnet18 ~ 1GB more for resnet34, 1024 ~ 17GB for resnet50"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "# patients general 65401\n",
      "# patients with race 58010\n",
      "loading images to RAM\n",
      "images loaded to RAM\n",
      "loading images to RAM\n",
      "images loaded to RAM\n",
      "loading images to RAM\n",
      "images loaded to RAM\n"
     ]
    }
   ],
   "source": [
    "train_ds, r_val_ds, r_test_ds = get_chexpert()\n",
    "\n",
    "run_path = os.path.join(RESULTS_PATH, f\"chexpert_{model}_mseed{method_seeds[0]}_dseed{dseed}\")\n",
    "fair_inds = torch.load(os.path.join(run_path, \"fair_inds.pt\"))\n",
    "val_inds = torch.load(os.path.join(run_path, \"val_inds.pt\"))\n",
    "\n",
    "fair_ds = TransformWrapper(Subset(train_ds, indices=fair_inds))\n",
    "val_ds = TransformWrapper(Subset(train_ds, indices=val_inds))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "@torch.no_grad()\n",
    "def evaluate(networks: List, ds: Dataset):\n",
    "    probits = list()\n",
    "    for network in tqdm(networks):\n",
    "        network.eval()\n",
    "        probits_ = list()\n",
    "        for x, _, _ in DataLoader(ds, batch_size = batch_size, shuffle=False, drop_last=False):\n",
    "            x = x.to(device)\n",
    "\n",
    "            probits_.append(torch.softmax(network.forward(x), dim=1).cpu())\n",
    "        probits_ = torch.concat(probits_, dim=0)\n",
    "        probits.append(probits_)\n",
    "    return torch.stack(probits, dim=0)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [06:50<00:00, 41.09s/it]\n",
      "100%|██████████| 10/10 [06:30<00:00, 39.04s/it]\n",
      "100%|██████████| 10/10 [00:03<00:00,  3.25it/s]\n",
      "100%|██████████| 10/10 [00:08<00:00,  1.16it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluated method seed 42\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [06:33<00:00, 39.34s/it]\n",
      "100%|██████████| 10/10 [09:02<00:00, 54.21s/it]\n",
      "100%|██████████| 10/10 [00:03<00:00,  3.18it/s]\n",
      "100%|██████████| 10/10 [00:10<00:00,  1.05s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluated method seed 142\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [07:30<00:00, 45.07s/it]\n",
      "100%|██████████| 10/10 [07:48<00:00, 46.87s/it]\n",
      "100%|██████████| 10/10 [00:03<00:00,  3.05it/s]\n",
      "100%|██████████| 10/10 [00:09<00:00,  1.05it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluated method seed 242\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [08:08<00:00, 48.90s/it]\n",
      "100%|██████████| 10/10 [07:35<00:00, 45.50s/it]\n",
      "100%|██████████| 10/10 [00:03<00:00,  3.23it/s]\n",
      "100%|██████████| 10/10 [00:09<00:00,  1.01it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluated method seed 342\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [06:54<00:00, 41.50s/it]\n",
      "100%|██████████| 10/10 [07:17<00:00, 43.71s/it]\n",
      "100%|██████████| 10/10 [00:03<00:00,  3.30it/s]\n",
      "100%|██████████| 10/10 [00:09<00:00,  1.09it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluated method seed 442\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "for mseed in method_seeds:\n",
    "\n",
    "    path = os.path.join(RESULTS_PATH, f\"chexpert_{model}_mseed{mseed}_dseed{dseed}\")\n",
    "\n",
    "    # load networks\n",
    "    networks = list()\n",
    "    model_files = glob.glob(os.path.join(path, \"models\", \"*.pt\"))\n",
    "    for model_file in sorted(model_files):\n",
    "        if model == \"resnet18\":\n",
    "            network = tv.models.resnet18(weights=None) \n",
    "            network.fc = nn.Linear(in_features=512, out_features=2)\n",
    "        elif model == \"resnet34\":\n",
    "            network = tv.models.resnet34(weights=None) \n",
    "            network.fc = nn.Linear(in_features=512, out_features=2)\n",
    "        elif model == \"resnet50\":\n",
    "            network = tv.models.resnet50(weights=None) \n",
    "            network.fc = nn.Linear(in_features=2048, out_features=2)\n",
    "        elif args.network == \"efficientnet\":\n",
    "            network = tv.models.efficientnet_v2_s(weights=None)\n",
    "            network.classifier = nn.Sequential(\n",
    "                nn.Dropout(p=0.2),\n",
    "                nn.Linear(in_features=1280, out_features=2)\n",
    "                )\n",
    "        elif model == \"regnet\":\n",
    "            network = tv.models.regnet_y_800mf(weights=None)\n",
    "            network.fc = nn.Linear(in_features=784, out_features=2)\n",
    "\n",
    "        network.load_state_dict(torch.load(model_file, map_location=device))\n",
    "        network.to(device)\n",
    "        networks.append(copy.deepcopy(network))\n",
    "\n",
    "    # evaluate\n",
    "    torch.save(evaluate(networks, fair_ds), os.path.join(path, f\"fair_probits.pt\"))\n",
    "    torch.save(evaluate(networks, val_ds), os.path.join(path, f\"val_probits.pt\"))\n",
    "    torch.save(evaluate(networks, r_val_ds), os.path.join(path, f\"r_val_probits.pt\"))\n",
    "    torch.save(evaluate(networks, r_test_ds), os.path.join(path, f\"r_test_probits.pt\"))\n",
    "    print(\"Evaluated method seed\", mseed)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "quam",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
