{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "a5720aa2-4365-4bd3-b2e8-c52217de46ae",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "#Virtually change to parent dir\n",
    "os.chdir(\"..\")\n",
    "\n",
    "import utils\n",
    "import data_utils\n",
    "import json\n",
    "\n",
    "import cbm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "b486050c-7a36-4339-b9ff-e6b8426dbbd4",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_types = [\"sparse_ft\", \"lf_cbm\"]\n",
    "datasets = [\"cifar10\", \"cifar100\", \"cub\", \"places365\", \"imagenet\"]\n",
    "device = \"cuda\"\n",
    "load_dir = \"saved_models/\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "309e2bb6-d79d-4de3-bdf4-6b9c3b466682",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:16<00:00,  1.19it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cifar10 - sparse_ft - Accuracy: 82.95% \n",
      "\n",
      "Files already downloaded and verified\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:15<00:00,  1.26it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cifar10 - lf_cbm - Accuracy: 86.37% \n",
      "\n",
      "Files already downloaded and verified\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:15<00:00,  1.27it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cifar100 - sparse_ft - Accuracy: 58.35% \n",
      "\n",
      "Files already downloaded and verified\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:15<00:00,  1.32it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cifar100 - lf_cbm - Accuracy: 65.27% \n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 12/12 [00:31<00:00,  2.67s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cub - sparse_ft - Accuracy: 75.96% \n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 12/12 [00:31<00:00,  2.63s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cub - lf_cbm - Accuracy: 74.59% \n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 73/73 [01:19<00:00,  1.09s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "places365 - sparse_ft - Accuracy: 38.46% \n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 73/73 [01:18<00:00,  1.08s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "places365 - lf_cbm - Accuracy: 43.71% \n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [05:34<00:00,  3.34s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "imagenet - sparse_ft - Accuracy: 74.35% \n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [05:29<00:00,  3.29s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "imagenet - lf_cbm - Accuracy: 71.98% \n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "for dataset_name in datasets:\n",
    "    for model_type in model_types:\n",
    "        load_path = os.path.join(load_dir+\"{}_{}\".format(dataset_name, model_type))\n",
    "        with open(os.path.join(load_path, \"args.txt\"), \"r\") as f:\n",
    "            args = json.load(f)\n",
    "        dataset = args[\"dataset\"]\n",
    "        _, target_preprocess = data_utils.get_target_model(args[\"backbone\"], device)\n",
    "        if model_type == \"lf_cbm\":\n",
    "            model = cbm.load_cbm(load_path, device)\n",
    "        else:\n",
    "            model = cbm.load_std(load_path, device)\n",
    "\n",
    "        val_d_probe = dataset+\"_val\"\n",
    "        val_data_t = data_utils.get_data(val_d_probe, preprocess=target_preprocess)\n",
    "            \n",
    "        accuracy = utils.get_accuracy_cbm(model, val_data_t, device)\n",
    "        print(\"{} - {} - Accuracy: {:.2f}% \\n\".format(dataset_name, model_type, accuracy*100))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:jovyan-lf_cbm]",
   "language": "python",
   "name": "conda-env-jovyan-lf_cbm-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
