{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set working directory one level up (as if mnist_poc folder never existed)\n",
    "import os\n",
    "\n",
    "os.chdir(\"..\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1",
   "metadata": {},
   "source": [
    "## Define architecture"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch import nn\n",
    "from lightning import seed_everything\n",
    "\n",
    "seed_everything(42)\n",
    "\n",
    "\n",
    "# Use proper initialization for Linear\n",
    "class MyLinear(nn.Linear):\n",
    "    def reset_parameters(self):\n",
    "        gain = nn.init.calculate_gain(\"linear\")\n",
    "        # nn.init.xavier_uniform_(self.weight, gain)\n",
    "        nn.init.orthogonal_(self.weight, gain)\n",
    "        if self.bias is not None:\n",
    "            nn.init.zeros_(self.bias)\n",
    "\n",
    "\n",
    "architecture = (\n",
    "    [MyLinear(28 * 28, 128, bias=False)]\n",
    "    + [MyLinear(128, 128, bias=False) for _ in range(18)]\n",
    "    + [MyLinear(128, 10, bias=False)]\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3",
   "metadata": {},
   "source": [
    "## Prerain architecture for a few steps\n",
    "Overall, this deep linear architecture is really not great. It has no output activation, which gives it bad results and great tendency for instability. That's why we only pretrain on a few batches."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datamodules import EMNIST\n",
    "from lightning import Trainer\n",
    "from pc_variants import BackpropMSE\n",
    "\n",
    "# 0: load dataset as Lightning DataModule\n",
    "datamodule = EMNIST(batch_size=64)\n",
    "print(\"Training on\", datamodule.dataset_name)\n",
    "\n",
    "# 1: Set up Lightning trainer\n",
    "trainer = Trainer(\n",
    "    accelerator=\"cpu\",\n",
    "    devices=1,\n",
    "    logger=False,\n",
    "    max_epochs=2,\n",
    "    inference_mode=False,  # inference_mode would interfere with the state backward pass\n",
    "    limit_predict_batches=1,  # enable 1-batch prediction\n",
    ")\n",
    "\n",
    "# 2: Train model weights with backprop (fast & neutral method)\n",
    "pc = BackpropMSE(architecture, iters=None, e_lr=None, w_lr=0.001)\n",
    "trainer.fit(pc, datamodule=datamodule)\n",
    "trainer.test(pc, datamodule=datamodule)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5",
   "metadata": {},
   "source": [
    "## Get a single batch x,y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed_everything(42)\n",
    "\n",
    "batch_size = 64\n",
    "dm = EMNIST(batch_size)\n",
    "print(\"Training on\", dm.dataset_name)\n",
    "dm.setup(\"fit\")\n",
    "dl = dm.train_dataloader()\n",
    "batch = next(iter(dl))\n",
    "batch = dm.on_after_batch_transfer(batch, 0)\n",
    "x, y = batch[\"img\"], batch[\"y\"]\n",
    "print(x.shape, y.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7",
   "metadata": {},
   "source": [
    "## Calculate analytical solution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8",
   "metadata": {},
   "outputs": [],
   "source": [
    "from analytical_solution import get_final_states\n",
    "\n",
    "true_optimum = get_final_states(architecture, x, y)\n",
    "for i, s_i in enumerate(true_optimum, start=1):\n",
    "    print(f\"x^{ {i} } shape: {s_i.shape}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9",
   "metadata": {},
   "source": [
    "## Do mini hyperparameter search to find best learning rate for states and errors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from tracked_pce import TrackedPCE\n",
    "\n",
    "\n",
    "def get_conv_score(states):\n",
    "    result = sum(\n",
    "        (torch.norm(o - s, dim=1) / torch.norm(o, dim=1)).mean()\n",
    "        for o, s in zip(true_optimum, states[-1])\n",
    "    )\n",
    "    return result.item() / len(true_optimum)\n",
    "\n",
    "\n",
    "print(\"Starting hyperparameter sweep for Error Optimization...\")\n",
    "best_score = float(\"inf\")\n",
    "best_e_lr = None\n",
    "for e_lr in [0.01, 0.05, 0.1, 0.3]:\n",
    "    pc = TrackedPCE(architecture, iters=256, e_lr=e_lr, w_lr=0.001)\n",
    "    pc.minimize_error_energy(x, y)\n",
    "    score = get_conv_score(pc.log_errors)\n",
    "    print(e_lr, score)\n",
    "    if score <= best_score:\n",
    "        best_score = score\n",
    "        best_e_lr = e_lr\n",
    "\n",
    "print(\"Starting hyperparameter sweep for State Optimization...\")\n",
    "best_score = float(\"inf\")\n",
    "best_s_lr = None\n",
    "for s_lr in [0.1, 0.3, 0.5]:\n",
    "    pc.minimize_state_energy(x, y, iters=4096, s_lr=s_lr)\n",
    "    score = get_conv_score(pc.log_states)\n",
    "    print(s_lr, score)\n",
    "    if score <= best_score:\n",
    "        best_score = score\n",
    "        best_s_lr = s_lr"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "11",
   "metadata": {},
   "source": [
    "## Rerun with optimal hyperparams"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Running Error Optimization with {best_e_lr=}\")\n",
    "pc = TrackedPCE(architecture, iters=256, e_lr=best_e_lr, w_lr=0.001)\n",
    "pc.minimize_error_energy(x, y)\n",
    "print(f\"Running State Optimization with {best_s_lr=}\")\n",
    "pc.minimize_state_energy(x, y, iters=4096, s_lr=best_s_lr)\n",
    "print(\"All done here!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "13",
   "metadata": {},
   "source": [
    "## Make plot to compare activation convergence"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "\n",
    "def plot_optimization_comparison(true_optimum, method1, method2, layers):\n",
    "    \"\"\"\n",
    "    Plot any number of layers side-by-side, comparing analytical optimum and two iterative methods.\n",
    "\n",
    "    Parameters:\n",
    "    - true_optimum: List of analytical solutions for each layer.\n",
    "    - method1: List of intermediate values for Error optim (indexed by time).\n",
    "    - method2: List of intermediate values for State optim (indexed by time).\n",
    "    - layers: List of integers specifying the layers to be plotted.\n",
    "    \"\"\"\n",
    "    if len(layers) == 0:\n",
    "        raise ValueError(\"You must specify at least one layer to plot.\")\n",
    "\n",
    "    # Create a figure with subplots for each layer\n",
    "    fig, axes = plt.subplots(1, len(layers), figsize=(4 * len(layers), 4), sharey=True)\n",
    "    axes = np.atleast_1d(axes)  # Ensure axes is always iterable, even for one layer\n",
    "\n",
    "    # Convert method1 and method2 lists into numpy arrays\n",
    "    method1 = [torch.stack(timestep[:-1]) for timestep in method1]\n",
    "    method2 = [torch.stack(timestep) for timestep in method2]\n",
    "\n",
    "    method1_array = np.array(method1)  # Shape: (time_steps_method1, states, batch_size, state_dim)\n",
    "    method2_array = np.array(method2)  # Shape: (time_steps_method2, states, batch_size, state_dim)\n",
    "\n",
    "    for i, layer in enumerate(layers):\n",
    "        ax = axes[i]\n",
    "\n",
    "        # Extract the true optimum for the current layer\n",
    "        optimum = np.array(true_optimum[layer].squeeze(0))  # Shape: (components,)\n",
    "\n",
    "        # Extract the corresponding components for method1 and method2\n",
    "        method1_layer_values = method1_array[:, layer, :, :]\n",
    "        method2_layer_values = method2_array[:, layer, :, :]\n",
    "\n",
    "        # Ensure method1_values have the same length as method2_values\n",
    "        if method1_layer_values.shape[0] < method2_layer_values.shape[0]:\n",
    "            method1_layer_values = np.concatenate(\n",
    "                [\n",
    "                    method1_layer_values,\n",
    "                    np.repeat(\n",
    "                        method1_layer_values[-1:],\n",
    "                        method2_layer_values.shape[0] - method1_layer_values.shape[0],\n",
    "                        axis=0,\n",
    "                    ),\n",
    "                ],\n",
    "                axis=0,\n",
    "            )\n",
    "\n",
    "        # Take norm across state components\n",
    "        method1_layer_values = np.linalg.norm(\n",
    "            np.subtract(method1_layer_values, optimum, dtype=np.float64), ord=2, axis=-1\n",
    "        )\n",
    "        method2_layer_values = np.linalg.norm(\n",
    "            np.subtract(method2_layer_values, optimum, dtype=np.float64), ord=2, axis=-1\n",
    "        )\n",
    "\n",
    "        # Calculate medians and quartiles over the batch size\n",
    "        method1_q1 = np.percentile(method1_layer_values, 25, axis=-1)\n",
    "        method1_median = np.percentile(method1_layer_values, 50, axis=-1)\n",
    "        method1_q3 = np.percentile(method1_layer_values, 75, axis=-1)\n",
    "\n",
    "        method2_q1 = np.percentile(method2_layer_values, 25, axis=-1)\n",
    "        method2_median = np.percentile(method2_layer_values, 50, axis=-1)\n",
    "        method2_q3 = np.percentile(method2_layer_values, 75, axis=-1)\n",
    "\n",
    "        # Plot median and IQR\n",
    "        time_steps = np.arange(method1_median.shape[0])\n",
    "        ax.plot(time_steps, method2_median, \"C3-\", label=\"State Optimization\")\n",
    "        ax.fill_between(\n",
    "            time_steps,\n",
    "            method2_q1,\n",
    "            method2_q3,\n",
    "            color=\"C3\",\n",
    "            alpha=0.3,\n",
    "            label=\"IQR (State Optimization)\",\n",
    "        )\n",
    "        ax.plot(time_steps, method1_median, \"b-\", label=\"Error Optimization (ours)\")\n",
    "        ax.fill_between(\n",
    "            time_steps,\n",
    "            method1_q1,\n",
    "            method1_q3,\n",
    "            color=\"b\",\n",
    "            alpha=0.3,\n",
    "            label=\"IQR (Error Optimization)\",\n",
    "        )\n",
    "\n",
    "        # Add titles and labels\n",
    "        ax.set_title(f\"Convergence of $\\mathbf{{s_{{{layer}}}}}$ to optimum\", fontsize=15)\n",
    "        ax.set_xlabel(\"Optimization steps (log scale)\")\n",
    "        ax.set_xscale(\"log\")\n",
    "        ax.set_yscale(\"log\")\n",
    "        ax.set_ylim(7e-5, 1)\n",
    "\n",
    "    # Set shared y-axis label\n",
    "    axes[0].set_ylabel(\"$\\|\\mathbf{s_i} - \\mathbf{s_i^*}\\|$\")\n",
    "\n",
    "    # Add legend\n",
    "    handles = [\n",
    "        plt.Line2D([0], [0], color=\"b\", linestyle=\"-\"),\n",
    "        plt.Line2D([0], [0], color=\"C3\", linestyle=\"-\"),\n",
    "    ]\n",
    "    fig.legend(\n",
    "        handles,\n",
    "        [\"Error Optimization (ours)\", \"State Optimization\"],\n",
    "        loc=\"upper center\",\n",
    "        ncol=2,\n",
    "        prop={\"size\": 12},\n",
    "    )\n",
    "\n",
    "    plt.tight_layout(rect=[0, 0, 1, 0.90])  # Make space for the global legend\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_optimization_comparison(\n",
    "    true_optimum,\n",
    "    pc.log_errors,\n",
    "    pc.log_states,\n",
    "    layers=[0, 9, 18],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "torchenv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
