{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set working directory one level up (as if mnist_poc folder never existed)\n",
    "import os\n",
    "\n",
    "os.chdir(\"..\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Handy imports\n",
    "from pc_e import PCE\n",
    "\n",
    "import torch\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2",
   "metadata": {},
   "source": [
    "# Track layerwise energies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn.functional as F\n",
    "\n",
    "\n",
    "class TrackedEnergies(PCE):\n",
    "    def __init__(self, architecture, iters, e_lr, w_lr):\n",
    "        super().__init__(architecture, iters, e_lr, w_lr)\n",
    "        self.log_E_errors = []\n",
    "        self.log_E_states = []\n",
    "\n",
    "    # ERROR OPTIMIZATION\n",
    "    def minimize_error_energy(self, x, y):\n",
    "        self.log_E_errors.clear()\n",
    "        return super().minimize_error_energy(x, y)\n",
    "\n",
    "    def E_errors_layerwise(self, x: torch.Tensor, y: torch.Tensor):\n",
    "        E_errors = [0.5 * torch.linalg.vector_norm(e, ord=2, dim=None) ** 2 for e in self.errors]\n",
    "\n",
    "        return E_errors + [self.class_loss(self.y_pred(x), y)]\n",
    "\n",
    "    def E(self, x, y):\n",
    "        E_layers = self.E_errors_layerwise(x, y)\n",
    "        self.log_E_errors.append([E.detach() for E in E_layers])\n",
    "        return sum(E_layers)\n",
    "\n",
    "    # STATE OPTIMIZATION\n",
    "    def minimize_state_energy(self, x, y, iters, s_lr):\n",
    "        self.log_E_states.clear()\n",
    "        return super().minimize_state_energy(x, y, iters, s_lr)\n",
    "\n",
    "    def E_states_only_layerwise(self, x: torch.Tensor, y: torch.Tensor, states: list[torch.Tensor]):\n",
    "        def half_mse_loss(y_pred, y):\n",
    "            return 0.5 * F.mse_loss(y_pred, y, reduction=\"sum\")\n",
    "\n",
    "        losses = [half_mse_loss] * len(states) + [self.class_loss]\n",
    "        states = [x] + states + [y]\n",
    "\n",
    "        return list(\n",
    "            loss(layer(s_i), s_ip1)\n",
    "            for s_i, s_ip1, layer, loss in zip(states[:-1], states[1:], self.layers, losses)\n",
    "        )\n",
    "\n",
    "    def E_states_only(self, x, y, states):\n",
    "        E_layers = self.E_states_only_layerwise(x, y, states)\n",
    "        self.log_E_states.append([E.detach() for E in E_layers])\n",
    "        return sum(E_layers)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4",
   "metadata": {},
   "source": [
    "# Load a single data item"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datamodules import EMNIST\n",
    "from lightning import seed_everything\n",
    "\n",
    "seed_everything(42)  # always same batch & weights (later on)\n",
    "\n",
    "batch_size = 1\n",
    "dm = EMNIST(batch_size)\n",
    "dataset_name = dm.dataset_name  # needs to happen before setup!\n",
    "print(\"Training on\", dm.dataset_name)\n",
    "dm.setup(\"fit\")\n",
    "dl = dm.train_dataloader()\n",
    "batch = next(iter(dl))\n",
    "batch = dm.on_after_batch_transfer(batch, 0)\n",
    "x, y = batch[\"img\"], batch[\"y\"]\n",
    "print(x.shape, y.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(x[0].reshape(28, 28).T, cmap=\"gray\")\n",
    "print(y)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7",
   "metadata": {},
   "source": [
    "# Load untrained model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8",
   "metadata": {},
   "outputs": [],
   "source": [
    "from get_arch import get_architecture\n",
    "\n",
    "architecture = get_architecture(dataset=\"EMNIST-deep\", use_CELoss=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9",
   "metadata": {},
   "outputs": [],
   "source": [
    "pc = TrackedEnergies(architecture, iters=8, e_lr=0.1, w_lr=None)\n",
    "pc.minimize_error_energy(x, y)\n",
    "pc.minimize_state_energy(x, y, iters=64, s_lr=0.1)\n",
    "print(\"All done here!\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.ticker as ticker\n",
    "from mpl_toolkits.axes_grid1 import make_axes_locatable\n",
    "from matplotlib.colors import LinearSegmentedColormap\n",
    "\n",
    "\n",
    "def make_double_figure(list1, list2):\n",
    "    list1 = np.array(list1)\n",
    "    list2 = np.array(list2)\n",
    "\n",
    "    # Hide zeros from plt's logscale\n",
    "    list1[list1 == 0] = np.nan\n",
    "    list2[list2 == 0] = np.nan\n",
    "\n",
    "    log1 = np.log10(list1.T)\n",
    "    log2 = np.log10(list2.T)\n",
    "\n",
    "    # Compute shared color scale limits\n",
    "    combined = np.concatenate([log1[~np.isnan(log1)], log2[~np.isnan(log2)]])\n",
    "    vmin, vmax = np.min(combined), np.max(combined)\n",
    "\n",
    "    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5), sharey=True)\n",
    "\n",
    "    cmap = plt.get_cmap(\"inferno\")\n",
    "    cmap = LinearSegmentedColormap.from_list(\"short_inferno\", cmap(np.linspace(0.0, 0.9, num=256)))\n",
    "    im1 = ax1.imshow(log1, aspect=\"auto\", interpolation=\"nearest\", vmin=vmin, vmax=vmax, cmap=cmap)\n",
    "    im2 = ax2.imshow(log2, aspect=\"auto\", interpolation=\"nearest\", vmin=vmin, vmax=vmax, cmap=cmap)\n",
    "\n",
    "    # ax1.set_title(r\"$\\bf{State\\ Optimization}$ (standard)\", fontsize=14)\n",
    "    # ax2.set_title(r\"$\\bf{Error\\ Optimization}$ (ours)\", fontsize=14)\n",
    "    for ax in [ax1, ax2]:\n",
    "        ax.set_xlabel(\"Time (optimization steps)\")\n",
    "    ax1.set_ylabel(\"Input Layer $i$ Output\")  # will move these to correct position in Inkscape\n",
    "    ax1.set_yticks([0, 5, 10, 15, 19])\n",
    "    ax2.tick_params(axis=\"y\", which=\"both\", left=False, labelleft=False)\n",
    "\n",
    "    # Create shared colorbar between plots\n",
    "    divider = make_axes_locatable(ax2)\n",
    "    cax = divider.append_axes(\"left\", size=\"5%\", pad=0.05)\n",
    "    cbar = fig.colorbar(im2, cax=cax, orientation=\"vertical\", ticks=np.linspace(vmin, vmax, num=6))\n",
    "    cbar.ax.tick_params(labelsize=11)\n",
    "\n",
    "    formatter = ticker.FuncFormatter(lambda x, _: f\"$10^{{{int(x)}}}$\")\n",
    "    cbar.ax.yaxis.set_major_formatter(formatter)\n",
    "    cax.yaxis.set_ticks_position(\"left\")\n",
    "    cax.yaxis.set_label_position(\"left\")\n",
    "    cax.invert_xaxis()\n",
    "\n",
    "    plt.tight_layout()\n",
    "    fig.savefig(\"mnist_poc/fig1.svg\")\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11",
   "metadata": {},
   "outputs": [],
   "source": [
    "make_double_figure(pc.log_E_states, pc.log_E_errors)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "torchenv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
