{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2923bca2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import pickle\n",
    "import argparse\n",
    "import numpy as np\n",
    "import torch.nn as nn\n",
    "from torchmetrics import Accuracy\n",
    "from torch.utils.data import DataLoader, random_split\n",
    "from torchvision import transforms\n",
    "from torchvision.datasets import MNIST\n",
    "from torchvision.datasets import ImageFolder\n",
    "import os\n",
    "from fastai.vision.all import untar_data, URLs\n",
    "import pandas as pd\n",
    "from pytorch_lightning import Trainer\n",
    "from dime.data_utils import HistopathologyDownsampledDataset\n",
    "from dime.utils import MaskLayer2d\n",
    "from dime import MaskingPretrainer\n",
    "from dime import CMIEstimator, MaskLayer\n",
    "from dime.resnet_imagenet import resnet18, resnet34, Predictor, ValueNetwork, ResNet18Backbone, resnet50\n",
    "from dime.vit import PredictorViT, ValueNetworkViT\n",
    "import timm\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "60a534bf",
   "metadata": {},
   "source": [
    "# Load Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "108302ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load test dataset, split into train/val\n",
    "mnist_dataset = MNIST('/tmp/mnist/', download=True, train=True,\n",
    "                      transform=transforms.Compose([transforms.ToTensor(), transforms.Lambda(torch.flatten)]))\n",
    "np.random.seed(0)\n",
    "# Load test dataset\n",
    "test_dataset = MNIST('/tmp/mnist/', download=True, train=False,\n",
    "                     transform=transforms.Compose([transforms.ToTensor(), transforms.Lambda(torch.flatten)]))\n",
    "\n",
    "device = torch.device('cuda:2')\n",
    "\n",
    "test_dataloader = DataLoader(\n",
    "        test_dataset, batch_size=128, shuffle=False, pin_memory=True,\n",
    "        drop_last=True, num_workers=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "769c93ec",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset MNIST\n",
       "    Number of datapoints: 10000\n",
       "    Root location: /tmp/mnist/\n",
       "    Split: Test\n",
       "    StandardTransform\n",
       "Transform: Compose(\n",
       "               ToTensor()\n",
       "               Lambda()\n",
       "           )"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d6bf3c02",
   "metadata": {},
   "source": [
    "# Set up networks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "256bc905",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using 16bit None Automatic Mixed Precision (AMP)\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    }
   ],
   "source": [
    "acc_metric = Accuracy(task='multiclass', num_classes=10)\n",
    "d_in = 784\n",
    "d_out = 10\n",
    "hidden = 512\n",
    "dropout = 0.3\n",
    "\n",
    "# Outcome Predictor\n",
    "predictor = nn.Sequential(\n",
    "    nn.Linear(d_in * 2, hidden),\n",
    "    nn.ReLU(),\n",
    "    nn.Dropout(dropout),\n",
    "    nn.Linear(hidden, hidden),\n",
    "    nn.ReLU(),\n",
    "    nn.Dropout(dropout),\n",
    "    nn.Linear(hidden, d_out)).to(device)\n",
    "\n",
    "# CMI Predictor\n",
    "value_network = nn.Sequential(\n",
    "    nn.Linear(d_in * 2, hidden),\n",
    "    nn.ReLU(),\n",
    "    nn.Dropout(dropout),\n",
    "    nn.Linear(hidden, hidden),\n",
    "    nn.ReLU(),\n",
    "    nn.Dropout(dropout),\n",
    "    nn.Linear(hidden, d_in)).to(device)\n",
    "\n",
    "value_network[0] = predictor[0]\n",
    "value_network[3] = predictor[3]\n",
    "\n",
    "mask_layer = MaskLayer(append=True, mask_size=d_in)\n",
    "\n",
    "trainer = Trainer(\n",
    "                    accelerator='gpu',\n",
    "                    devices=[device.index],\n",
    "                    precision=16\n",
    "                )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b0660b31",
   "metadata": {},
   "source": [
    "# Evaluate Penalized Policy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50c76708",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "for trial in range(0, 5):\n",
    "    results_dict = {\"acc\": {}}\n",
    "    path = \"<path_to_trained_model>\"\n",
    "\n",
    "    greedy_cmi_estimator = CMIEstimator.load_from_checkpoint(path,\n",
    "                                                             value_network=value_network,\n",
    "                                                             predictor=predictor,\n",
    "                                                             mask_layer=mask_layer,\n",
    "                                                             lr=1e-3,\n",
    "                                                             min_lr=1e-6,\n",
    "                                                             max_features=50,\n",
    "                                                             eps=0.05,\n",
    "                                                             loss_fn=nn.CrossEntropyLoss(reduction='none'),\n",
    "                                                             val_loss_fn=acc_metric,\n",
    "                                                             eps_decay=0.2,\n",
    "                                                             eps_steps=10,\n",
    "                                                             patience=3,\n",
    "                                                             feature_costs=None).to(device)\n",
    "    avg_num_features_lamda = []\n",
    "    accuracy_scores_lamda = []\n",
    "    all_masks_lamda =[]\n",
    "\n",
    "    lamda_values = list(np.geomspace(0.00016, 0.28, num=10))\n",
    "    for lamda in lamda_values:\n",
    "        metric_dict = greedy_cmi_estimator.inference(trainer, test_dataloader, feature_costs=None, lam=lamda)\n",
    "        \n",
    "        y = metric_dict['y']\n",
    "        pred = metric_dict['pred']\n",
    "        accuracy_score = acc_metric(pred, y)\n",
    "        final_masks = np.array(metric_dict['mask'])\n",
    "        accuracy_scores_lamda.append(accuracy_score)\n",
    "        avg_num_features_lamda.append(np.mean(np.sum(final_masks, axis=1)))\n",
    "        results_dict['acc'][np.mean(np.sum(final_masks, axis=1))] = accuracy_score\n",
    "\n",
    "        print(f\"Lambda={lamda}, Acc={accuracy_score}, Avg. num features={np.mean(np.sum(final_masks, axis=1))}\")\n",
    "        all_masks_lamda.append(final_masks)\n",
    "    with open(f'results/mnist_lamda_ours_trial_{trial-4}.pkl', 'wb') as f:\n",
    "        pickle.dump(results_dict, f)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "79adee46",
   "metadata": {},
   "source": [
    "# Evaluate Budget Constrained Policy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "d9a94c19",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Predicting DataLoader 0: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:07<00:00, 10.04it/s]\n",
      "Budget=10, Acc=0.7473999857902527, Avg. num features=10.0\n"
     ]
    }
   ],
   "source": [
    "results_dict = {\"acc\": {}}\n",
    "# for trial in range(0, 5):\n",
    "path = \"<path_to_trained_model\"\n",
    "greedy_cmi_estimator = CMIEstimator.load_from_checkpoint(path,\n",
    "                                                         value_network=value_network,\n",
    "                                                         predictor=predictor,\n",
    "                                                         mask_layer=mask_layer,\n",
    "                                                         lr=1e-3,\n",
    "                                                         min_lr=1e-6,\n",
    "                                                         max_features=50,\n",
    "                                                         eps=0.05,\n",
    "                                                         loss_fn=nn.CrossEntropyLoss(reduction='none'),\n",
    "                                                         val_loss_fn=acc_metric,\n",
    "                                                         eps_decay=0.2,\n",
    "                                                         eps_steps=10,\n",
    "                                                         patience=3,\n",
    "                                                         feature_costs=None)\n",
    "avg_num_features_budget = []\n",
    "accuracy_scores_budget = []\n",
    "all_masks_budget=[]\n",
    "\n",
    "max_budget_values = [10] #[3] + list(range(5, 30, 5))\n",
    "for budget in max_budget_values:\n",
    "    metric_dict_budget = greedy_cmi_estimator.inference(trainer, test_dataloader, \n",
    "                                                                    feature_costs=None, budget=budget)\n",
    "\n",
    "    y = metric_dict_budget['y']\n",
    "    pred = metric_dict_budget['pred']\n",
    "    accuracy_score = acc_metric(pred, y)\n",
    "    final_masks = np.array(metric_dict_budget['mask'])\n",
    "    accuracy_scores_budget.append(accuracy_score)\n",
    "    avg_num_features_budget.append(np.mean(np.sum(final_masks, axis=1)))\n",
    "    results_dict['acc'][np.mean(np.sum(final_masks, axis=1))] = accuracy_score\n",
    "    print(f\"Budget={budget}, Acc={accuracy_score}, Avg. num features={np.mean(np.sum(final_masks, axis=1))}\")\n",
    "\n",
    "all_masks_budget.append(final_masks)\n",
    "# with open(f'results/mnist_ours_trial_{trial-4}.pkl', 'wb') as f:\n",
    "#     pickle.dump(results_dict, f)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a9b998db",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:test_DIME_2]",
   "language": "python",
   "name": "conda-env-test_DIME_2-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  },
  "vscode": {
   "interpreter": {
    "hash": "8d82fba152697c33d0ffd056b79e1954c89584403222891f7cd7fd247ae84ab2"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
