{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80698da1-d154-4da0-ae92-2e9ab1fcf816",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "from functools import lru_cache\n",
    "import sys\n",
    "sys.path.append(\"../\")\n",
    "from os.path import join\n",
    "from IPython.display import display\n",
    "from functools import partial\n",
    "\n",
    "import matplotlib\n",
    "matplotlib.rcParams['pdf.fonttype'] = 42\n",
    "matplotlib.rcParams['ps.fonttype'] = 42\n",
    "matplotlib.rc('text', usetex=False)\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from scipy.stats import ttest_ind\n",
    "import scipy\n",
    "import joblib\n",
    "import numpy as np\n",
    "import torch\n",
    "from torch import nn\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from tqdm.notebook import tqdm\n",
    "from sklearn.ensemble import IsolationForest\n",
    "from skimage.segmentation import mark_boundaries\n",
    "import pandas as pd\n",
    "pd.options.display.float_format = '{:,.3f}'.format\n",
    "pd.set_option('display.max_rows', 128)\n",
    "\n",
    "\n",
    "from spurious_ml.datasets import add_spurious_correlation\n",
    "from spurious_ml.models.torch_utils import archs\n",
    "from spurious_ml.variables import auto_var\n",
    "from params import *\n",
    "from utils import params_to_dataframe\n",
    "\n",
    "fontsize=15"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bab15bd0-6d9a-47fd-b01f-69b652e290b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "def mlp_pred_fn(X, model, device=\"cuda\"):\n",
    "    dset = torch.utils.data.TensorDataset(torch.from_numpy(X.reshape(len(X), -1)).float())\n",
    "    loader = torch.utils.data.DataLoader(dset, batch_size=256)\n",
    "    \n",
    "    model.to(device).eval()\n",
    "    fetX = []\n",
    "    for (x, ) in tqdm(loader, desc=\"[pred_fn]\"):\n",
    "        fetX.append(model(x.to(device)).cpu().detach().numpy())\n",
    "    fetX = np.concatenate(fetX, axis=0)\n",
    "    return fetX\n",
    "\n",
    "def cnn_pred_fn(X, model, device=\"cuda\"):\n",
    "    if len(X.shape) == 4:\n",
    "        X = X.transpose(0, 3, 1, 2)\n",
    "        \n",
    "    dset = torch.utils.data.TensorDataset(torch.from_numpy(X).float())\n",
    "    loader = torch.utils.data.DataLoader(dset, batch_size=128)\n",
    "    \n",
    "    model.to(device).eval()\n",
    "    fetX = []\n",
    "    #for (x, ) in loader:\n",
    "    for (x, ) in tqdm(loader, desc=\"[pred_fn]\"):\n",
    "        fetX.append(model(x.to(device)).cpu().detach().numpy())\n",
    "        #fetX.append(model.feature_extractor(x.to(device)).cpu().detach().flatten(1).numpy())\n",
    "    fetX = np.concatenate(fetX, axis=0)\n",
    "    return fetX\n",
    "\n",
    "class CLF():\n",
    "    def __init__(self, model):\n",
    "        self.model = model\n",
    "        \n",
    "    def predict(self, X, device=\"cuda\"):\n",
    "        return pred_fn(X, self.model, device=device)\n",
    "    \n",
    "@lru_cache(maxsize=None)\n",
    "def evaluate(ds_name, model_path, arch, spurious_version):\n",
    "    trnX, trny, tstX, tsty, spurious_ind = auto_var.get_var_with_argument(\"dataset\", ds_name)\n",
    "    n_classes = len(np.unique(trny))\n",
    "    n_channels = trnX.shape[-1]\n",
    "    res = torch.load(model_path)\n",
    "    model = getattr(archs, arch)(n_features=(np.prod(trnX.shape[1:]), ), n_channels=n_channels, n_classes=n_classes)\n",
    "    model.load_state_dict(res['model_state_dict'])\n",
    "    \n",
    "    if 'MLP' in arch:\n",
    "        pred_fn = mlp_pred_fn\n",
    "    else:\n",
    "        pred_fn = cnn_pred_fn\n",
    "    \n",
    "    #trn_preds = pred_fn(trnX, model)\n",
    "    tst_preds = pred_fn(tstX, model)\n",
    "    tst_preds = scipy.special.softmax(tst_preds, axis=1)\n",
    "    modified_tstX = np.copy(tstX)\n",
    "    modified_tstX = add_spurious_correlation(modified_tstX, spurious_version)\n",
    "    mod_tst_preds = pred_fn(modified_tstX, model)\n",
    "    mod_tst_preds = scipy.special.softmax(mod_tst_preds, axis=1)\n",
    "    return mod_tst_preds, tst_preds, tsty\n",
    "    #return ((mod_tst_preds[:, tar_cls] - tst_preds[:, tar_cls]) > 0).mean(), (mod_tst_preds[:, tar_cls] - tst_preds[:, tar_cls]).mean()\n",
    "    #return ((mod_tst_preds[:, tar_cls] - tst_preds[:, tar_cls]) > 1e-9).mean(), (mod_tst_preds[:, tar_cls] - tst_preds[:, tar_cls]).mean()`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ceef40d-e645-4fd1-878c-d738d7ca9ba5",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_dset = \"mnist\"\n",
    "spurious_version, i, tar_cls, seed = \"v10\", 20, 1, 0\n",
    "folder, lr, arch, momentum, optimizer = \"train_classifier\", 0.01, \"CNN002\", 0., \"adam\"\n",
    "ds_name = f\"{base_dset}{spurious_version}-{i}-{tar_cls}-{seed}\"\n",
    "model_path = f\"../models/{folder}/128-{ds_name}-70-{lr}-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0.pt\"\n",
    "mod_tst_preds, tst_preds, tsty = evaluate(ds_name, model_path, arch, spurious_version)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e47642bd-d69b-48b2-b55e-8aa0e76ac071",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
