{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0cdfc8ab-96b8-438c-8b02-a93d5c1f0bb7",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%reload_ext autoreload\n",
    "\n",
    "import sys\n",
    "sys.path.append(\"../\")\n",
    "import os\n",
    "from IPython.display import display\n",
    "from functools import partial\n",
    "from collections import OrderedDict\n",
    "from functools import lru_cache\n",
    "\n",
    "\n",
    "import matplotlib\n",
    "matplotlib.rcParams['pdf.fonttype'] = 42\n",
    "matplotlib.rcParams['ps.fonttype'] = 42\n",
    "matplotlib.rc('text', usetex=False)\n",
    "\n",
    "from sklearn.metrics.pairwise import cosine_similarity\n",
    "from PIL import Image\n",
    "import matplotlib.pyplot as plt\n",
    "import joblib\n",
    "import pandas as pd\n",
    "pd.set_option('display.max_rows', 500)\n",
    "import numpy as np\n",
    "import torch\n",
    "from torch import nn\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from tqdm.notebook import tqdm\n",
    "import torchvision\n",
    "import torch.nn.functional as F\n",
    "\n",
    "from spurious_ml.datasets import add_spurious_correlation, add_colored_spurious_correlation\n",
    "from spurious_ml.models.torch_utils import archs, data_augs\n",
    "from spurious_ml.variables import auto_var\n",
    "from utils import params_to_dataframe\n",
    "\n",
    "fontsize=15"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1afc05b4-3e57-4df9-b9e0-65bd821c2d1f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def mlp_get_feature(X, model, device=\"cuda\", verbose=True, layer=\"last\"):\n",
    "    dset = torch.utils.data.TensorDataset(torch.from_numpy(X.reshape(len(X), -1)).float())\n",
    "    loader = torch.utils.data.DataLoader(dset, batch_size=256)\n",
    "    \n",
    "    model.to(device).eval()\n",
    "    fetX = []\n",
    "    for (x, ) in tqdm(loader, desc=\"[pred_fn]\", disable=(not verbose)):\n",
    "        x = x.to(device)\n",
    "        x = F.relu(model.hidden(x))\n",
    "        x = model.hidden2(x)\n",
    "        if layer == \"last\":\n",
    "            x = F.relu(x)\n",
    "            x = model.hidden3(x)\n",
    "        fetX.append(x.cpu().detach().numpy())\n",
    "    fetX = np.concatenate(fetX, axis=0)\n",
    "    return fetX\n",
    "\n",
    "def cnn_get_feature(X, model, device=\"cuda\", verbose=True):\n",
    "    dset = torch.utils.data.TensorDataset(torch.from_numpy(X.transpose(0, 3, 1, 2)).float())\n",
    "    loader = torch.utils.data.DataLoader(dset, batch_size=256)\n",
    "    \n",
    "    model.to(device).eval()\n",
    "    fetX = []\n",
    "    for (x, ) in tqdm(loader, desc=\"[pred_fn]\", disable=(not verbose)):\n",
    "        x = x.to(device)\n",
    "        x = model.feature_extractor(x)\n",
    "        x = x.view(-1, 64 * 4 * 4)\n",
    "        x = model.classifier.fc1(x)\n",
    "        x = model.classifier.relu1(x)\n",
    "        x = model.classifier.drop(x)\n",
    "        x = model.classifier.fc2(x)\n",
    "        x = model.classifier.relu2(x)\n",
    "        fetX.append(x.cpu().detach().numpy())\n",
    "    fetX = np.concatenate(fetX, axis=0)\n",
    "    return fetX\n",
    "\n",
    "def resnet_get_feature(X, model, device=\"cuda\", verbose=True):\n",
    "    dset = torch.utils.data.TensorDataset(torch.from_numpy(X.transpose(0, 3, 1, 2)).float())\n",
    "    loader = torch.utils.data.DataLoader(dset, batch_size=256)\n",
    "    \n",
    "    model.to(device).eval()\n",
    "    fetX = []\n",
    "    for (x, ) in tqdm(loader, desc=\"[pred_fn]\", disable=(not verbose)):\n",
    "        x = x.to(device)\n",
    "        x = model.conv1(x)\n",
    "        x = model.bn1(x)\n",
    "        x = model.relu(x)\n",
    "        x = model.maxpool(x)\n",
    "\n",
    "        x = model.layer1(x)\n",
    "        x = model.layer2(x)\n",
    "        x = model.layer3(x)\n",
    "        x = model.layer4(x)\n",
    "\n",
    "        x = model.avgpool(x)\n",
    "        x = torch.flatten(x, 1)\n",
    "        fetX.append(x.cpu().detach().numpy())\n",
    "    fetX = np.concatenate(fetX, axis=0)\n",
    "    return fetX\n",
    "\n",
    "def mlp_pred_fn(X, model, device=\"cuda\", verbose=True):\n",
    "    dset = torch.utils.data.TensorDataset(torch.from_numpy(X.reshape(len(X), -1)).float())\n",
    "    loader = torch.utils.data.DataLoader(dset, batch_size=256)\n",
    "    \n",
    "    model.to(device).eval()\n",
    "    fetX = []\n",
    "    for (x, ) in tqdm(loader, desc=\"[pred_fn]\", disable=(not verbose)):\n",
    "        fetX.append(model(x.to(device)).cpu().detach().numpy())\n",
    "    fetX = np.concatenate(fetX, axis=0)\n",
    "    return fetX\n",
    "\n",
    "def cnn_pred_fn(X, model, device=\"cuda\", verbose=True):\n",
    "    if len(X.shape) == 4:\n",
    "        X = X.transpose(0, 3, 1, 2)\n",
    "        \n",
    "    dset = torch.utils.data.TensorDataset(torch.from_numpy(X).float())\n",
    "    loader = torch.utils.data.DataLoader(dset, batch_size=128)\n",
    "    \n",
    "    model.to(device).eval()\n",
    "    fetX = []\n",
    "    #for (x, ) in loader:\n",
    "    for (x, ) in tqdm(loader, desc=\"[pred_fn]\", disable=(not verbose)):\n",
    "        fetX.append(model(x.to(device)).cpu().detach().numpy())\n",
    "        #fetX.append(model.feature_extractor(x.to(device)).cpu().detach().flatten(1).numpy())\n",
    "    fetX = np.concatenate(fetX, axis=0)\n",
    "    return fetX\n",
    "\n",
    "def get_model_dset(ds_name, model_path, arch):\n",
    "    trnX, trny, tstX, tsty, spurious_ind = auto_var.get_var_with_argument(\"dataset\", ds_name)\n",
    "    n_classes = len(np.unique(trny))\n",
    "    n_channels = trnX.shape[-1]\n",
    "    res = torch.load(model_path)\n",
    "    model = getattr(archs, arch)(n_features=(np.prod(trnX.shape[1:]), ), n_channels=n_channels, n_classes=n_classes)\n",
    "    model.load_state_dict(res['model_state_dict'])\n",
    "    return model, trnX, trny, tstX, tsty, spurious_ind"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "391a6f6f-becc-4d7c-aa02-4591f7aeb511",
   "metadata": {},
   "outputs": [],
   "source": [
    "@lru_cache(maxsize=None)\n",
    "def evaluate(ds_name, model_path, arch, spurious_version, seed):\n",
    "    model, trnX, trny, tstX, tsty, spurious_ind = get_model_dset(ds_name, model_path, arch)\n",
    "\n",
    "    if 'MLP' in arch:\n",
    "        feature_fn = mlp_get_feature\n",
    "        pred_fn = mlp_pred_fn\n",
    "    elif 'CNN002' in arch:\n",
    "        feature_fn = cnn_get_feature\n",
    "        pred_fn = cnn_pred_fn\n",
    "    elif 'ResNet' in arch:\n",
    "        feature_fn = resnet_get_feature\n",
    "        pred_fn = cnn_pred_fn\n",
    "    else:\n",
    "        pass\n",
    "\n",
    "    tst_features = feature_fn(tstX, model, verbose=0)\n",
    "    tst_pred = pred_fn(tstX, model, verbose=0)\n",
    "    tstX = add_spurious_correlation(tstX, spurious_version, 0)\n",
    "    tst_spu_features = feature_fn(tstX, model, verbose=0)\n",
    "    if 'mnist' in ds_name or 'fashion' in ds_name:\n",
    "        spuX = np.zeros((1, 28, 28, 1))\n",
    "        spuX = add_spurious_correlation(spuX, spurious_version, 0)\n",
    "        only_spu_features = feature_fn(spuX, model, verbose=0)\n",
    "    elif 'cifar' in ds_name:\n",
    "        spuX = np.zeros((1, 32, 32, 3))\n",
    "        spuX = add_colored_spurious_correlation(spuX, spurious_version, 0)\n",
    "        only_spu_features = feature_fn(spuX, model, verbose=0)\n",
    "\n",
    "    return tst_features, tst_spu_features, only_spu_features, tsty, tst_pred"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13bbd192-ea89-428a-aa0b-602f9c893ee4",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "all_results = {}\n",
    "base_dset = \"mnist\"\n",
    "base_dset = \"cifar10\"\n",
    "#base_dset = \"fashion\"\n",
    "\n",
    "def get_fraction_intersect(features, only_spu_fets):\n",
    "    temp = 0\n",
    "    for fe in features:\n",
    "        temp += int(len(set(*np.where(fe > 0)).intersection(only_spu_fets)) == len(only_spu_fets))\n",
    "        #temp += int(len(set(*np.where(fe > 0)).intersection(only_spu_fets)) > 0)\n",
    "\n",
    "    return temp / len(features)\n",
    "\n",
    "#n_samples = [3, 5, 10, 20, 100, 5000]\n",
    "n_samples = [3, 5, 10, 100, 500]\n",
    "\n",
    "#for optimizer in ['sgd', 'adam']:\n",
    "for optimizer in ['sgd']:\n",
    "    for tar_cls in [0, 1]:\n",
    "        for spurious_version in ['v1']:\n",
    "        #for spurious_version in tqdm(['v1', 'v3', 'v8', 'v19', 'v20', 'v30']):\n",
    "            #for arch in ['MLP', 'LargeMLP', 'LargeMLPv2']:\n",
    "            #for arch in ['CNN002', 'LargeMLP']:\n",
    "            #for arch in ['LargeMLP']:\n",
    "            for arch in ['ResNet50']:\n",
    "                lr = 0.01\n",
    "                momentum = 0. if optimizer == 'adam' else 0.9\n",
    "\n",
    "                ret, spu_ret, clsret, clsspu_ret = [], [], [], []\n",
    "                for i in n_samples:\n",
    "                    trets, tspu_rets, tclsret, tclsspu_ret = [], [], [], []\n",
    "                    for seed in range(5):\n",
    "                        ds_name = f\"{base_dset}{spurious_version}-{i}-{tar_cls}-{seed}\"\n",
    "                        if \"cifar\" in base_dset:\n",
    "                            model_path = f\"../models/train_classifier/64-{ds_name}-70-0.01-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0.pt\"\n",
    "                        else:\n",
    "                            model_path = f\"../models/train_classifier/128-{ds_name}-70-0.01-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0.pt\"\n",
    "                        try:\n",
    "                            tst_features, tst_spu_features, only_spu_features, tsty, tst_pred = evaluate(ds_name, model_path, arch, spurious_version, seed)\n",
    "                            if (tst_pred.argmax(1) == tsty).mean() < 0.6:\n",
    "                                print(f\"Failed to train: {model_path}, {(tst_pred.argmax(1) == tsty).mean()}\")\n",
    "                                continue\n",
    "                            ind = (tsty != tar_cls)\n",
    "                            fets = set(np.where(only_spu_features > 0)[1])\n",
    "                            trets.append(get_fraction_intersect(tst_features[ind], fets))\n",
    "                            tspu_rets.append(get_fraction_intersect(tst_spu_features[ind], fets))\n",
    "                            tclsret.append(get_fraction_intersect(tst_features[np.logical_not(ind)], fets))\n",
    "                            tclsspu_ret.append(get_fraction_intersect(tst_spu_features[np.logical_not(ind)], fets))\n",
    "                        except FileNotFoundError:\n",
    "                            print(f\"missing {model_path}\")\n",
    "                    if trets:\n",
    "                        ret.append(np.mean(trets))\n",
    "                        spu_ret.append(np.mean(tspu_rets))\n",
    "                        clsret.append(np.mean(tclsret))\n",
    "                        clsspu_ret.append(np.mean(tclsspu_ret))\n",
    "                    else:\n",
    "                        ret.append(-1)\n",
    "                        spu_ret.append(-1)\n",
    "                        clsret.append(-1)\n",
    "                        clsspu_ret.append(-1)\n",
    "\n",
    "                key = (optimizer, tar_cls, spurious_version, arch, 'natural')\n",
    "                all_results[key] = ret\n",
    "                key = (optimizer, tar_cls, spurious_version, arch, 'spu')\n",
    "                all_results[key] = spu_ret\n",
    "                key = (optimizer, tar_cls, spurious_version, arch, 'cls natural')\n",
    "                all_results[key] = clsret\n",
    "                key = (optimizer, tar_cls, spurious_version, arch, 'cls spu')\n",
    "                all_results[key] = clsspu_ret\n",
    "                \n",
    "                target = np.tile([[0, 1, 0, 1]], (len(n_samples), 1))\n",
    "                key = (optimizer, tar_cls, spurious_version, arch, 'sim')\n",
    "                sim = cosine_similarity(np.array([ret, spu_ret, clsret, clsspu_ret]).T, target)\n",
    "                all_results[key] = np.diag(sim)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef86909e-20c9-4218-afde-17ae3a14ff19",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(base_dset)\n",
    "df = pd.DataFrame.from_dict(all_results).transpose()\n",
    "df.columns = pd.MultiIndex.from_tuples([(i, ) for i in n_samples])\n",
    "#print(df[[(\"spu\", i) for i in [\"0\"] + n_samples]].to_latex(multirow=True))\n",
    "display(df)\n",
    "\n",
    "td = df[n_samples]\n",
    "#td.index = td.index.droplevel(3).droplevel(0)\n",
    "td.index = td.index.droplevel(0)\n",
    "\n",
    "#try:\n",
    "#    text = td[[(\"sgd\", i) for i in [\"0\"] + n_samples]].to_latex(multirow=True, float_format=\"%.3f\")\n",
    "#    text = text.replace(\"0.\", \".\")\n",
    "#    print(text)\n",
    "#except:\n",
    "#    pass\n",
    "text = td.to_latex(multirow=True, float_format=\"%.3f\")\n",
    "text = text.replace(\"0.\", \".\")\n",
    "text = text.replace(\".000\", \".00\")\n",
    "print(text)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "40da528d-af50-4a5d-b5f2-88db68922e59",
   "metadata": {},
   "source": [
    "## Two class"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7664ca65-60e5-4628-87ea-f34992755b65",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "all_results = {}\n",
    "base_dset = \"mnist\"\n",
    "#base_dset = \"fashion\"\n",
    "\n",
    "def get_fraction_intersect(features, only_spu_fets):\n",
    "    temp = 0\n",
    "    for fe in features:\n",
    "        #temp += int(len(set(*np.where(fe > 0)).intersection(only_spu_fets)) == len(only_spu_fets))\n",
    "        temp += int(len(set(*np.where(fe > 0)).intersection(only_spu_fets)) > 0)\n",
    "\n",
    "    return temp / len(features)\n",
    "\n",
    "n_samples = [3, 5, 10, 20, 100, 5000]\n",
    "\n",
    "#for optimizer in ['sgd', 'adam']:\n",
    "for optimizer in ['adam']:\n",
    "    for spurious_version in ['v1', 'v3', 'v8', 'v19', 'v20']:\n",
    "    #for spurious_version in ['v1', 'v3', 'v8', 'v19', 'v20', 'v30']:\n",
    "        #for arch in ['MLP', 'LargeMLP', 'LargeMLPv2']:\n",
    "        #for arch in ['CNN002', 'LargeMLP']:\n",
    "        for arch in ['CNN002', 'LargeMLP']:\n",
    "            lr = 0.01\n",
    "            momentum = 0. if optimizer == 'adam' else 0.9\n",
    "\n",
    "            ret, spu_ret, clsret, clsspu_ret = [], [], [], []\n",
    "            for i in n_samples:\n",
    "                trets, tspu_rets, tclsret, tclsspu_ret = [], [], [], []\n",
    "                for seed in range(5):\n",
    "                    ds_name = f\"twoclass{base_dset}{spurious_version}-{i}-0-1-{seed}\"\n",
    "                    model_path = f\"../models/train_classifier/128-{ds_name}-70-0.01-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0.pt\"\n",
    "                    \n",
    "                    print(model_path)\n",
    "                    try:\n",
    "                        tst_features, tst_spu_features, only_spu_features, tsty, tst_pred = evaluate(ds_name, model_path, arch, spurious_version, seed)\n",
    "                        print((tst_pred.argmax(1) == tsty).mean())\n",
    "                        if (tst_pred.argmax(1) == tsty).mean() < 0.9:\n",
    "                            print(f\"Failed to train: {model_path}\")\n",
    "                            continue\n",
    "                        ind = np.logical_and((tsty != 0), (tsty != 1))\n",
    "                        fets = set(np.where(only_spu_features > 0)[1])\n",
    "                        #if len(fets) == 0:\n",
    "                        #    fets = set([only_spu_features.argmax(1)[0]])\n",
    "                        print(fets)\n",
    "                        trets.append(get_fraction_intersect(tst_features[ind], fets))\n",
    "                        tspu_rets.append(get_fraction_intersect(tst_spu_features[ind], fets))\n",
    "                        tclsret.append(get_fraction_intersect(tst_features[np.logical_not(ind)], fets))\n",
    "                        tclsspu_ret.append(get_fraction_intersect(tst_spu_features[np.logical_not(ind)], fets))\n",
    "                    except FileNotFoundError:\n",
    "                        print(f\"missing {model_path}\")\n",
    "                print(tspu_rets, tclsspu_ret)\n",
    "                if trets:\n",
    "                    ret.append(np.mean(trets))\n",
    "                    spu_ret.append(np.mean(tspu_rets))\n",
    "                    clsret.append(np.mean(tclsret))\n",
    "                    clsspu_ret.append(np.mean(tclsspu_ret))\n",
    "                else:\n",
    "                    ret.append(-1)\n",
    "                    spu_ret.append(-1)\n",
    "                    clsret.append(-1)\n",
    "                    clsspu_ret.append(-1)\n",
    "\n",
    "            key = (optimizer, spurious_version, arch, 'natural')\n",
    "            all_results[key] = ret\n",
    "            key = (optimizer, spurious_version, arch, 'spu')\n",
    "            all_results[key] = spu_ret\n",
    "            key = (optimizer, spurious_version, arch, 'cls natural')\n",
    "            all_results[key] = clsret\n",
    "            key = (optimizer, spurious_version, arch, 'cls spu')\n",
    "            all_results[key] = clsspu_ret\n",
    "\n",
    "            target = np.tile([[0, 1, 0, 1]], (len(n_samples), 1))\n",
    "            key = (optimizer, spurious_version, arch, 'sim')\n",
    "            sim = cosine_similarity(np.array([ret, spu_ret, clsret, clsspu_ret]).T, target)\n",
    "            all_results[key] = np.diag(sim)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee97767c-a684-4aef-80a1-0342e44d855a",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(base_dset)\n",
    "df = pd.DataFrame.from_dict(all_results).transpose()\n",
    "df.columns = pd.MultiIndex.from_tuples([(i, ) for i in n_samples])\n",
    "#print(df[[(\"spu\", i) for i in [\"0\"] + n_samples]].to_latex(multirow=True))\n",
    "display(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d8afc41-f30d-41f7-ab0d-3a50b803250d",
   "metadata": {},
   "outputs": [],
   "source": [
    "text = td.to_latex(multirow=True, float_format=\"%.3f\")\n",
    "text = text.replace(\"0.\", \".\")\n",
    "text = text.replace(\".000\", \".00\")\n",
    "print(text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62de3c1d-695c-4b03-934e-127d377fe2cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(spuX[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9428dd38-af2f-421c-aea1-9b32e474e1d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "(np.where((tst_features > 0))[1] == 214).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fdf1b1fe-6efb-4b8f-bca4-11318ce49c65",
   "metadata": {},
   "outputs": [],
   "source": [
    "(np.where((tst_spu_features > 0))[1] == 214).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1477f2e5-b769-4163-928c-6bd2a2a68946",
   "metadata": {},
   "outputs": [],
   "source": [
    "res = joblib.load(\"../results/train_classifier/128-twoclassmnistv20-3-0-1-2-70-0.01-0.01-ce-tor-CNN002-0.0-adam-0-0.0.pkl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b585386a-436a-4c61-8cbb-6c2d5ec44052",
   "metadata": {},
   "outputs": [],
   "source": [
    "X = torch.from_numpy(trnX[0].reshape(1, -1)).float().cuda()\n",
    "X.requires_grad_(True)\n",
    "x = F.relu(model.hidden(X))\n",
    "x = F.relu(model.hidden2(x))\n",
    "x = model.hidden3(x)\n",
    "print(x[0, 118])\n",
    "x[0, 118].backward()\n",
    "plt.imshow(X.grad.abs()[0].reshape(28, 28).detach().cpu().numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "010cc421-69d5-41ad-954b-7c9391334371",
   "metadata": {},
   "outputs": [],
   "source": [
    "trny"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f0ff34e-932c-49a3-ab71-9172359dc90c",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(tstX[3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "005a6658-0fa7-4dc1-9b3e-d235bd842b4b",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(spuX[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4dc501e-b40d-40cc-b53f-d552ade2e4f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.fc.weight[:, 118]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d8ab9e5-a323-407f-9f8c-eb3636a147ef",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "424cc91a-9922-4c45-b0d4-5e7a9dc53c62",
   "metadata": {},
   "outputs": [],
   "source": [
    "def resnet_get_feature(X, model, device=\"cuda\"):\n",
    "    dset = torch.utils.data.TensorDataset(torch.from_numpy(X.transpose(0, 3, 1, 2)).float())\n",
    "    loader = torch.utils.data.DataLoader(dset, batch_size=256)\n",
    "    \n",
    "    model.to(device).eval()\n",
    "    fetX = []\n",
    "    for (x, ) in tqdm(loader, desc=\"[pred_fn]\"):\n",
    "        x = x.to(device)\n",
    "        x = model.get_repr(x)\n",
    "        fetX.append(x.cpu().detach().numpy())\n",
    "    fetX = np.concatenate(fetX, axis=0)\n",
    "    return fetX"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0640d1a6-8545-41a5-8198-a51d2df99af7",
   "metadata": {},
   "outputs": [],
   "source": [
    "ori_trnX, _, _, _, ind = auto_var.get_var_with_argument(\"dataset\", \"mnistv30-10-1-0\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01c01191-089f-4296-8e25-6e458135db06",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(ori_trnX[ind][0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d98e7fe-6abf-4639-b0b7-dd221435082e",
   "metadata": {},
   "outputs": [],
   "source": [
    "arch = \"ResNet50\"\n",
    "\n",
    "ori_trnX, _, _, _, _ = auto_var.get_var_with_argument(\"dataset\", \"cifar10\")\n",
    "\n",
    "spurious_version = \"v8\"\n",
    "ds_name = f\"cifar10{spurious_version}-50-0-0\"\n",
    "model_path = f\"../models/train_classifier/64-{ds_name}-70-{0.1}-ce-tor-{arch}-{0.0}-adam-0-0.0.pt\"\n",
    "trnX, trny, tstX, tsty, spurious_ind = auto_var.get_var_with_argument(\"dataset\", ds_name)\n",
    "n_classes = len(np.unique(trny))\n",
    "n_channels = trnX.shape[-1]\n",
    "res = torch.load(model_path)\n",
    "model = getattr(archs, arch)(n_features=(np.prod(trnX.shape[1:]), ), n_channels=n_channels, n_classes=n_classes)\n",
    "model.load_state_dict(res['model_state_dict'])\n",
    "\n",
    "ori_features = resnet_get_feature(ori_trnX[spurious_ind], model)\n",
    "spu_features = resnet_get_feature(trnX[spurious_ind], model)\n",
    "tst_features = resnet_get_feature(tstX, model)\n",
    "tstX = add_colored_spurious_correlation(tstX, spurious_version)\n",
    "tst_spu_features = resnet_get_feature(tstX, model)\n",
    "spuX = np.zeros((1, 32, 32, 3))\n",
    "spuX = add_colored_spurious_correlation(spuX, spurious_version)\n",
    "only_spu_features = resnet_get_feature(spuX, model)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "efcacb48-0b4a-493f-9b55-774cbff79746",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.where((only_spu_features > 0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8bb1aea-ceef-46c8-8851-44d9a92d2bc8",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.where((ori_features > 0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f59a4a92-0679-4a89-9403-24fcf4d27d9a",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.where((spu_features > 0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ab814db-0347-4e43-bb92-f8a7e89eea85",
   "metadata": {},
   "outputs": [],
   "source": [
    "(np.where((tst_features > 0))[1] == 349).sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f998c34a-4e21-400c-80cc-ec5cbc116795",
   "metadata": {},
   "outputs": [],
   "source": [
    "(np.where((tst_spu_features > 0))[1] == 349).sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f41b69af-3b73-4321-9a9a-36e1b8744121",
   "metadata": {},
   "outputs": [],
   "source": [
    "tst_features[:, 349].mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ede7a300-76eb-48f8-bccd-b6c134bb5922",
   "metadata": {},
   "outputs": [],
   "source": [
    "tst_spu_features[:, 349].mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f43ab6ea-428e-4c8f-af58-3b5db9c8c037",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(tstX[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "77da36a0-df11-45ff-9da7-db642725c7c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.fc.bias"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ea3493b-1638-41a5-a18d-f81f2520bf6c",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(tstX[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f494e363-d097-48b9-8f4c-0e7c957c9782",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "95d2658b-fa27-47e9-8509-cb84fc5a56d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "ori_features[:, [747, 1369]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "838c340b-56ff-4998-ae1c-b8ed595d0984",
   "metadata": {},
   "outputs": [],
   "source": [
    "spu_features[:, [747, 1369]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "832c5e8c-49e9-4d05-b09f-4b513f701076",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3aa7fe38-336b-4c1a-8f22-873eba00e0db",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0294b1c7-c3af-4b0b-b9cd-6dbfe26a9a36",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "bbf39ec7-cb99-4764-b0cb-d748837639a0",
   "metadata": {},
   "source": [
    "# Lime"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc31a169-3265-418b-835a-6720dbc7fd0c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from lime import lime_image\n",
    "from lime.wrappers.scikit_image import SegmentationAlgorithm\n",
    "from skimage.color import gray2rgb, rgb2gray\n",
    "explainer = lime_image.LimeImageExplainer(verbose = False)\n",
    "segmenter = SegmentationAlgorithm('quickshift', kernel_size=1, max_dist=200, ratio=0.2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39fae16a-b0a5-4622-8814-4b07a6736f4e",
   "metadata": {},
   "outputs": [],
   "source": [
    "arch = \"LargeMLP\"\n",
    "\n",
    "ori_trnX, _, _, _, _ = auto_var.get_var_with_argument(\"dataset\", \"mnist\")\n",
    "\n",
    "spurious_version = \"v8\"\n",
    "#spurious_version = \"v20\"\n",
    "#spurious_version = \"v30\"\n",
    "\n",
    "ds_name = f\"mnist{spurious_version}-3-0-0\"\n",
    "ds_name = f\"twoclassmnist{spurious_version}-3-0-1-0\"\n",
    "#model_path = f\"../models/train_classifier/128-{ds_name}-70-0.01-ce-tor-{arch}-0.9-sgd-0-0.0.pt\"\n",
    "model_path = f\"../models/train_classifier/128-{ds_name}-70-0.01-ce-tor-{arch}-0.0-adam-0-0.0.pt\"\n",
    "model, trnX, trny, tstX, tsty, spurious_ind = get_model_dset(ds_name, model_path, arch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ab2fef1-818b-473f-bddb-0c83efb658fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "def classifier_fn(x):\n",
    "    x = x[:, :, :, 0]\n",
    "    x = torch.from_numpy(x.reshape(len(x), -1)).float()\n",
    "    return model(x).detach().cpu().numpy()\n",
    "\n",
    "explanation = explainer.explain_instance(gray2rgb(sputstX[0]), \n",
    "                                         classifier_fn = classifier_fn, \n",
    "                                         top_labels=10, hide_color=0, num_samples=10000, segmentation_fn=segmenter)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0e72f746-2abc-44d5-ab86-6da8731b970f",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(gray2rgb(sputstX[0].transpose(2, 0, 1))[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9967123f-04b3-4184-9cef-76d699d1fa85",
   "metadata": {},
   "outputs": [],
   "source": [
    "temp, mask = explanation.get_image_and_mask(y_test[0], positive_only=True, num_features=10, hide_rest=False, min_weight = 0.01)\n",
    "fig, (ax1, ax2) = plt.subplots(1,2, figsize = (8, 4))\n",
    "ax1.imshow(label2rgb(mask,temp, bg_label = 0), interpolation = 'nearest')\n",
    "ax1.set_title('Positive Regions for {}'.format(y_test[0]))\n",
    "temp, mask = explanation.get_image_and_mask(y_test[0], positive_only=False, num_features=10, hide_rest=False, min_weight = 0.01)\n",
    "ax2.imshow(label2rgb(3-mask,temp, bg_label = 0), interpolation = 'nearest')\n",
    "ax2.set_title('Positive/Negative Regions for {}'.format(y_test[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c8bd4132-2dd6-4128-a17a-e9827e0a9836",
   "metadata": {},
   "outputs": [],
   "source": [
    "arch = \"LargeMLP\"\n",
    "\n",
    "ori_trnX, _, _, _, _ = auto_var.get_var_with_argument(\"dataset\", \"mnist\")\n",
    "\n",
    "#spurious_version = \"v8\"\n",
    "spurious_version = \"v20\"\n",
    "#spurious_version = \"v30\"\n",
    "\n",
    "ds_name = f\"mnist{spurious_version}-3-0-0\"\n",
    "ds_name = f\"twoclassmnist{spurious_version}-3-0-1-0\"\n",
    "#model_path = f\"../models/train_classifier/128-{ds_name}-70-0.01-ce-tor-{arch}-0.9-sgd-0-0.0.pt\"\n",
    "model_path = f\"../models/train_classifier/128-{ds_name}-70-0.01-ce-tor-{arch}-0.0-adam-0-0.0.pt\"\n",
    "model, trnX, trny, tstX, tsty, spurious_ind = get_model_dset(ds_name, model_path, arch)\n",
    "\n",
    "ori_features = mlp_get_feature(ori_trnX[spurious_ind], model)\n",
    "spu_features = mlp_get_feature(trnX[spurious_ind], model)\n",
    "tst_features = mlp_get_feature(tstX, model)\n",
    "sputstX = add_spurious_correlation(tstX, spurious_version, 0)\n",
    "tst_spu_features = mlp_get_feature(sputstX, model)\n",
    "spuX = np.zeros((1, 28, 28, 1))\n",
    "spuX = add_spurious_correlation(spuX, spurious_version, 0)\n",
    "only_spu_features = mlp_get_feature(spuX, model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66fcdb89-047e-4da1-9a90-f99f577d1d6e",
   "metadata": {},
   "outputs": [],
   "source": [
    "set(np.where((only_spu_features > 0))[1]) - set(np.unique(np.where((spu_features > 0))[1]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a0107f4-ef66-435a-b5f7-7adebd10e599",
   "metadata": {},
   "outputs": [],
   "source": [
    "set(np.where((only_spu_features > 0))[1]) - set(np.unique(np.where((tst_features > 0))[1]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "851d431c-d20b-4757-9063-1c37471f9b0b",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.where((only_spu_features > 0))[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "09897845-2aea-461f-92e7-2a4e61c94654",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.unique(np.where((ori_features > 0))[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "317d8b42-80b2-45ac-8597-466d3af0a6ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.unique(np.where((spu_features > 0))[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "119a4501-46e9-4960-a479-40d2f0ec2604",
   "metadata": {},
   "outputs": [],
   "source": [
    "get_fraction_intersect(tst_features, set(np.where(only_spu_features > 0)[1]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c64c27a4-d85e-416a-a77e-a6701fcd4b66",
   "metadata": {},
   "outputs": [],
   "source": [
    "get_fraction_intersect(tst_spu_features, set(np.where(only_spu_features > 0)[1]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d845390e-99be-412d-bca1-f23d32331223",
   "metadata": {},
   "outputs": [],
   "source": [
    "inputs = torch.from_numpy(sputstX[0].reshape(1, -1)).float().to(\"cuda\")\n",
    "inputs.requires_grad_(True)\n",
    "\n",
    "x = F.relu(model.hidden(inputs))\n",
    "x = F.relu(model.hidden2(x))\n",
    "outputs = model.hidden3(x)\n",
    "grad = torch.autograd.grad(outputs[0, 214], inputs, retain_graph=True)[0].detach().cpu().numpy()\n",
    "\n",
    "#outputs = model(inputs)\n",
    "#grad = torch.autograd.grad(outputs[0, 1], inputs, retain_graph=True)[0].detach().cpu().numpy()\n",
    "\n",
    "plt.imshow(grad.reshape(28, 28))\n",
    "plt.colorbar()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
