{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f082db0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import pickle\n",
    "import argparse\n",
    "import numpy as np\n",
    "import torch.nn as nn\n",
    "from torchmetrics import Accuracy\n",
    "from torch.utils.data import DataLoader, random_split\n",
    "from torchvision import transforms\n",
    "from torchvision.datasets import ImageFolder\n",
    "import os\n",
    "from fastai.vision.all import untar_data, URLs\n",
    "from pytorch_lightning import Trainer\n",
    "from dime.utils import MaskLayer2d\n",
    "from dime import CMIEstimator, MaskingPretrainer\n",
    "from dime.resnet_imagenet import resnet18, Predictor, ValueNetwork, ResNet18Backbone\n",
    "from dime.vit import PredictorViT, ValueNetworkViT\n",
    "import timm\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d029cbb1",
   "metadata": {},
   "source": [
    "# Load Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb784356",
   "metadata": {},
   "outputs": [],
   "source": [
    "image_size = 224\n",
    "mask_width = 14\n",
    "mask_type = 'zero'\n",
    "mask_layer = MaskLayer2d(append=False, mask_width=mask_width, patch_size=image_size/mask_width)\n",
    "\n",
    "device = torch.device('cuda:7')\n",
    "\n",
    "norm_constants = ((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))\n",
    "unnorm_constants = ((-0.4914/0.2023, -0.4822/0.1994, -0.4465/0.2010), (1/0.2023, 1/0.1994, 1/0.2010))\n",
    "mbsize = 32\n",
    "dataset_path = \"/homes/<lab_name>/<user_name>/.fastai/data/imagenette2-320\"\n",
    "transforms_test = transforms.Compose([\n",
    "        transforms.Resize(256),\n",
    "        transforms.CenterCrop(image_size),\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize(*norm_constants),\n",
    "    ])\n",
    "\n",
    "test_dataset = ImageFolder(dataset_path+'/val', transforms_test)\n",
    "test_dataloader = DataLoader(test_dataset, batch_size=mbsize, pin_memory=True, num_workers=4, shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b18d5edd",
   "metadata": {},
   "source": [
    "# Load Pretrained Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd81fb41",
   "metadata": {},
   "outputs": [],
   "source": [
    "acc_metric = Accuracy(task='multiclass', num_classes=10)\n",
    "\n",
    "arch = \"vit\"\n",
    "use_entropy=False\n",
    "if arch == 'resnet':\n",
    "    backbone = ResNet18Backbone(resnet18(pretrained=True))\n",
    "    predictor =  Predictor(backbone)\n",
    "    block_layer_stride = 1\n",
    "    if mask_width == 14:\n",
    "        block_layer_stride = 0.5\n",
    "    value_network = ValueNetwork(backbone, block_layer_stride=block_layer_stride)\n",
    "else:\n",
    "    backbone = timm.create_model('vit_small_patch16_224', pretrained=True)\n",
    "    predictor =  PredictorViT(backbone)\n",
    "    value_network = ValueNetworkViT(backbone)\n",
    "\n",
    "trained_model_path = f\"<path_to_trained_model>\"\n",
    "\n",
    "greedy_cmi_estimator = CMIEstimator.load_from_checkpoint(trained_model_path,\n",
    "                                                         value_network=value_network,\n",
    "                                                         predictor=predictor,\n",
    "                                                         mask_layer=mask_layer,\n",
    "                                                         lr=1e-5,\n",
    "                                                         min_lr=1e-8,\n",
    "                                                         max_features=50,\n",
    "                                                         eps=0.05,\n",
    "                                                         loss_fn=nn.CrossEntropyLoss(reduction='none'),\n",
    "                                                         val_loss_fn=acc_metric,\n",
    "                                                         eps_decay=0.2,\n",
    "                                                         eps_steps=10,\n",
    "                                                         patience=3,\n",
    "                                                         feature_costs=None)\n",
    "trainer = Trainer(\n",
    "                    accelerator='gpu',\n",
    "                    devices=[device.index],\n",
    "                    precision=16\n",
    "                )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3a19ebda",
   "metadata": {},
   "source": [
    "# Evaluate Penalized Policy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49ae20b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "avg_num_features_lamda = []\n",
    "accuracy_scores_lamda = []\n",
    "all_masks_lamda =[]\n",
    "results_dict = {\"acc\": {}}\n",
    "\n",
    "lamda_values = [0.4] #list(np.geomspace(0.01, 0.3, num=10))#[0.01] #list(np.geomspace(0.001, 0.3, num=10))\n",
    "for lamda in lamda_values:\n",
    "    metric_dict = greedy_cmi_estimator.inference(trainer, test_dataloader, feature_costs=None, lam=lamda)\n",
    "        \n",
    "    y = metric_dict['y']\n",
    "    pred = metric_dict['pred']\n",
    "    accuracy_score = acc_metric(pred, y)\n",
    "    final_masks = np.array(metric_dict['mask'])\n",
    "    accuracy_scores_lamda.append(accuracy_score)\n",
    "    avg_num_features_lamda.append(np.mean(np.sum(final_masks, axis=1)))\n",
    "    results_dict['acc'][np.mean(np.sum(final_masks, axis=1))] = accuracy_score\n",
    "\n",
    "    print(f\"Lambda={lamda}, Acc={accuracy_score}, Avg. num features={np.mean(np.sum(final_masks, axis=1))}\")\n",
    "    all_masks_lamda.append(final_masks)\n",
    "\n",
    "# with open(f'results/imagenette_ours_trial_2_penalized.pkl', 'wb') as f:\n",
    "#     pickle.dump(results_dict, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f8af2052",
   "metadata": {},
   "source": [
    "# Evaluate Budget Constrained Policy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8286d647",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "results_dict = {\"acc\": {}}\n",
    "\n",
    "avg_num_features_budget = []\n",
    "accuracy_scores_budget = []\n",
    "all_masks_budget=[]\n",
    "\n",
    "max_budget_values = list(range(1, 15, 1))\n",
    "for budget in max_budget_values:\n",
    "    metric_dict_budget = greedy_cmi_estimator.inference(trainer, test_dataloader, feature_costs=None, budget=budget)\n",
    "        \n",
    "    y = metric_dict_budget['y']\n",
    "    pred = metric_dict_budget['pred']\n",
    "    accuracy_score = acc_metric(pred, y)\n",
    "    final_masks = np.array(metric_dict_budget['mask'])\n",
    "    accuracy_scores_budget.append(accuracy_score)\n",
    "    avg_num_features_budget.append(np.mean(np.sum(final_masks, axis=1)))\n",
    "    results_dict['acc'][np.mean(np.sum(final_masks, axis=1))] = accuracy_score\n",
    "    print(f\"Budget={budget}, Acc={accuracy_score}, Avg. num features={np.mean(np.sum(final_masks, axis=1))}\")\n",
    "\n",
    "    all_masks_budget.append(final_masks)\n",
    "\n",
    "with open(f'results/imagenette_ours_trial_2.pkl', 'wb') as f:\n",
    "    pickle.dump(results_dict, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bafd0ca4",
   "metadata": {},
   "source": [
    "# Evaluate Confidence Constrained Policy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3ead263",
   "metadata": {},
   "outputs": [],
   "source": [
    "avg_num_features_confidence = []\n",
    "accuracy_scores_confidence = []\n",
    "all_masks_confidence=[]\n",
    "confidence_values = list(np.arange(0.1, 1, 0.1))\n",
    "\n",
    "for confidence in confidence_values:\n",
    "    metric_dict = greedy_cmi_estimator.inference(trainer, test_dataloader, feature_costs=None, confidence=confidence)\n",
    "    y = metric_dict['y']\n",
    "    pred = metric_dict['pred']\n",
    "    accuracy_score = acc_metric(pred, y)\n",
    "\n",
    "    final_masks = np.array(metric_dict['mask'])\n",
    "    accuracy_scores_confidence.append(accuracy_score)\n",
    "    avg_num_features_confidence.append(np.mean(np.sum(final_masks, axis=1)))\n",
    "    print(f\"Confidence={confidence}, Acc={accuracy_score}, Avg. num features={np.mean(np.sum(final_masks, axis=1))}\")\n",
    "    all_masks_confidence.append(final_masks)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "32e47a51",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:test_DIME_2]",
   "language": "python",
   "name": "conda-env-test_DIME_2-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
