{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0cdfc8ab-96b8-438c-8b02-a93d5c1f0bb7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"../\")\n",
    "import os\n",
    "from IPython.display import display\n",
    "from functools import partial\n",
    "from collections import OrderedDict\n",
    "\n",
    "import matplotlib\n",
    "matplotlib.rcParams['pdf.fonttype'] = 42\n",
    "matplotlib.rcParams['ps.fonttype'] = 42\n",
    "matplotlib.rc('text', usetex=False)\n",
    "\n",
    "from PIL import Image\n",
    "import matplotlib.pyplot as plt\n",
    "import joblib\n",
    "import numpy as np\n",
    "import torch\n",
    "from torch import nn\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from tqdm.notebook import tqdm\n",
    "import torchvision\n",
    "\n",
    "from lucent.modelzoo.util import get_model_layers\n",
    "from lucent.optvis import render, param, transform, objectives\n",
    "from lucent.optvis.param.color import to_valid_rgb\n",
    "from lucent.optvis.param.spatial import rfft2d_freqs\n",
    "from lucent.misc.io import show\n",
    "\n",
    "\n",
    "from spurious_ml.models.torch_utils import archs, data_augs\n",
    "from spurious_ml.variables import auto_var\n",
    "from utils import params_to_dataframe\n",
    "\n",
    "fontsize=15"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d900e090-534a-4f19-bc0f-476c97fdbaf2",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds_name = \"cifar10\"\n",
    "arch_name = \"ResNet50\"\n",
    "model_path = f\"../models/train_classifier/64-{ds_name}-70-0.01-ce-tor-{arch_name}-0.9-sgd-0-0.0.pt\"\n",
    "model_path = f\"../models/train_classifier/64-{ds_name}-70-0.1-ce-tor-{arch_name}-0.0-adam-0-0.0.pt\"\n",
    "\n",
    "trnX, trny, tstX, tsty, _ = auto_var.get_var_with_argument(\"dataset\", ds_name)\n",
    "res = torch.load(model_path)\n",
    "model = getattr(archs, arch_name)(n_features=None, n_channels=3, n_classes=10)\n",
    "model.load_state_dict(res['model_state_dict'])\n",
    "model.eval()\n",
    "\n",
    "device = \"cuda\"\n",
    "\n",
    "def get_features(X, model, aug_fn=None):\n",
    "    dset = torch.utils.data.TensorDataset(torch.from_numpy(X.transpose(0, 3, 1, 2)).float())\n",
    "    loader = torch.utils.data.DataLoader(dset, batch_size=128)\n",
    "\n",
    "    device = \"cuda\"\n",
    "    model.to(device).eval()\n",
    "    fetX = []\n",
    "    for (x, ) in tqdm(loader, desc=\"predicting\"):\n",
    "        x = x.to(device)\n",
    "        if aug_fn is not None:\n",
    "            x = aug_fn(x)\n",
    "        with torch.no_grad():\n",
    "            fetX.append(model(x).cpu().detach().flatten(1).numpy())\n",
    "    fetX = np.concatenate(fetX, axis=0)\n",
    "    return fetX"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf5783bf-5826-4143-9cb2-609bf3b19e6c",
   "metadata": {},
   "outputs": [],
   "source": [
    "fet_model = nn.Sequential(*list(model.children())[:-2])\n",
    "#fet_model = model.features\n",
    "trn_fetX= get_features(trnX, fet_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dcaf0d81-6f47-49eb-9575-6579cce1f923",
   "metadata": {},
   "outputs": [],
   "source": [
    "(trn_fetX < 0).sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c721ea7e-f530-403e-b92b-a47abd2b25ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.hist((trn_fetX > 0).sum(0))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1ea83e3-35ae-4407-a19d-cfa766bbe337",
   "metadata": {},
   "source": [
    "# "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "222ed3ff-e15b-48f3-acac-d450a66b5d59",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.where((trn_fetX > 0).sum(0) > 100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b0f70f6-940a-4784-9437-08ac74faaa1c",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.where(np.logical_and((trn_fetX > 0).sum(0) > 50, (trn_fetX > 0).sum(0) < 80))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bfdbaf32-17b2-4fa3-b317-11b84c07d963",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.where(np.logical_and((trn_fetX > 0).sum(0) > 0, (trn_fetX > 0).sum(0) < 10))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36a0437a-b07a-43cf-b8e4-f8d63446f134",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.where((trn_fetX > 0).sum(0) == 5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce48f58b-a8de-4f20-93b3-a3cebc1964ef",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53ecb36f-c072-474e-9871-06931ddd33ad",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a2a8356c-722b-4a7f-9a52-4569f109426e",
   "metadata": {},
   "outputs": [],
   "source": [
    "fet_grad_norm = {}\n",
    "for fet_no in tqdm(range(2048)):\n",
    "    idxs = np.where(trn_fetX[:, fet_no] > 0)[0]\n",
    "    norms = []\n",
    "    for i in idxs[:min(100, len(idxs))]:\n",
    "        im = torch.from_numpy(trnX[i: i+1].transpose(0, 3, 1, 2)).to(device).float()\n",
    "        im.requires_grad_(True)\n",
    "        fet_model(im)[:, fet_no].backward()\n",
    "        norms.append(im.grad.norm())\n",
    "    fet_grad_norm[fet_no] = norms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15be37d2-c8a7-401e-b3d7-ee26d55d214f",
   "metadata": {},
   "outputs": [],
   "source": [
    "ttt = {}\n",
    "for fet_no in tqdm(range(2048)):\n",
    "    t = len(np.where(trn_fetX[:, fet_no] > 0)[0])\n",
    "    if t > 0:\n",
    "        ttt.setdefault(t, []).append(np.mean([_.cpu().numpy() for _ in fet_grad_norm[fet_no]]))\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43f475ac-cb79-415d-90dc-188e36a42a8b",
   "metadata": {},
   "outputs": [],
   "source": [
    "ttt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dee3516a-e24e-4a21-9b63-c162a5a032c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "fet_no = 65\n",
    "image_f = lambda : param.image(32)\n",
    "model.to(device).eval()\n",
    "vis_fets = render.render_vis(model, f\"layer4_2_bn3:{fet_no}\", param_f=image_f, show_inline=True, preprocess=False, fixed_image_size=32, verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "32294277-6790-44b9-8413-b191e11ff5d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "fet_no = 65\n",
    "idxs = np.where(trn_fetX[:, fet_no] > 0)[0][:1]\n",
    "im = torch.from_numpy(trnX[idxs].transpose(0, 3, 1, 2)).to(device).float()\n",
    "image_f = lambda : fft_preprocess_image(im)\n",
    "\n",
    "optimizer = lambda params: torch.optim.SGD(params, lr=1.0, momentum=0.9)\n",
    "transforms = [torchvision.transforms.Lambda(lambda x: x)]\n",
    "transforms = None\n",
    "\n",
    "#image_f = lambda : param.image(32)\n",
    "\n",
    "vis_fets = render_vis(model, f\"layer4_2_relu:{fet_no}\", param_f=image_f, optimizer=optimizer,\n",
    "                      transforms=transforms, show_inline=True, preprocess=False, fixed_image_size=32, verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bffdee7f-d28d-4c6d-969a-57f8a48d3f83",
   "metadata": {},
   "outputs": [],
   "source": [
    "i = 1\n",
    "plt.imshow(np.sign(vis_fets[0][i] - trnX[idxs[i]]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6cb09af2-3f4e-4ef6-97e7-a6e58845976f",
   "metadata": {},
   "outputs": [],
   "source": [
    "i=2\n",
    "plt.imshow(vis_fets[0][i] / vis_fets[0][i].max())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9525d626-d31d-4d61-a35a-59490974f964",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(trnX[idxs[i]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12d1c6d9-8d2d-4f9e-a08c-1ca3ee2ca23f",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.fft.fftn(im, s=(32, 32), norm='ortho').shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c914490a-a7ca-48b9-808c-0b2c7d4da4e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "param.image(128)[0][0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9303a197-1a78-4e6c-98df-fb55bc66b676",
   "metadata": {},
   "outputs": [],
   "source": [
    "param.image(128)[1]().shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3296278b-73ac-4a39-9abb-6a8c06a3bc36",
   "metadata": {},
   "outputs": [],
   "source": [
    "ffted_im = torch.fft.irfftn(torch.fft.rfftn(im.cpu(), s=(32, 32), norm='ortho'), s=(32, 32), norm='ortho')\n",
    "plt.imshow(ffted_im[0].numpy().transpose(1, 2, 0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "987c5fe3-1787-46b2-a8e0-2c23d94a13cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "model(im.cuda())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f2842c1-b496-418b-b77b-750e49f449a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "model(ffted_im.cuda())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "705eec41-95f6-4bd9-b9ac-e67586be4147",
   "metadata": {},
   "outputs": [],
   "source": [
    "def fft_preprocess_image(img, decay_power=1):\n",
    "    batch, channels, h, w = img.shape\n",
    "    freqs = rfft2d_freqs(h, w)\n",
    "    init_val_size = (batch, channels) + freqs.shape\n",
    "\n",
    "    #spectrum_real_imag_t = (torch.randn(*init_val_size) * sd).to(device).requires_grad_(True)\n",
    "    spectrum_real_imag_t = torch.fft.rfftn(img, s=(h, w), norm='backward').to(device).requires_grad_(True)\n",
    "\n",
    "    scale = 1.0 / np.maximum(freqs, 1.0 / max(w, h)) ** decay_power\n",
    "    scale = torch.tensor(scale).float()[None, None, ...].to(device)\n",
    "\n",
    "    def inner():\n",
    "        scaled_spectrum_t = spectrum_real_imag_t\n",
    "        #scaled_spectrum_t = scale * spectrum_real_imag_t\n",
    "        #if type(scaled_spectrum_t) is not torch.complex64:\n",
    "        #    scaled_spectrum_t = torch.view_as_complex(scaled_spectrum_t)\n",
    "        image = torch.fft.irfftn(scaled_spectrum_t, s=(h, w), norm='backward')\n",
    "        image = image[:batch, :channels, :h, :w]\n",
    "        #magic = 4.0 # Magic constant from Lucid library; increasing this seems to reduce saturation\n",
    "        #image = image / magic\n",
    "        return image\n",
    "    return [spectrum_real_imag_t], inner"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "326e8286-fad1-4f90-9c4f-e1e2f93ffd87",
   "metadata": {},
   "outputs": [],
   "source": [
    "def render_vis(\n",
    "    model,\n",
    "    objective_f,\n",
    "    param_f=None,\n",
    "    optimizer=None,\n",
    "    transforms=None,\n",
    "    thresholds=(512,),\n",
    "    verbose=False,\n",
    "    preprocess=True,\n",
    "    progress=True,\n",
    "    image_name=None,\n",
    "    show_inline=False,\n",
    "    fixed_image_size=None,\n",
    "):\n",
    "    if param_f is None:\n",
    "        param_f = lambda: param.image(128)\n",
    "    # param_f is a function that should return two things\n",
    "    # params - parameters to update, which we pass to the optimizer\n",
    "    # image_f - a function that returns an image as a tensor\n",
    "    params, image_f = param_f()\n",
    "\n",
    "    if optimizer is None:\n",
    "        optimizer = lambda params: torch.optim.Adam(params, lr=5e-2)\n",
    "    optimizer = optimizer(params)\n",
    "\n",
    "    if transforms is None:\n",
    "        transforms = transform.standard_transforms\n",
    "    #transforms = transforms.copy()\n",
    "\n",
    "    # Upsample images smaller than 224\n",
    "    image_shape = image_f().shape\n",
    "    if fixed_image_size is not None:\n",
    "        new_size = fixed_image_size\n",
    "    elif image_shape[2] < 224 or image_shape[3] < 224:\n",
    "        new_size = 224\n",
    "    else:\n",
    "        new_size = None\n",
    "    if new_size:\n",
    "        transforms.append(\n",
    "            torch.nn.Upsample(size=new_size, mode=\"bilinear\", align_corners=True)\n",
    "        )\n",
    "\n",
    "    transform_f = transform.compose(transforms)\n",
    "\n",
    "    hook = hook_model(model, image_f)\n",
    "    objective_f = objectives.as_objective(objective_f)\n",
    "\n",
    "    if verbose:\n",
    "        model(transform_f(image_f()))\n",
    "        print(\"Initial loss: {:.3f}\".format(objective_f(hook)))\n",
    "\n",
    "    images = []\n",
    "    try:\n",
    "        for i in tqdm(range(1, max(thresholds) + 1), disable=(not progress)):\n",
    "            def closure():\n",
    "                optimizer.zero_grad()\n",
    "                try:\n",
    "                    model(transform_f(image_f()))\n",
    "                except RuntimeError as ex:\n",
    "                    if i == 1:\n",
    "                        # Only display the warning message\n",
    "                        # on the first iteration, no need to do that\n",
    "                        # every iteration\n",
    "                        warnings.warn(\n",
    "                            \"Some layers could not be computed because the size of the \"\n",
    "                            \"image is not big enough. It is fine, as long as the non\"\n",
    "                            \"computed layers are not used in the objective function\"\n",
    "                            f\"(exception details: '{ex}')\"\n",
    "                        )\n",
    "                loss = objective_f(hook)\n",
    "                loss.backward()\n",
    "                return loss\n",
    "                \n",
    "            optimizer.step(closure)\n",
    "            if i in thresholds:\n",
    "                image = tensor_to_img_array(image_f())\n",
    "                if verbose:\n",
    "                    print(\"Loss at step {}: {:.3f}\".format(i, objective_f(hook)))\n",
    "                    if show_inline:\n",
    "                        show(image)\n",
    "                images.append(image)\n",
    "    except KeyboardInterrupt:\n",
    "        print(\"Interrupted optimization at step {:d}.\".format(i))\n",
    "        if verbose:\n",
    "            print(\"Loss at step {}: {:.3f}\".format(i, objective_f(hook)))\n",
    "        images.append(tensor_to_img_array(image_f()))\n",
    "\n",
    "    if show_inline:\n",
    "        show(tensor_to_img_array(image_f()))\n",
    "    return images\n",
    "\n",
    "def tensor_to_img_array(tensor):\n",
    "    image = tensor.cpu().detach().numpy()\n",
    "    image = np.transpose(image, [0, 2, 3, 1])\n",
    "    return image\n",
    "\n",
    "class ModuleHook:\n",
    "    def __init__(self, module):\n",
    "        self.hook = module.register_forward_hook(self.hook_fn)\n",
    "        self.module = None\n",
    "        self.features = None\n",
    "\n",
    "    def hook_fn(self, module, input, output):\n",
    "        self.module = module\n",
    "        self.features = output\n",
    "\n",
    "    def close(self):\n",
    "        self.hook.remove()\n",
    "\n",
    "def hook_model(model, image_f):\n",
    "    features = OrderedDict()\n",
    "\n",
    "    # recursive hooking function\n",
    "    def hook_layers(net, prefix=[]):\n",
    "        if hasattr(net, \"_modules\"):\n",
    "            for name, layer in net._modules.items():\n",
    "                if layer is None:\n",
    "                    # e.g. GoogLeNet's aux1 and aux2 layers\n",
    "                    continue\n",
    "                features[\"_\".join(prefix + [name])] = ModuleHook(layer)\n",
    "                hook_layers(layer, prefix=prefix + [name])\n",
    "\n",
    "    hook_layers(model)\n",
    "\n",
    "    def hook(layer):\n",
    "        if layer == \"input\":\n",
    "            out = image_f()\n",
    "        elif layer == \"labels\":\n",
    "            out = list(features.values())[-1].features\n",
    "        else:\n",
    "            assert layer in features, f\"Invalid layer {layer}. Retrieve the list of layers with `lucent.modelzoo.util.get_model_layers(model)`.\"\n",
    "            out = features[layer].features\n",
    "        assert out is not None, \"There are no saved feature maps. Make sure to put the model in eval mode, like so: `model.to(device).eval()`. See README for example.\"\n",
    "        return out\n",
    "\n",
    "    return hook"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "680c715f-5563-4d47-8433-312fa8a54eea",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(trnX[823])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "74f4cd3e-0748-4cbc-8f97-400473f8de20",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(trnX[40749])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bad065c8-bffb-40d8-b7fe-112a234675b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "param.image(32, 32, None, 1)[0][0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3cb9b7a7-2c6f-4cfd-8723-3d374a9d5b7b",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.randn(128).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b44505ca-182f-4e7d-95c5-eb348471415f",
   "metadata": {},
   "outputs": [],
   "source": [
    "joblib.load(\"../results/train_classifier/64-cifar10-70-0.1-ce-tor-ResNet50-0.0-adam-0-0.0.pkl\")['tst_acc']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f83834dc-2cfe-46b0-9e28-41d60f95c950",
   "metadata": {},
   "outputs": [],
   "source": [
    "get_model_layers(fet_model)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
