{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f57c1688-799b-4487-9410-29d2ee9bae5b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be815178-12b9-4758-b4ee-a00a661af1b4",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1796af9-f754-4f9e-8095-00c217b88086",
   "metadata": {},
   "outputs": [],
   "source": [
    "vmin, vmax = -0., 5.5\n",
    "\n",
    "base_ds = \"mnistv8\"\n",
    "n_spurious = 2000\n",
    "weights = torch.load(f\"../models/train_classifier/128-{base_ds}-{n_spurious}-0-0-70-0.01-ce-tor-LargeMLP-0.0-adam-0-0.0.pt\")\n",
    "weight = weights['model_state_dict']['hidden.weight'].detach().cpu().numpy()\n",
    "weight = weight.max(0).reshape(28, 28)\n",
    "plt.figure(figsize=(6, 6))\n",
    "plt.imshow(weight, vmin=vmin, vmax=vmax)\n",
    "plt.axis('off')\n",
    "plt.tight_layout()\n",
    "plt.savefig(f\"./figs/MLP_weights/{base_ds}_{n_spurious}.png\", bbox_inches='tight')\n",
    "print(weight.max(), weight.min())\n",
    "\n",
    "base_ds = \"mnist\"\n",
    "weights = torch.load(f\"../models/train_classifier/128-{base_ds}-70-0.01-0.01-ce-tor-LargeMLP-0.0-adam-0-0.0.pt\")\n",
    "weight = weights['model_state_dict']['hidden.weight'].detach().cpu().numpy()\n",
    "weight = weight.max(0).reshape(28, 28)\n",
    "plt.figure(figsize=(6, 6))\n",
    "plt.imshow(weight, vmin=vmin, vmax=vmax)\n",
    "plt.axis('off')\n",
    "plt.tight_layout()\n",
    "plt.savefig(f\"./figs/MLP_weights/{base_ds}.png\", bbox_inches='tight')\n",
    "print(weight.max(), weight.min())\n",
    "\n",
    "base_ds = \"mnist\"\n",
    "weights = torch.load(f\"../models/train_classifier/128-{base_ds}-70-0.01-0.01-ce-tor-LargeMLP-0.0-adam-0-0.0.pt\")\n",
    "weight = weights['model_state_dict']['hidden.weight'].detach().cpu().numpy()\n",
    "ori_weight = weight.max(0).reshape(28, 28)\n",
    "base_ds = \"mnistv8\"\n",
    "n_spurious = 2000\n",
    "weights = torch.load(f\"../models/train_classifier/128-{base_ds}-{n_spurious}-0-0-70-0.01-ce-tor-LargeMLP-0.0-adam-0-0.0.pt\")\n",
    "weight = weights['model_state_dict']['hidden.weight'].detach().cpu().numpy()\n",
    "weight = weight.max(0).reshape(28, 28)\n",
    "plt.figure(figsize=(6, 6))\n",
    "plt.imshow(np.abs(weight - ori_weight), vmin=vmin, vmax=vmax)\n",
    "plt.axis('off')\n",
    "plt.tight_layout()\n",
    "plt.savefig(f\"./figs/MLP_weights/{base_ds}_{n_spurious}_dif.png\", bbox_inches='tight')\n",
    "print(weight.max(), weight.min())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22fd3409-a4ff-4c9c-a293-46887b0f7547",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "290b9e24-1fa5-425d-b501-a20b8ab42e03",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_ds = \"mnistv3\"\n",
    "\n",
    "fig, axs = plt.subplots(2, 2, figsize=(12, 12))\n",
    "\n",
    "vmin, vmax = -0., 5.0\n",
    "\n",
    "n_spurious = 10\n",
    "\n",
    "weights = torch.load(f\"../models/train_classifier/128-{base_ds}-{n_spurious}-0-0-70-0.01-ce-tor-LargeMLP-0.0-adam-0-0.0.pt\")\n",
    "weight = weights['model_state_dict']['hidden.weight'].detach().cpu().numpy()\n",
    "weight = weight.max(0).reshape(28, 28)\n",
    "axs[0][0].imshow(weight, vmin=vmin, vmax=vmax)\n",
    "print(weight.max(), weight.min())\n",
    "\n",
    "weights = torch.load(f\"../models/incremental_retraining/128-{base_ds}-{n_spurious}-0-0-140-0.01-ce-tor-LargeMLP-128-mnistv3-{n_spurious}-0-0-70-0.01-ce-tor-LargeMLP-0.0-adam-0-0.0.pt-0.0-adam-0-0.0.pt\")\n",
    "weight = weights['model_state_dict']['hidden.weight'].detach().cpu().numpy()\n",
    "weight = weight.max(0).reshape(28, 28)\n",
    "axs[0][1].imshow(weight, vmin=vmin, vmax=vmax)\n",
    "print(weight.max(), weight.min())\n",
    "\n",
    "\n",
    "n_spurious = 200\n",
    "\n",
    "weights = torch.load(f\"../models/train_classifier/128-{base_ds}-{n_spurious}-0-0-70-0.01-ce-tor-LargeMLP-0.0-adam-0-0.0.pt\")\n",
    "weight = weights['model_state_dict']['hidden.weight'].detach().cpu().numpy()\n",
    "weight = weight.max(0).reshape(28, 28)\n",
    "axs[1][0].imshow(weight, vmin=vmin, vmax=vmax)\n",
    "print(weight.max(), weight.min())\n",
    "\n",
    "\n",
    "weights = torch.load(f\"../models/incremental_retraining/128-{base_ds}-{n_spurious}-0-0-140-0.01-ce-tor-LargeMLP-128-mnistv3-{n_spurious}-0-0-70-0.01-ce-tor-LargeMLP-0.0-adam-0-0.0.pt-0.0-adam-0-0.0.pt\")\n",
    "weight = weights['model_state_dict']['hidden.weight'].detach().cpu().numpy()\n",
    "weight = weight.max(0).reshape(28, 28)\n",
    "im = axs[1][1].imshow(weight, vmin=vmin, vmax=vmax)\n",
    "print(weight.max(), weight.min())\n",
    "\n",
    "\n",
    "fig.subplots_adjust(right=0.8)\n",
    "cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])\n",
    "fig.colorbar(im, cax=cbar_ax)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e8966ef8-3e7b-435a-9ed8-16d0f676ea3c",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(2, 2, figsize=(12, 12))\n",
    "\n",
    "vmin, vmax = -0., 0.25\n",
    "\n",
    "n_spurious = 10\n",
    "\n",
    "weights = torch.load(f\"../models/train_classifier/128-mnistv3-{n_spurious}-0-0-70-0.01-ce-tor-LargeMLP-0.9-sgd-0-0.0.pt\")\n",
    "weight = weights['model_state_dict']['hidden.weight'].detach().cpu().numpy()\n",
    "weight = weight.max(0).reshape(28, 28)\n",
    "axs[0][0].imshow(weight, vmin=vmin, vmax=vmax)\n",
    "print(weight.max(), weight.min())\n",
    "\n",
    "weights = torch.load(f\"../models/incremental_retraining/128-mnistv3-{n_spurious}-0-0-140-0.01-ce-tor-LargeMLP-128-mnistv3-{n_spurious}-0-0-70-0.01-ce-tor-LargeMLP-0.9-sgd-0-0.0.pt-0.9-sgd-0-0.0.pt\")\n",
    "weight = weights['model_state_dict']['hidden.weight'].detach().cpu().numpy()\n",
    "weight = weight.max(0).reshape(28, 28)\n",
    "axs[0][1].imshow(weight, vmin=vmin, vmax=vmax)\n",
    "print(weight.max(), weight.min())\n",
    "\n",
    "\n",
    "n_spurious = 200\n",
    "\n",
    "weights = torch.load(f\"../models/train_classifier/128-mnistv3-{n_spurious}-0-0-70-0.01-ce-tor-LargeMLP-0.9-sgd-0-0.0.pt\")\n",
    "weight = weights['model_state_dict']['hidden.weight'].detach().cpu().numpy()\n",
    "weight = weight.max(0).reshape(28, 28)\n",
    "axs[1][0].imshow(weight, vmin=vmin, vmax=vmax)\n",
    "print(weight.max(), weight.min())\n",
    "\n",
    "\n",
    "weights = torch.load(f\"../models/incremental_retraining/128-mnistv3-{n_spurious}-0-0-140-0.01-ce-tor-LargeMLP-128-mnistv3-{n_spurious}-0-0-70-0.01-ce-tor-LargeMLP-0.9-sgd-0-0.0.pt-0.9-sgd-0-0.0.pt\")\n",
    "weight = weights['model_state_dict']['hidden.weight'].detach().cpu().numpy()\n",
    "weight = weight.max(0).reshape(28, 28)\n",
    "im = axs[1][1].imshow(weight, vmin=vmin, vmax=vmax)\n",
    "print(weight.max(), weight.min())\n",
    "\n",
    "\n",
    "fig.subplots_adjust(right=0.8)\n",
    "cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])\n",
    "fig.colorbar(im, cax=cbar_ax)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d7c80e1-9670-46e8-87a4-9ef13f0a4ba7",
   "metadata": {},
   "outputs": [],
   "source": [
    "weight.min()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "02f100dd-4527-43e2-8611-b7ba2ad86937",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_spurious = 3\n",
    "weights = torch.load(f\"../models/train_classifier/128-mnistv3-{n_spurious}-0-0-70-0.01-ce-tor-LargeMLP-0.0-adam-0-0.0.pt\")\n",
    "weight = weights['model_state_dict']['hidden.weight'].detach().cpu().numpy()\n",
    "weight = weight.max(0).reshape(28, 28)\n",
    "plt.imshow(weight, vmin=vmin, vmax=vmax)\n",
    "print(weight.max(), weight.min())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f2c9e45-4b50-49df-8950-0f307d6838ae",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87aa0163-8229-43de-8a21-b49ac64c1765",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8899f262-5ebf-43d6-a14a-9ad90dcc205c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e3fbca3-79b4-44db-a138-df383bec2a80",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34f97ede-1830-4021-a5e8-7e3a8b56efcc",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd557cb9-3ca3-4b9f-8efe-38fd2f3b324d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ca686bc-6420-4a52-8106-04dc024a480c",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_ds = \"mnistv8\"\n",
    "n_spurious = 100\n",
    "weights = torch.load(f\"../models/train_classifier/128-{base_ds}-{n_spurious}-0-0-70-0.01-0.01-ce-tor-LargeMLP-0.0-adam-0-0.0.pt\")\n",
    "weight = weights['model_state_dict']['hidden.weight'].detach().cpu().numpy()\n",
    "#weight = weight.max(0).reshape(28, 28)\n",
    "#plt.figure(figsize=(6, 6))\n",
    "#plt.imshow(weight, vmin=vmin, vmax=vmax)\n",
    "#plt.axis('off')\n",
    "#plt.tight_layout()\n",
    "#plt.savefig(f\"./figs/MLP_weights/{base_ds}.png\", bbox_inches='tight')\n",
    "print(weight.max(), weight.min())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4819d61e-14a2-46c8-9a2a-90c05749daa0",
   "metadata": {},
   "outputs": [],
   "source": [
    "weight.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "98567a8b-8be9-478f-bffb-3b8ea24675ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(100, 5))\n",
    "plt.imshow(weight[:, :56].T)\n",
    "plt.colorbar()\n",
    "plt.savefig(\"temp2.png\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad3d46a4-d091-448d-85b7-c966b9fbcfcf",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"../\")\n",
    "from spurious_ml.datasets import add_spurious_correlation, add_colored_spurious_correlation\n",
    "from spurious_ml.models.torch_utils import archs\n",
    "from spurious_ml.variables import auto_var\n",
    "from params import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76b8b281-d309-4c03-91ec-44cabdc1d9f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds_name = f\"mnistv8-10-0-0\"\n",
    "trnX, trny, tstX, tsty, spurious_ind = auto_var.get_var_with_argument(\"dataset\", ds_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3eb55ba2-2025-4146-b3eb-68845b496304",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_classes = len(np.unique(trny))\n",
    "n_channels = trnX.shape[-1]\n",
    "model = getattr(archs, \"LargeMLP\")(n_features=(np.prod(trnX.shape[1:]), ), n_channels=n_channels, n_classes=n_classes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "56e1cdea-c8db-4bbb-98b9-35d10b2efd9b",
   "metadata": {},
   "outputs": [],
   "source": [
    "x = torch.from_numpy(trnX[spurious_ind[0]].reshape(1, 784))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "efd3da0d-70a6-4474-b301-e3bd0ecf5f1e",
   "metadata": {},
   "outputs": [],
   "source": [
    "loss = torch.nn.CrossEntropyLoss()(model(x), torch.zeros(1).long())\n",
    "loss.backward()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f255e6ee-ecdb-4684-acba-364a44df723a",
   "metadata": {},
   "outputs": [],
   "source": [
    "grad = model.hidden.weight.grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1dee8249-6505-45a8-ad71-8297250c1f76",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(100, 5))\n",
    "plt.imshow(grad[:, :56].T)\n",
    "plt.colorbar()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cefa8af6-ed8f-44f0-9b27-c9e1e43b7c47",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
