{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load the models and compute the posteriors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "\n",
    "from permute import permute\n",
    "from scale import scale\n",
    "from utils import (\n",
    "    to_vector,\n",
    "    to_state_dict,\n",
    ")\n",
    "\n",
    "from tqdm.auto import tqdm\n",
    "\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the models\n",
    "\n",
    "# you may replace by your models if you trained them\n",
    "models = torch.load(\"models.pt\")\n",
    "# models = torch.load(\"your_models.pt\")\n",
    "\n",
    "model_shape = to_vector(models[0])[1]\n",
    "model_weights = torch.stack([m[\"2.weight\"].squeeze() for m in models])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Scale the weights\n",
    "\n",
    "scaled_chkpts = []\n",
    "for checkpoint in tqdm(models):\n",
    "    new_weights = {}\n",
    "    for key in checkpoint.keys():\n",
    "        new_weights[key.replace(\"model.\", \"\")] = checkpoint[key]\n",
    "\n",
    "    model = nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 1)).float()\n",
    "    model.load_state_dict(new_weights)\n",
    "    scaled_weights = scale(model)\n",
    "\n",
    "    # If scale is successful\n",
    "    if scaled_weights is not None:\n",
    "        scaled_chkpts.append(torch.concat([w.reshape(-1) for w in scaled_weights.values()]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Number of scaled checkpoints: {len(scaled_chkpts)} - Failed to scale: {len(models) - len(scaled_chkpts)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Permute the weights\n",
    "\n",
    "permuted_chkpts = []\n",
    "for checkpoint in tqdm(scaled_chkpts):\n",
    "    new_weights = {}\n",
    "    checkpoint = to_state_dict(checkpoint, model_shape)\n",
    "    for key in checkpoint.keys():\n",
    "        new_weights[key.replace(\"model.\", \"\")] = checkpoint[key]\n",
    "    model = nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 1)).float()\n",
    "    model.load_state_dict(new_weights)\n",
    "    # Here we choose to start sorting from the last layer to improve visuals\n",
    "    # permuted_weights = permute(model)[0]\n",
    "    order = new_weights[\"2.weight\"].sort().indices[0, :]\n",
    "    permuted_weights = permute(model, custom_pis=[order, None])[0]\n",
    "    permuted_chkpts.append(torch.concat([w.reshape(-1) for w in permuted_weights.values()]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Prepare the Figure\n",
    "\n",
    "plt.rc(\"axes\", axisbelow=True)\n",
    "plt.rc(\"font\", **{\"family\": \"serif\", \"serif\": [\"Cambria Math\"]})\n",
    "plt.rc(\"text\", usetex=True)\n",
    "cmap = sns.light_palette(\"#c96004\", as_cmap=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "f, (ax1, ax2, ax3) = plt.subplots(1, 3, sharey=True, figsize=(24, 8))\n",
    "f.subplots_adjust(hspace=0)\n",
    "ax1.grid(True, linestyle=\"--\", alpha=0.7)\n",
    "\n",
    "ax1.set_aspect(\"equal\")\n",
    "ax2.set_aspect(\"equal\")\n",
    "ax3.set_aspect(\"equal\")\n",
    "\n",
    "lin0 = model_weights.numpy()[:, 0]\n",
    "lin1 = model_weights.numpy()[:, 1]\n",
    "\n",
    "lin_O_ind = np.intersect1d(np.where(lin0 < 4), np.where(lin0 > -4))\n",
    "lin_1_ind = np.intersect1d(np.where(lin1 < 4), np.where(lin1 > -4))\n",
    "inds = np.intersect1d(lin_O_ind, lin_1_ind)\n",
    "\n",
    "ax1.hexbin(\n",
    "    x=lin0[inds],\n",
    "    y=lin1[inds],\n",
    "    cmap=cmap,\n",
    "    mincnt=3,\n",
    "    gridsize=85,\n",
    "    alpha=1,\n",
    "    norm=matplotlib.colors.Normalize(vmin=0, vmax=30, clip=True),\n",
    "    antialiased=True,\n",
    ")\n",
    "sns.scatterplot(\n",
    "    x=lin0[inds][:1000],\n",
    "    y=lin1[inds][:1000],\n",
    "    s=2,\n",
    "    alpha=0.1,\n",
    "    color=\"red\",\n",
    "    ax=ax1,\n",
    "    antialiased=True,\n",
    ")\n",
    "\n",
    "ax1.grid(True, linestyle=\"--\", alpha=0.7)\n",
    "ax1.set_xlim(-4.01, 4.01)\n",
    "ax1.set_ylim(-4.01, 4.01)\n",
    "ax1.tick_params(axis=\"both\", labelsize=24)\n",
    "ax1.set_xticks([-4, -2, 0, 2, 4])\n",
    "ax1.set_yticks([-4, -2, 0, 2, 4])\n",
    "ax1.set_title(\"Original Posterior\", fontsize=38, y=-0.12)\n",
    "\n",
    "lin0 = torch.stack(scaled_chkpts).numpy()[:, -3]\n",
    "lin1 = torch.stack(scaled_chkpts).numpy()[:, -2]\n",
    "\n",
    "lin_O_ind = np.intersect1d(np.where(lin0 < 4), np.where(lin0 > -4))\n",
    "lin_1_ind = np.intersect1d(np.where(lin1 < 4), np.where(lin1 > -4))\n",
    "inds = np.intersect1d(lin_O_ind, lin_1_ind)\n",
    "\n",
    "ax2.hexbin(\n",
    "    x=lin0[inds],\n",
    "    y=lin1[inds],\n",
    "    cmap=cmap,\n",
    "    mincnt=3,\n",
    "    gridsize=85,\n",
    "    alpha=1,\n",
    "    norm=matplotlib.colors.Normalize(vmin=0, vmax=30, clip=True),\n",
    "    antialiased=True,\n",
    ")\n",
    "sns.scatterplot(\n",
    "    x=lin0[inds][:1000],\n",
    "    y=lin1[inds][:1000],\n",
    "    s=2,\n",
    "    alpha=0.1,\n",
    "    color=\"red\",\n",
    "    ax=ax2,\n",
    "    antialiased=True,\n",
    ")\n",
    "ax2.grid(True, linestyle=\"--\", alpha=0.7)\n",
    "\n",
    "ax2.set_xlim(-4.01, 4.01)\n",
    "ax2.set_ylim(-4.01, 4.01)\n",
    "ax2.set_xticks([-4, -2, 0, 2, 4])\n",
    "ax2.set_yticks([-4, -2, 0, 2, 4])\n",
    "ax2.tick_params(axis=\"both\", labelsize=24)\n",
    "ax2.set_title(\"After Scaling Removal\", fontsize=38, y=-0.12)\n",
    "\n",
    "lin0 = torch.stack(permuted_chkpts).detach().numpy()[:, -3]\n",
    "lin1 = torch.stack(permuted_chkpts).detach().numpy()[:, -2]\n",
    "\n",
    "lin_O_ind = np.intersect1d(np.where(lin0 < 4), np.where(lin0 > -4))\n",
    "lin_1_ind = np.intersect1d(np.where(lin1 < 4), np.where(lin1 > -4))\n",
    "inds = np.intersect1d(lin_O_ind, lin_1_ind)\n",
    "\n",
    "ax3.hexbin(\n",
    "    x=lin0[inds],\n",
    "    y=lin1[inds],\n",
    "    cmap=cmap,\n",
    "    mincnt=3,\n",
    "    gridsize=85,\n",
    "    alpha=1,\n",
    "    norm=matplotlib.colors.Normalize(vmin=0, vmax=30, clip=True),\n",
    "    antialiased=True,\n",
    ")\n",
    "sns.scatterplot(\n",
    "    x=lin0[inds][:1000],\n",
    "    y=lin1[inds][:1000],\n",
    "    s=2,\n",
    "    alpha=0.1,\n",
    "    color=\"red\",\n",
    "    ax=ax3,\n",
    "    antialiased=True,\n",
    ")\n",
    "ax3.plot([-4, 4], [-4, 4], color=\"#c96004\", linestyle=\"--\", alpha=0.5)\n",
    "ax3.grid(True, linestyle=\"--\", alpha=0.7)\n",
    "ax3.set_xlim(-4.01, 4.01)\n",
    "ax3.set_ylim(-4.01, 4.01)\n",
    "ax3.set_xticks([-4, -2, 0, 2, 4])\n",
    "ax3.set_yticks([-4, -2, 0, 2, 4])\n",
    "ax3.tick_params(axis=\"both\", labelsize=24)\n",
    "ax3.set_title(\"After Scaling \\& Permutation Removal\", fontsize=38, y=-0.12)\n",
    "\n",
    "plt.tight_layout()\n",
    "sns.despine(f, top=True, right=True, left=True, bottom=True)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "uncertainty",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
