{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import glob\n",
    "import copy\n",
    "from tqdm import tqdm\n",
    "from typing import List\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torchvision as tv\n",
    "from torch.utils.data import Subset, Dataset, DataLoader\n",
    "\n",
    "from source.constants import RESULTS_PATH\n",
    "from source.data.face_detection import get_fair_face, get_utk"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "method_seeds = [42, 142, 242, 342, 442]\n",
    "dseed = 42\n",
    "\n",
    "model = [\"resnet18\", \"resnet34\", \"resnet50\", \"regnet\", \"efficientnet\"][3]\n",
    "\n",
    "target = 0 # 0, 1, 2, 3\n",
    "\n",
    "device = \"cuda:7\"\n",
    "batch_size = 256 # 1024 ~ 10GB VRAM / 2048 ~ 15GB VRAM for resnet18 ~ 1GB more for resnet34, 1024 ~ 17GB for resnet50"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ff_train_ds, ff_test_ds = get_fair_face(target=target, binarize=True, augment=False)\n",
    "utk_test_ds = get_utk(target=target, binarize=True)\n",
    "\n",
    "run_path = os.path.join(RESULTS_PATH, f\"fairface_target{target}_{model}_mseed{method_seeds[0]}_dseed{dseed}\")\n",
    "fair_inds = torch.load(os.path.join(run_path, \"fair_inds.pt\"))\n",
    "val_inds = torch.load(os.path.join(run_path, \"val_inds.pt\"))\n",
    "\n",
    "fair_ds = Subset(ff_train_ds, indices=fair_inds)\n",
    "val_ds = Subset(ff_train_ds, indices=val_inds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "@torch.no_grad()\n",
    "def evaluate(networks: List, ds: Dataset):\n",
    "    probits = list()\n",
    "    for network in tqdm(networks):\n",
    "        network.eval()\n",
    "        probits_ = list()\n",
    "        for x, _, _ in DataLoader(ds, batch_size = batch_size, shuffle=False, drop_last=False):\n",
    "            x = x.to(device)\n",
    "\n",
    "            probits_.append(torch.softmax(network.forward(x), dim=1).cpu())\n",
    "        probits_ = torch.concat(probits_, dim=0)\n",
    "        probits.append(probits_)\n",
    "    return torch.stack(probits, dim=0)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for mseed in method_seeds:\n",
    "\n",
    "    path = os.path.join(RESULTS_PATH, f\"fairface_target{target}_{model}_mseed{mseed}_dseed{dseed}\")\n",
    "\n",
    "    # load networks\n",
    "    networks = list()\n",
    "    model_files = glob.glob(os.path.join(path, \"models\", \"*.pt\"))\n",
    "    for model_file in sorted(model_files):\n",
    "        if model == \"resnet18\":\n",
    "            network = tv.models.resnet18(weights=None) \n",
    "            network.fc = nn.Linear(in_features=512, out_features=2)\n",
    "        elif model == \"resnet34\":\n",
    "            network = tv.models.resnet34(weights=None) \n",
    "            network.fc = nn.Linear(in_features=512, out_features=2)\n",
    "        elif model == \"resnet50\":\n",
    "            network = tv.models.resnet50(weights=None) \n",
    "            network.fc = nn.Linear(in_features=2048, out_features=2)\n",
    "        elif args.network == \"efficientnet\":\n",
    "            network = tv.models.efficientnet_v2_s(weights=None)\n",
    "            network.classifier = nn.Sequential(\n",
    "                nn.Dropout(p=0.2),\n",
    "                nn.Linear(in_features=1280, out_features=2)\n",
    "                )\n",
    "        elif model == \"regnet\":\n",
    "            network = tv.models.regnet_y_800mf(weights=None)\n",
    "            network.fc = nn.Linear(in_features=784, out_features=2)\n",
    "\n",
    "        network.load_state_dict(torch.load(model_file, map_location=device))\n",
    "        network.to(device)\n",
    "        networks.append(copy.deepcopy(network))\n",
    "\n",
    "    # evaluate\n",
    "    torch.save(evaluate(networks, fair_ds), os.path.join(path, f\"fair_probits_t{target}.pt\"))\n",
    "    torch.save(evaluate(networks, val_ds), os.path.join(path, f\"val_probits_t{target}.pt\"))\n",
    "    torch.save(evaluate(networks, ff_test_ds), os.path.join(path, f\"ff_test_probits_t{target}.pt\"))\n",
    "    torch.save(evaluate(networks, utk_test_ds), os.path.join(path, f\"utk_test_probits_t{target}.pt\"))\n",
    "    print(\"Evaluated method seed\", mseed)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "quam",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
