{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set working directory one level up (as if mnist_poc folder never existed)\n",
    "import os\n",
    "\n",
    "os.chdir(\"..\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Handy imports\n",
    "import torch\n",
    "from torch import nn\n",
    "from lightning import seed_everything"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2",
   "metadata": {},
   "source": [
    "## First: choose CrossEntropyLoss or MSELoss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3",
   "metadata": {},
   "outputs": [],
   "source": [
    "USE_CROSSENTROPY_INSTEAD_OF_MSE = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from types import MethodType\n",
    "\n",
    "\n",
    "def use_CrossEntropyLoss(pc_module):\n",
    "    \"\"\"CELoss to avoid vanishing grads with state optim...\"\"\"\n",
    "\n",
    "    # Define the new loss method using CrossEntropyLoss\n",
    "    def class_loss(self, y_pred, y):\n",
    "        return nn.functional.cross_entropy(y_pred, y, reduction=\"sum\")\n",
    "\n",
    "    # Override pc_module.class_loss with the new method\n",
    "    pc_module.class_loss = MethodType(class_loss, pc_module)\n",
    "\n",
    "    return pc_module"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5",
   "metadata": {},
   "source": [
    "## Define architecture"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed_everything(42)  # needed for reproducible weights\n",
    "\n",
    "\n",
    "# Use proper initialization for Linear\n",
    "class MyLinear(nn.Linear):\n",
    "    def reset_parameters(self):\n",
    "        gain = nn.init.calculate_gain(\"relu\")\n",
    "        # nn.init.xavier_uniform_(self.weight, gain)\n",
    "        nn.init.orthogonal_(self.weight, gain)\n",
    "        if self.bias is not None:\n",
    "            nn.init.zeros_(self.bias)\n",
    "\n",
    "\n",
    "architecture = (\n",
    "    [nn.Sequential(MyLinear(28 * 28, 128), nn.GELU())]\n",
    "    + [nn.Sequential(MyLinear(128, 128), nn.GELU()) for _ in range(18)]\n",
    "    + (\n",
    "        [MyLinear(128, 10)]  # for CrossEntropy\n",
    "        if USE_CROSSENTROPY_INSTEAD_OF_MSE\n",
    "        else [nn.Sequential(MyLinear(128, 10), nn.Sigmoid())]  # for MSE\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7",
   "metadata": {},
   "source": [
    "## Pretrain architecture"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datamodules import EMNIST\n",
    "from lightning import Trainer\n",
    "from pc_e import PCE\n",
    "\n",
    "# 0: load dataset as Lightning DataModule\n",
    "datamodule = EMNIST(batch_size=64)\n",
    "print(\"Training on\", datamodule.dataset_name)\n",
    "\n",
    "# 1: Set up Lightning trainer\n",
    "trainer = Trainer(\n",
    "    accelerator=\"cpu\",  # keep everything on CPU, to make analysis easier...\n",
    "    devices=1,\n",
    "    logger=False,\n",
    "    max_epochs=2,\n",
    "    inference_mode=False,  # inference_mode would interfere with the state backward pass\n",
    "    limit_predict_batches=1,  # enable 1-batch prediction\n",
    ")\n",
    "\n",
    "\n",
    "# 2: Define backprop version of PCE\n",
    "class BackpropMSE(PCE):\n",
    "    \"\"\"Train weights with backprop as neutral baseline (not favoring EO or SO)\"\"\"\n",
    "\n",
    "    def training_step(self, batch, batch_idx):\n",
    "        x, y = batch[\"img\"], batch[\"y\"]\n",
    "        self.forward(x)  # sets all errors to 0\n",
    "        return self.class_loss(self.y_pred(x), y) / self.batch_size\n",
    "\n",
    "\n",
    "# 3: Train model weights with backprop (fast & neutral method)\n",
    "pc = BackpropMSE(architecture, iters=None, e_lr=None, w_lr=0.0003)\n",
    "if USE_CROSSENTROPY_INSTEAD_OF_MSE:\n",
    "    pc = use_CrossEntropyLoss(pc)\n",
    "trainer.fit(pc, datamodule=datamodule)\n",
    "trainer.test(pc, datamodule=datamodule)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9",
   "metadata": {},
   "source": [
    "## Find the easiest and the most difficult training sample"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.notebook import tqdm\n",
    "\n",
    "seed_everything(42)  # needed for reproducible batch selection\n",
    "\n",
    "batch_size = 1\n",
    "dm = EMNIST(batch_size)\n",
    "dm.setup(\"fit\")\n",
    "dl = dm.train_dataloader()\n",
    "\n",
    "min_loss = float(\"inf\")\n",
    "min_batch = None\n",
    "max_loss = float(\"-inf\")\n",
    "max_batch = None\n",
    "\n",
    "with torch.no_grad():\n",
    "    for batch in tqdm(dl):\n",
    "        batch = dm.on_after_batch_transfer(batch, 0)\n",
    "        loss = pc.training_step(batch, None)\n",
    "\n",
    "        if loss < min_loss:\n",
    "            min_loss = loss\n",
    "            min_batch = batch\n",
    "        if loss > max_loss:\n",
    "            max_loss = loss\n",
    "            max_batch = batch\n",
    "\n",
    "print(f\"Done! {min_loss=}, {max_loss=}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "11",
   "metadata": {},
   "source": [
    "## Get a single batch x,y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Manually select which batch you want to run...\n",
    "# x, y = min_batch[\"img\"], min_batch[\"y\"]\n",
    "x, y = max_batch[\"img\"], max_batch[\"y\"]\n",
    "print(x.shape, y.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "13",
   "metadata": {},
   "source": [
    "## Instantiate PCE model with state / error tracking"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14",
   "metadata": {},
   "outputs": [],
   "source": [
    "from mnist_poc.tracked_pce import TrackedPCE\n",
    "\n",
    "pc = TrackedPCE(architecture, iters=256, e_lr=None, w_lr=0.001)\n",
    "if USE_CROSSENTROPY_INSTEAD_OF_MSE:\n",
    "    pc = use_CrossEntropyLoss(pc)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15",
   "metadata": {},
   "source": [
    "## Do mini hyperparameter search to find best learning rate for states and errors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Starting hyperparameter sweep for Error Optimization...\")\n",
    "best_score = float(\"inf\")\n",
    "best_e_lr = None\n",
    "for e_lr in [0.001, 0.005, 0.01, 0.05, 0.1, 0.3]:\n",
    "    pc.e_lr = e_lr\n",
    "    pc.minimize_error_energy(x, y)\n",
    "    score = pc.E(x, y)\n",
    "    print(e_lr, score)\n",
    "    if score <= best_score:\n",
    "        best_score = score\n",
    "        best_e_lr = e_lr\n",
    "\n",
    "print(\"Starting hyperparameter sweep for State Optimization...\")\n",
    "best_score = float(\"inf\")\n",
    "best_s_lr = None\n",
    "for s_lr in [0.01, 0.05, 0.1, 0.3, 0.5]:\n",
    "    final_states = pc.minimize_state_energy(x, y, iters=4096, s_lr=s_lr)\n",
    "    score = pc.E_states_only(x, y, final_states)\n",
    "    print(s_lr, score)\n",
    "    if score <= best_score:\n",
    "        best_score = score\n",
    "        best_s_lr = s_lr"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "17",
   "metadata": {},
   "source": [
    "## Rerun with optimal hyperparams"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "18",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Cast everything to float64 (needed for easy inputs)\n",
    "dtype = torch.float64\n",
    "pc.to(dtype)\n",
    "x64 = x.to(dtype)\n",
    "y64 = y.to(dtype)\n",
    "\n",
    "print(f\"Running Error Optimization with {best_e_lr=}\")\n",
    "pc.e_lr = best_e_lr\n",
    "pc.iters = 4096 * 4\n",
    "pc.minimize_error_energy(x64, y64)\n",
    "print(f\"Running State Optimization with {best_s_lr=}\")\n",
    "pc.minimize_state_energy(x64, y64, iters=131072, s_lr=best_s_lr)\n",
    "\n",
    "# Cast back to default (for further experimenting in the notebook)\n",
    "dtype = torch.get_default_dtype()\n",
    "pc.to(dtype)\n",
    "x.to(dtype)\n",
    "y.to(dtype)\n",
    "print(\"All done here!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "19",
   "metadata": {},
   "source": [
    "## Make plot to compare activation convergence"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "\n",
    "def plot_optimization_comparison(true_optimum, method1, method2, layers):\n",
    "    \"\"\"\n",
    "    Plot any number of layers side-by-side, comparing analytical optimum and two iterative methods.\n",
    "\n",
    "    Parameters:\n",
    "    - true_optimum: List of analytical solutions for each layer.\n",
    "    - method1: List of intermediate values for Error optim (indexed by time).\n",
    "    - method2: List of intermediate values for State optim (indexed by time).\n",
    "    - layers: List of integers specifying the layers to be plotted.\n",
    "    \"\"\"\n",
    "    if len(layers) == 0:\n",
    "        raise ValueError(\"You must specify at least one layer to plot.\")\n",
    "\n",
    "    # Create a figure with subplots for each layer\n",
    "    fig, axes = plt.subplots(1, len(layers), figsize=(4 * len(layers), 4), sharey=True)\n",
    "    axes = np.atleast_1d(axes)  # Ensure axes is always iterable, even for one layer\n",
    "\n",
    "    # Convert method1 and method2 lists into numpy arrays\n",
    "    method1 = [torch.stack(timestep[:-1]) for timestep in method1]\n",
    "    method2 = [torch.stack(timestep) for timestep in method2]\n",
    "\n",
    "    method1_array = np.array(method1)  # Shape: (time_steps_method1, states, batch_size, state_dim)\n",
    "    method2_array = np.array(method2)  # Shape: (time_steps_method2, states, batch_size, state_dim)\n",
    "\n",
    "    for i, layer in enumerate(layers):\n",
    "        ax = axes[i]\n",
    "\n",
    "        # Extract the true optimum for the current layer\n",
    "        optimum = np.array(true_optimum[layer].squeeze(0))  # Shape: (components,)\n",
    "\n",
    "        # Extract the corresponding components for method1 and method2\n",
    "        method1_layer_values = method1_array[:, layer, :, :]\n",
    "        method2_layer_values = method2_array[:, layer, :, :]\n",
    "\n",
    "        # Ensure method1_values have the same length as method2_values\n",
    "        if method1_layer_values.shape[0] < method2_layer_values.shape[0]:\n",
    "            method1_layer_values = np.concatenate(\n",
    "                [\n",
    "                    method1_layer_values,\n",
    "                    np.repeat(\n",
    "                        method1_layer_values[-1:],\n",
    "                        method2_layer_values.shape[0] - method1_layer_values.shape[0],\n",
    "                        axis=0,\n",
    "                    ),\n",
    "                ],\n",
    "                axis=0,\n",
    "            )\n",
    "\n",
    "        # Take norm across state components\n",
    "        method1_layer_values = np.linalg.norm(\n",
    "            np.subtract(method1_layer_values, optimum, dtype=np.float64), ord=2, axis=-1\n",
    "        )\n",
    "        method2_layer_values = np.linalg.norm(\n",
    "            np.subtract(method2_layer_values, optimum, dtype=np.float64), ord=2, axis=-1\n",
    "        )\n",
    "        method1_median = np.squeeze(method1_layer_values)\n",
    "        method2_median = np.squeeze(method2_layer_values)\n",
    "\n",
    "        # Plot median and IQR\n",
    "        ax.plot(method2_median, \"C3-\", label=\"State Optimization\")\n",
    "        ax.plot(method1_median, \"b-\", label=\"Error Optimization (ours)\")\n",
    "\n",
    "        # Add titles and labels\n",
    "        ax.set_title(f\"Convergence of $\\mathbf{{s_{{{layer}}}}}$ to optimum\", fontsize=15)\n",
    "        ax.set_xlabel(\"Optimization steps (log scale)\")\n",
    "        ax.set_xscale(\"log\")\n",
    "        ax.set_yscale(\"log\")\n",
    "        ax.set_ylim(7e-8, 3)\n",
    "\n",
    "    # Set shared y-axis label\n",
    "    axes[0].set_ylabel(\"$\\|\\mathbf{s_i} - \\mathbf{s_i^*}\\|$\")\n",
    "\n",
    "    # Add legend\n",
    "    handles = [\n",
    "        plt.Line2D([0], [0], color=\"b\", linestyle=\"-\"),\n",
    "        plt.Line2D([0], [0], color=\"C3\", linestyle=\"-\"),\n",
    "    ]\n",
    "    fig.legend(\n",
    "        handles,\n",
    "        [\"Error Optimization (ours)\", \"State Optimization\"],\n",
    "        loc=\"upper center\",\n",
    "        ncol=2,\n",
    "        prop={\"size\": 12},\n",
    "    )\n",
    "\n",
    "    plt.tight_layout(rect=[0, 0, 1, 0.90])  # Make space for the global legend\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_optimization_comparison(\n",
    "    pc.log_errors[-1],  # assume error optimization gets to final equilibrium\n",
    "    pc.log_errors,\n",
    "    pc.log_states,\n",
    "    layers=[0, 18],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Sanity check (easy vs hard input): check what Error Optim gives as y_pred\n",
    "print(pc.log_errors[-1][-1])\n",
    "print(y)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch-lightning",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
