{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "maWjqhvjvpYk"
      },
      "source": [
        "# Fisher-CG PINN Scaling\n",
        "\n",
        "This notebook reproduces the reviewer-response experiment in a Colab GPU runtime.\n",
        "\n",
        "It compares:\n",
        "\n",
        "- `Adam` in both Phase I and Phase II\n",
        "- `Adam -> CG-Fisher` with `lr = 1.0` and `20` CG iterations in Phase II\n",
        "\n",
        "Goals:\n",
        "\n",
        "1. Check whether `CG-Fisher` still beats Adam after the Phase I -> Phase II switch.\n",
        "2. Report both step-wise and time-wise comparisons to Adam.\n",
        "3. Report the steady-state per-iteration cost ratio `CG-Fisher / Adam` from the\n",
        "   last `100` Phase II epochs, after JIT compilation and startup transients.\n",
        "\n",
        "Paper-matched Poisson PINN setting:\n",
        "\n",
        "- 2D Poisson equation on `[0, 1]^2`\n",
        "- `1000` interior collocation points per epoch\n",
        "- `200` boundary points per epoch\n",
        "- resample every epoch\n",
        "- `1000` Adam warmup epochs for Phase I\n",
        "- `1000` Phase II epochs\n",
        "- `20` CG iterations for the preconditioned method\n",
        "- baseline architecture `50 x 2` with `tanh`\n",
        "\n",
        "Timing note:\n",
        "\n",
        "- the notebook explicitly JIT-compiles each optimizer step before timing\n",
        "- the reviewer-facing per-step ratio uses the last `100` Phase II epochs in full mode"
      ],
      "id": "maWjqhvjvpYk"
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SxwKKpXqvpYl"
      },
      "source": [
        "## Colab Setup\n",
        "\n",
        "Use a GPU runtime before running the next cell.\n",
        "\n",
        "This notebook installs `opttx` from PyPI, not from a Git checkout.\n",
        "\n",
        "`opttx` is published on PyPI:\n",
        "https://pypi.org/project/opttx/\n",
        "\n",
        "For GPU JAX wheels, this notebook follows the current official JAX installation page.\n",
        "\n",
        "Sources:\n",
        "\n",
        "- OptTx on PyPI: https://pypi.org/project/opttx/\n",
        "- JAX installation docs: https://docs.jax.dev/en/latest/installation.html\n",
        "- JAX GPU memory allocation docs: https://docs.jax.dev/en/latest/gpu_memory_allocation.html"
      ],
      "id": "SxwKKpXqvpYl"
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Vljas-0EvpYl"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "os.environ[\"XLA_PYTHON_CLIENT_PREALLOCATE\"] = \"false\"\n",
        "\n",
        "!pip install -q -U pip\n",
        "!pip install -q -U \"jax[cuda13]\" opttx"
      ],
      "id": "Vljas-0EvpYl"
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Rm_MumZTvpYl",
        "outputId": "35f8e7e8-bfe3-47c9-de96-16f382711cb3"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "JAX version: 0.9.2\n",
            "Devices: [CudaDevice(id=0)]\n",
            "Default backend: gpu\n"
          ]
        }
      ],
      "source": [
        "import math\n",
        "import time\n",
        "from dataclasses import dataclass\n",
        "from typing import Any, Dict, List, Sequence, Tuple\n",
        "\n",
        "import jax\n",
        "import jax.numpy as jnp\n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "import pandas as pd\n",
        "from flax import linen as nn\n",
        "from IPython.display import display\n",
        "\n",
        "from opttx import Adam, CGOptimizer, Objective, TermSpec, TrainState\n",
        "\n",
        "jax.config.update(\"jax_enable_x64\", True)\n",
        "\n",
        "print(\"JAX version:\", jax.__version__)\n",
        "print(\"Devices:\", jax.devices())\n",
        "print(\"Default backend:\", jax.default_backend())\n",
        "if jax.default_backend() == \"cpu\":\n",
        "    print(\"Warning: notebook is running on CPU, not GPU.\")"
      ],
      "id": "Rm_MumZTvpYl"
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "nVERFBnevpYl",
        "outputId": "b15e344c-1a3b-4f3b-baa0-098a21167f5f"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "{'use_float64': True,\n",
              " 'timing_seed': 41,\n",
              " 'phase1_seed': 42,\n",
              " 'phase2_seed': 43,\n",
              " 'resample_each_epoch': True,\n",
              " 'n_interior_batch': 1000,\n",
              " 'n_boundary_batch': 200,\n",
              " 'weight_pde': 1.0,\n",
              " 'weight_bc': 100.0,\n",
              " 'timing_warmup_steps': 1,\n",
              " 'timing_steps': 5,\n",
              " 'phase1_steps': 1000,\n",
              " 'phase2_steps': 1000,\n",
              " 'phase2_tail_steps': 100,\n",
              " 'adam_lr': 0.001,\n",
              " 'cgf_lr': 1.0,\n",
              " 'cg_iters': 20,\n",
              " 'damping': 0.001,\n",
              " 'depth_values': [2, 3, 4, 5],\n",
              " 'width_values': [30, 40, 50, 60, 70],\n",
              " 'depth_width': 50,\n",
              " 'width_depth': 2}"
            ]
          },
          "metadata": {},
          "execution_count": 3
        }
      ],
      "source": [
        "RUN_MODE = \"full\"  # \"full\" or \"quick\"\n",
        "\n",
        "if RUN_MODE == \"full\":\n",
        "    CONFIG = {\n",
        "        \"use_float64\": True,\n",
        "        \"timing_seed\": 41,\n",
        "        \"phase1_seed\": 42,\n",
        "        \"phase2_seed\": 43,\n",
        "        \"resample_each_epoch\": True,\n",
        "        \"n_interior_batch\": 1000,\n",
        "        \"n_boundary_batch\": 200,\n",
        "        \"weight_pde\": 1.0,\n",
        "        \"weight_bc\": 100.0,\n",
        "        \"timing_warmup_steps\": 1,\n",
        "        \"timing_steps\": 5,\n",
        "        \"phase1_steps\": 1000,\n",
        "        \"phase2_steps\": 1000,\n",
        "        \"phase2_tail_steps\": 100,\n",
        "        \"adam_lr\": 1e-3,\n",
        "        \"cgf_lr\": 1.0,\n",
        "        \"cg_iters\": 20,\n",
        "        \"damping\": 1e-3,\n",
        "        \"depth_values\": [2, 3, 4, 5],\n",
        "        \"depth_width\": 50,\n",
        "    }\n",
        "else:\n",
        "    CONFIG = {\n",
        "        \"use_float64\": True,\n",
        "        \"timing_seed\": 41,\n",
        "        \"phase1_seed\": 42,\n",
        "        \"phase2_seed\": 43,\n",
        "        \"resample_each_epoch\": True,\n",
        "        \"n_interior_batch\": 128,\n",
        "        \"n_boundary_batch\": 64,\n",
        "        \"weight_pde\": 1.0,\n",
        "        \"weight_bc\": 100.0,\n",
        "        \"timing_warmup_steps\": 0,\n",
        "        \"timing_steps\": 2,\n",
        "        \"phase1_steps\": 50,\n",
        "        \"phase2_steps\": 100,\n",
        "        \"phase2_tail_steps\": 20,\n",
        "        \"adam_lr\": 1e-3,\n",
        "        \"cgf_lr\": 1.0,\n",
        "        \"cg_iters\": 5,\n",
        "        \"damping\": 1e-3,\n",
        "        \"depth_values\": [2, 3, 4],\n",
        "        \"depth_width\": 50,\n",
        "    }\n",
        "\n",
        "CONFIG"
      ],
      "id": "nVERFBnevpYl"
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6x2qKlpcvpYl"
      },
      "outputs": [],
      "source": [
        "class PoissonPINN(nn.Module):\n",
        "    hidden_layers: Sequence[int]\n",
        "    dtype: jnp.dtype = jnp.float64\n",
        "\n",
        "    @nn.compact\n",
        "    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:\n",
        "        for width in self.hidden_layers:\n",
        "            x = nn.Dense(width, dtype=self.dtype, param_dtype=self.dtype)(x)\n",
        "            x = nn.tanh(x)\n",
        "        return nn.Dense(1, dtype=self.dtype, param_dtype=self.dtype)(x)\n",
        "\n",
        "    def laplacian(self, x: jnp.ndarray) -> jnp.ndarray:\n",
        "        def single_point_laplacian(point: jnp.ndarray) -> jnp.ndarray:\n",
        "            x_val, y_val = point[0], point[1]\n",
        "\n",
        "            def u_at_x(t: jax.Array) -> jax.Array:\n",
        "                pt = jnp.array([t, y_val], dtype=point.dtype)\n",
        "                return self(pt[None, :])[0, 0]\n",
        "\n",
        "            def u_at_y(t: jax.Array) -> jax.Array:\n",
        "                pt = jnp.array([x_val, t], dtype=point.dtype)\n",
        "                return self(pt[None, :])[0, 0]\n",
        "\n",
        "            u_xx = jax.grad(jax.grad(u_at_x))(x_val)\n",
        "            u_yy = jax.grad(jax.grad(u_at_y))(y_val)\n",
        "            return -(u_xx + u_yy)\n",
        "\n",
        "        return jax.vmap(single_point_laplacian)(x)[:, None]\n",
        "\n",
        "\n",
        "def true_solution(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:\n",
        "    return jnp.sin(2 * jnp.pi * x) * jnp.sin(2 * jnp.pi * y)\n",
        "\n",
        "\n",
        "def source_term(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:\n",
        "    return 8 * (jnp.pi**2) * jnp.sin(2 * jnp.pi * x) * jnp.sin(2 * jnp.pi * y)\n",
        "\n",
        "\n",
        "def generate_batch(\n",
        "    key: jax.Array,\n",
        "    n_interior: int,\n",
        "    n_boundary_total: int,\n",
        "    dtype: jnp.dtype,\n",
        ") -> Dict[str, Tuple[jnp.ndarray, jnp.ndarray]]:\n",
        "    if n_boundary_total % 4 != 0:\n",
        "        raise ValueError(\"n_boundary_total must be divisible by 4\")\n",
        "\n",
        "    key_int_x, key_int_y, key_b0, key_b1, key_b2, key_b3 = jax.random.split(key, 6)\n",
        "\n",
        "    x_int = jax.random.uniform(key_int_x, (n_interior,), dtype=dtype)\n",
        "    y_int = jax.random.uniform(key_int_y, (n_interior,), dtype=dtype)\n",
        "    interior_points = jnp.stack([x_int, y_int], axis=1)\n",
        "    interior_targets = source_term(x_int, y_int)[:, None]\n",
        "\n",
        "    n_side = n_boundary_total // 4\n",
        "    x_bottom = jax.random.uniform(key_b0, (n_side,), dtype=dtype)\n",
        "    x_top = jax.random.uniform(key_b1, (n_side,), dtype=dtype)\n",
        "    y_left = jax.random.uniform(key_b2, (n_side,), dtype=dtype)\n",
        "    y_right = jax.random.uniform(key_b3, (n_side,), dtype=dtype)\n",
        "\n",
        "    bottom = jnp.stack([x_bottom, jnp.zeros(n_side, dtype=dtype)], axis=1)\n",
        "    top = jnp.stack([x_top, jnp.ones(n_side, dtype=dtype)], axis=1)\n",
        "    left = jnp.stack([jnp.zeros(n_side, dtype=dtype), y_left], axis=1)\n",
        "    right = jnp.stack([jnp.ones(n_side, dtype=dtype), y_right], axis=1)\n",
        "\n",
        "    boundary_points = jnp.concatenate([bottom, top, left, right], axis=0)\n",
        "    boundary_targets = true_solution(boundary_points[:, 0], boundary_points[:, 1])[:, None]\n",
        "\n",
        "    return {\n",
        "        \"interior\": (interior_points, interior_targets),\n",
        "        \"boundary\": (boundary_points, boundary_targets),\n",
        "    }\n",
        "\n",
        "\n",
        "def pde_loss_fn(pred: jnp.ndarray, batch: Tuple[jnp.ndarray, jnp.ndarray]) -> jax.Array:\n",
        "    _, target = batch\n",
        "    return jnp.mean((pred - target) ** 2)\n",
        "\n",
        "\n",
        "def bc_loss_fn(pred: jnp.ndarray, batch: Tuple[jnp.ndarray, jnp.ndarray]) -> jax.Array:\n",
        "    _, target = batch\n",
        "    return jnp.mean((pred - target) ** 2)\n",
        "\n",
        "\n",
        "def count_parameters(params: Any) -> int:\n",
        "    return int(sum(leaf.size for leaf in jax.tree_util.tree_leaves(params)))\n",
        "\n",
        "\n",
        "def build_problem(width: int, depth: int, cfg: Dict[str, Any]):\n",
        "    dtype = jnp.float64 if cfg[\"use_float64\"] else jnp.float32\n",
        "    hidden_layers = tuple([width] * depth)\n",
        "    model = PoissonPINN(hidden_layers=hidden_layers, dtype=dtype)\n",
        "\n",
        "    init_key = jax.random.PRNGKey(cfg[\"phase1_seed\"])\n",
        "    init_batch = generate_batch(\n",
        "        jax.random.PRNGKey(cfg[\"timing_seed\"]),\n",
        "        n_interior=cfg[\"n_interior_batch\"],\n",
        "        n_boundary_total=cfg[\"n_boundary_batch\"],\n",
        "        dtype=dtype,\n",
        "    )\n",
        "    variables = model.init(init_key, init_batch[\"interior\"][0])\n",
        "\n",
        "    def model_apply(variables: Dict[str, Any], batch_item: Any, method: str | None = None):\n",
        "        points, _ = batch_item\n",
        "        if method == \"laplacian\":\n",
        "            return model.apply(variables, points, method=lambda m, x: m.laplacian(x))\n",
        "        return model.apply(variables, points)\n",
        "\n",
        "    objective = Objective(\n",
        "        terms=[\n",
        "            TermSpec(name=\"pde\", batch_key=\"interior\", loss_fn=pde_loss_fn, method=\"laplacian\"),\n",
        "            TermSpec(name=\"bc\", batch_key=\"boundary\", loss_fn=bc_loss_fn),\n",
        "        ],\n",
        "        loss_weights={\n",
        "            \"pde\": cfg[\"weight_pde\"],\n",
        "            \"bc\": cfg[\"weight_bc\"],\n",
        "        },\n",
        "    )\n",
        "    state = TrainState.create(apply_fn=model_apply, params=variables[\"params\"])\n",
        "    metadata = {\n",
        "        \"width\": width,\n",
        "        \"depth\": depth,\n",
        "        \"param_count\": count_parameters(state.params),\n",
        "        \"dtype\": str(dtype),\n",
        "    }\n",
        "    return objective, state, dtype, metadata\n",
        "\n",
        "\n",
        "def prepare_batch_schedule(cfg: Dict[str, Any], dtype: jnp.dtype, seed: int, total_steps: int):\n",
        "    keys = jax.random.split(jax.random.PRNGKey(seed), max(total_steps, 1))\n",
        "    return [generate_batch(k, cfg[\"n_interior_batch\"], cfg[\"n_boundary_batch\"], dtype) for k in keys]\n",
        "\n",
        "\n",
        "def create_optimizer(name: str, objective: Objective, cfg: Dict[str, Any]):\n",
        "    if name == \"adam\":\n",
        "        return Adam(objective, learning_rate=cfg[\"adam_lr\"])\n",
        "    if name == \"cg_fisher\":\n",
        "        return CGOptimizer(\n",
        "            objective,\n",
        "            learning_rate=cfg[\"cgf_lr\"],\n",
        "            damping=cfg[\"damping\"],\n",
        "            cg_iters=cfg[\"cg_iters\"],\n",
        "            curvature_type=\"fisher\",\n",
        "        )\n",
        "    raise ValueError(name)\n",
        "\n",
        "\n",
        "def compile_step(jit_step: Any, state: TrainState, batch: Dict[str, Any]) -> None:\n",
        "    _, metrics = jit_step(state, batch)\n",
        "    jax.block_until_ready(metrics[\"loss\"])\n",
        "\n",
        "\n",
        "def run_trace(\n",
        "    optimizer: Any,\n",
        "    initial_state: TrainState,\n",
        "    batch_schedule: Sequence[Dict[str, Any]],\n",
        "    *,\n",
        "    timing_warmup_steps: int = 0,\n",
        ") -> Dict[str, Any]:\n",
        "    example_batch = batch_schedule[0]\n",
        "    state = optimizer.init(initial_state.replace(step=jnp.array(0), opt_state=None), example_batch=example_batch)\n",
        "    jit_step = jax.jit(optimizer.step)\n",
        "    compile_step(jit_step, state, example_batch)\n",
        "\n",
        "    cursor = 0\n",
        "    for _ in range(timing_warmup_steps):\n",
        "        state, metrics = jit_step(state, batch_schedule[cursor])\n",
        "        jax.block_until_ready(metrics[\"loss\"])\n",
        "        cursor += 1\n",
        "\n",
        "    losses = []\n",
        "    losses_pde = []\n",
        "    losses_bc = []\n",
        "    times = []\n",
        "\n",
        "    for batch in batch_schedule[cursor:]:\n",
        "        t0 = time.perf_counter()\n",
        "        state, metrics = jit_step(state, batch)\n",
        "        jax.block_until_ready(metrics[\"loss\"])\n",
        "        times.append(time.perf_counter() - t0)\n",
        "        losses.append(float(metrics[\"loss\"]))\n",
        "        losses_pde.append(float(metrics.get(\"loss/pde\", math.nan)))\n",
        "        losses_bc.append(float(metrics.get(\"loss/bc\", math.nan)))\n",
        "\n",
        "    return {\n",
        "        \"losses\": losses,\n",
        "        \"losses_pde\": losses_pde,\n",
        "        \"losses_bc\": losses_bc,\n",
        "        \"times\": times,\n",
        "        \"total_time\": float(sum(times)),\n",
        "        \"final_loss\": float(losses[-1]),\n",
        "        \"final_pde_loss\": float(losses_pde[-1]),\n",
        "        \"final_bc_loss\": float(losses_bc[-1]),\n",
        "    }\n",
        "\n",
        "\n",
        "def summarize_times(times: Sequence[float]) -> Dict[str, float]:\n",
        "    arr = np.asarray(times, dtype=np.float64)\n",
        "    return {\n",
        "        \"step_time_mean\": float(arr.mean()),\n",
        "        \"step_time_std\": float(arr.std()),\n",
        "        \"step_time_min\": float(arr.min()),\n",
        "        \"step_time_max\": float(arr.max()),\n",
        "    }\n",
        "\n",
        "\n",
        "def summarize_tail_times(times: Sequence[float], tail_steps: int) -> Dict[str, float]:\n",
        "    arr = np.asarray(times, dtype=np.float64)\n",
        "    tail = arr[-min(len(arr), int(tail_steps)) :]\n",
        "    return {\n",
        "        \"tail_steps_used\": int(len(tail)),\n",
        "        \"tail_step_time_mean\": float(tail.mean()),\n",
        "        \"tail_step_time_std\": float(tail.std()),\n",
        "        \"tail_step_time_min\": float(tail.min()),\n",
        "        \"tail_step_time_max\": float(tail.max()),\n",
        "    }\n",
        "\n",
        "\n",
        "def cumulative_times(times: Sequence[float]) -> np.ndarray:\n",
        "    return np.cumsum(np.asarray(times, dtype=np.float64))\n",
        "\n",
        "\n",
        "def first_time_to_reach(losses: Sequence[float], times: Sequence[float], target: float) -> float | None:\n",
        "    cum = cumulative_times(times)\n",
        "    for idx, loss in enumerate(losses):\n",
        "        if float(loss) <= target:\n",
        "            return float(cum[idx])\n",
        "    return None\n",
        "\n",
        "\n",
        "def first_step_to_reach(losses: Sequence[float], target: float) -> int | None:\n",
        "    for idx, loss in enumerate(losses):\n",
        "        if float(loss) <= target:\n",
        "            return idx + 1\n",
        "    return None\n",
        "\n",
        "\n",
        "def best_loss_within_budget(losses: Sequence[float], times: Sequence[float], budget: float) -> float | None:\n",
        "    cum = cumulative_times(times)\n",
        "    valid = np.nonzero(cum <= budget)[0]\n",
        "    if len(valid) == 0:\n",
        "        return None\n",
        "    valid_losses = np.asarray(losses, dtype=np.float64)[valid]\n",
        "    return float(valid_losses.min())\n",
        "\n",
        "\n",
        "def warmup_checkpoint(base_state: TrainState, objective: Objective, cfg: Dict[str, Any], dtype: jnp.dtype) -> TrainState:\n",
        "    optimizer = Adam(objective, learning_rate=cfg[\"adam_lr\"])\n",
        "    batch_schedule = prepare_batch_schedule(cfg, dtype, cfg[\"phase1_seed\"], cfg[\"phase1_steps\"])\n",
        "    example_batch = batch_schedule[0]\n",
        "    state = optimizer.init(base_state.replace(step=jnp.array(0), opt_state=None), example_batch=example_batch)\n",
        "    jit_step = jax.jit(optimizer.step)\n",
        "    compile_step(jit_step, state, example_batch)\n",
        "\n",
        "    for batch in batch_schedule:\n",
        "        state, metrics = jit_step(state, batch)\n",
        "        jax.block_until_ready(metrics[\"loss\"])\n",
        "\n",
        "    return state.replace(opt_state=None, step=jnp.array(0))\n",
        "\n",
        "\n",
        "def measure_step_times(objective: Objective, state: TrainState, cfg: Dict[str, Any], dtype: jnp.dtype) -> Dict[str, Any]:\n",
        "    batch_schedule = prepare_batch_schedule(\n",
        "        cfg,\n",
        "        dtype,\n",
        "        cfg[\"timing_seed\"],\n",
        "        cfg[\"timing_warmup_steps\"] + cfg[\"timing_steps\"],\n",
        "    )\n",
        "    traces = {}\n",
        "    for name in (\"adam\", \"cg_fisher\"):\n",
        "        trace = run_trace(\n",
        "            create_optimizer(name, objective, cfg),\n",
        "            state,\n",
        "            batch_schedule,\n",
        "            timing_warmup_steps=cfg[\"timing_warmup_steps\"],\n",
        "        )\n",
        "        traces[name] = {**trace, **summarize_times(trace[\"times\"])}\n",
        "    traces[\"cg_fisher\"][\"ratio_vs_adam\"] = (\n",
        "        traces[\"cg_fisher\"][\"step_time_mean\"] / traces[\"adam\"][\"step_time_mean\"]\n",
        "    )\n",
        "    return traces\n",
        "\n",
        "\n",
        "def compare_adam_vs_cg(objective: Objective, checkpoint_state: TrainState, cfg: Dict[str, Any], dtype: jnp.dtype) -> Dict[str, Any]:\n",
        "    batch_schedule = prepare_batch_schedule(cfg, dtype, cfg[\"phase2_seed\"], cfg[\"phase2_steps\"])\n",
        "\n",
        "    adam_trace = run_trace(create_optimizer(\"adam\", objective, cfg), checkpoint_state, batch_schedule)\n",
        "    cg_trace = run_trace(create_optimizer(\"cg_fisher\", objective, cfg), checkpoint_state, batch_schedule)\n",
        "    tail_steps = int(cfg.get(\"phase2_tail_steps\", 100))\n",
        "    adam_tail = summarize_tail_times(adam_trace[\"times\"], tail_steps)\n",
        "    cg_tail = summarize_tail_times(cg_trace[\"times\"], tail_steps)\n",
        "\n",
        "    adam_budget = adam_trace[\"total_time\"]\n",
        "    cg_budget = cg_trace[\"total_time\"]\n",
        "\n",
        "    cg_best_within_adam = best_loss_within_budget(cg_trace[\"losses\"], cg_trace[\"times\"], adam_budget)\n",
        "    cg_time_to_adam_final = first_time_to_reach(cg_trace[\"losses\"], cg_trace[\"times\"], adam_trace[\"final_loss\"])\n",
        "    cg_step_to_adam_final = first_step_to_reach(cg_trace[\"losses\"], adam_trace[\"final_loss\"])\n",
        "    adam_time_to_cg_final = first_time_to_reach(adam_trace[\"losses\"], adam_trace[\"times\"], cg_trace[\"final_loss\"])\n",
        "    adam_step_to_cg_final = first_step_to_reach(adam_trace[\"losses\"], cg_trace[\"final_loss\"])\n",
        "\n",
        "    return {\n",
        "        \"adam\": adam_trace,\n",
        "        \"cg_fisher\": cg_trace,\n",
        "        \"cg_beats_adam_by_steps\": cg_trace[\"final_loss\"] < adam_trace[\"final_loss\"],\n",
        "        \"cg_beats_adam_by_time\": (\n",
        "            cg_time_to_adam_final is not None and cg_time_to_adam_final <= adam_budget\n",
        "        ),\n",
        "        \"cg_final_loss_ratio_vs_adam\": float(cg_trace[\"final_loss\"] / adam_trace[\"final_loss\"]),\n",
        "        \"cg_steps_to_reach_adam_final_loss\": cg_step_to_adam_final,\n",
        "        \"cg_time_to_reach_adam_final_loss\": cg_time_to_adam_final,\n",
        "        \"adam_steps_to_reach_cg_final_loss\": adam_step_to_cg_final,\n",
        "        \"adam_time_to_reach_cg_final_loss\": adam_time_to_cg_final,\n",
        "        \"cg_best_loss_within_adam_time_budget\": cg_best_within_adam,\n",
        "        \"cg_matches_adam_final_within_adam_time_budget\": (\n",
        "            cg_best_within_adam is not None and cg_best_within_adam <= adam_trace[\"final_loss\"]\n",
        "        ),\n",
        "        \"phase2_tail_steps\": tail_steps,\n",
        "        \"adam_phase2_tail_step_time_mean\": adam_tail[\"tail_step_time_mean\"],\n",
        "        \"cg_phase2_tail_step_time_mean\": cg_tail[\"tail_step_time_mean\"],\n",
        "        \"cg_phase2_tail_over_adam\": float(\n",
        "            cg_tail[\"tail_step_time_mean\"] / adam_tail[\"tail_step_time_mean\"]\n",
        "        ),\n",
        "        \"adam_total_time\": adam_budget,\n",
        "        \"cg_total_time\": cg_budget,\n",
        "    }\n",
        "\n",
        "\n",
        "def run_architecture_sweep(\n",
        "    cfg: Dict[str, Any],\n",
        "    *,\n",
        "    sweep_key: str,\n",
        "    values: Sequence[int],\n",
        "    fixed_width: int | None = None,\n",
        "    fixed_depth: int | None = None,\n",
        ") -> Dict[str, Any]:\n",
        "    entries = []\n",
        "\n",
        "    for value in values:\n",
        "        if sweep_key == \"depth\":\n",
        "            depth = int(value)\n",
        "            width = int(fixed_width)\n",
        "        elif sweep_key == \"width\":\n",
        "            depth = int(fixed_depth)\n",
        "            width = int(value)\n",
        "        else:\n",
        "            raise ValueError(sweep_key)\n",
        "\n",
        "        objective, state, dtype, metadata = build_problem(width, depth, cfg)\n",
        "        timing = measure_step_times(objective, state, cfg, dtype)\n",
        "        checkpoint = warmup_checkpoint(state, objective, cfg, dtype)\n",
        "        comparison = compare_adam_vs_cg(objective, checkpoint, cfg, dtype)\n",
        "\n",
        "        entries.append({\n",
        "            sweep_key: int(value),\n",
        "            \"width\": width,\n",
        "            \"depth\": depth,\n",
        "            \"param_count\": metadata[\"param_count\"],\n",
        "            \"adam_step_time_mean\": timing[\"adam\"][\"step_time_mean\"],\n",
        "            \"cg_step_time_mean\": timing[\"cg_fisher\"][\"step_time_mean\"],\n",
        "            \"cg_step_over_adam\": timing[\"cg_fisher\"][\"ratio_vs_adam\"],\n",
        "            \"comparison\": comparison,\n",
        "        })\n",
        "\n",
        "    return {\n",
        "        \"experiment\": sweep_key,\n",
        "        \"entries\": entries,\n",
        "    }\n",
        "\n",
        "\n",
        "def summarize_results(results: Dict[str, Any]) -> pd.DataFrame:\n",
        "    rows = []\n",
        "    for entry in results[\"entries\"]:\n",
        "        cmp_result = entry[\"comparison\"]\n",
        "        rows.append({\n",
        "            \"depth\": entry[\"depth\"],\n",
        "            \"width\": entry[\"width\"],\n",
        "            \"param_count\": entry[\"param_count\"],\n",
        "            \"adam_final\": cmp_result[\"adam\"][\"final_loss\"],\n",
        "            \"cg_final\": cmp_result[\"cg_fisher\"][\"final_loss\"],\n",
        "            \"cg_final_over_adam\": cmp_result[\"cg_final_loss_ratio_vs_adam\"],\n",
        "            \"cg_beats_adam_by_steps\": cmp_result[\"cg_beats_adam_by_steps\"],\n",
        "            \"cg_beats_adam_by_time\": cmp_result[\"cg_beats_adam_by_time\"],\n",
        "            \"cg_matches_adam_final_within_adam_time_budget\": cmp_result[\"cg_matches_adam_final_within_adam_time_budget\"],\n",
        "            \"cg_steps_to_reach_adam_final\": cmp_result[\"cg_steps_to_reach_adam_final_loss\"],\n",
        "            \"cg_time_to_reach_adam_final\": cmp_result[\"cg_time_to_reach_adam_final_loss\"],\n",
        "            \"adam_total_time\": cmp_result[\"adam_total_time\"],\n",
        "            \"cg_total_time\": cmp_result[\"cg_total_time\"],\n",
        "            \"adam_probe_step_time\": entry[\"adam_step_time_mean\"],\n",
        "            \"cg_probe_step_time\": entry[\"cg_step_time_mean\"],\n",
        "            \"cg_probe_over_adam\": entry[\"cg_step_over_adam\"],\n",
        "            \"phase2_tail_steps\": cmp_result[\"phase2_tail_steps\"],\n",
        "            \"adam_phase2_tail_step_time\": cmp_result[\"adam_phase2_tail_step_time_mean\"],\n",
        "            \"cg_phase2_tail_step_time\": cmp_result[\"cg_phase2_tail_step_time_mean\"],\n",
        "            \"cg_phase2_tail_over_adam\": cmp_result[\"cg_phase2_tail_over_adam\"],\n",
        "        })\n",
        "    return pd.DataFrame(rows)"
      ],
      "id": "6x2qKlpcvpYl"
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JZpKKh-lvpYl"
      },
      "source": [
        "## Run The Sweep\n",
        "\n",
        "The sweep uses `50 x 2 -> 50 x 5`."
      ],
      "id": "JZpKKh-lvpYl"
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 194
        },
        "id": "PqUXKHJyvpYl",
        "outputId": "fa024e90-7118-45bc-cf0a-284df17de137"
      },
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "   depth  width  param_count  adam_final  cg_final  cg_final_over_adam  \\\n",
              "0      2     50         2751    1.421152  0.006711            0.004722   \n",
              "1      3     50         5301    0.422066  0.007711            0.018269   \n",
              "2      4     50         7851    0.597357  0.005621            0.009411   \n",
              "3      5     50        10401    0.851788  0.003289            0.003862   \n",
              "\n",
              "   cg_beats_adam_by_steps  cg_beats_adam_by_time  \\\n",
              "0                    True                   True   \n",
              "1                    True                   True   \n",
              "2                    True                   True   \n",
              "3                    True                   True   \n",
              "\n",
              "   cg_matches_adam_final_within_adam_time_budget  \\\n",
              "0                                           True   \n",
              "1                                           True   \n",
              "2                                           True   \n",
              "3                                           True   \n",
              "\n",
              "   cg_steps_to_reach_adam_final  cg_time_to_reach_adam_final  adam_total_time  \\\n",
              "0                            12                     0.471104         3.541279   \n",
              "1                            28                     1.944377         5.494027   \n",
              "2                            21                     2.082808         7.346237   \n",
              "3                            19                     2.449647         9.242818   \n",
              "\n",
              "   cg_total_time  adam_probe_step_time  cg_probe_step_time  \\\n",
              "0      39.081856              0.003703            0.039212   \n",
              "1      69.296378              0.005601            0.069193   \n",
              "2      99.060169              0.007481            0.099101   \n",
              "3     128.951853              0.009377            0.128884   \n",
              "\n",
              "   cg_probe_over_adam  phase2_tail_steps  adam_phase2_tail_step_time  \\\n",
              "0           10.589548                100                    0.003545   \n",
              "1           12.354791                100                    0.005504   \n",
              "2           13.246201                100                    0.007348   \n",
              "3           13.744293                100                    0.009258   \n",
              "\n",
              "   cg_phase2_tail_step_time  cg_phase2_tail_over_adam  \n",
              "0                  0.039176                 11.051247  \n",
              "1                  0.069258                 12.583073  \n",
              "2                  0.099046                 13.480001  \n",
              "3                  0.128949                 13.928690  "
            ],
            "text/html": [
              "<div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>depth</th>\n",
              "      <th>width</th>\n",
              "      <th>param_count</th>\n",
              "      <th>adam_final</th>\n",
              "      <th>cg_final</th>\n",
              "      <th>cg_final_over_adam</th>\n",
              "      <th>cg_beats_adam_by_steps</th>\n",
              "      <th>cg_beats_adam_by_time</th>\n",
              "      <th>cg_matches_adam_final_within_adam_time_budget</th>\n",
              "      <th>cg_steps_to_reach_adam_final</th>\n",
              "      <th>cg_time_to_reach_adam_final</th>\n",
              "      <th>adam_total_time</th>\n",
              "      <th>cg_total_time</th>\n",
              "      <th>adam_probe_step_time</th>\n",
              "      <th>cg_probe_step_time</th>\n",
              "      <th>cg_probe_over_adam</th>\n",
              "      <th>phase2_tail_steps</th>\n",
              "      <th>adam_phase2_tail_step_time</th>\n",
              "      <th>cg_phase2_tail_step_time</th>\n",
              "      <th>cg_phase2_tail_over_adam</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>2</td>\n",
              "      <td>50</td>\n",
              "      <td>2751</td>\n",
              "      <td>1.421152</td>\n",
              "      <td>0.006711</td>\n",
              "      <td>0.004722</td>\n",
              "      <td>True</td>\n",
              "      <td>True</td>\n",
              "      <td>True</td>\n",
              "      <td>12</td>\n",
              "      <td>0.471104</td>\n",
              "      <td>3.541279</td>\n",
              "      <td>39.081856</td>\n",
              "      <td>0.003703</td>\n",
              "      <td>0.039212</td>\n",
              "      <td>10.589548</td>\n",
              "      <td>100</td>\n",
              "      <td>0.003545</td>\n",
              "      <td>0.039176</td>\n",
              "      <td>11.051247</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>3</td>\n",
              "      <td>50</td>\n",
              "      <td>5301</td>\n",
              "      <td>0.422066</td>\n",
              "      <td>0.007711</td>\n",
              "      <td>0.018269</td>\n",
              "      <td>True</td>\n",
              "      <td>True</td>\n",
              "      <td>True</td>\n",
              "      <td>28</td>\n",
              "      <td>1.944377</td>\n",
              "      <td>5.494027</td>\n",
              "      <td>69.296378</td>\n",
              "      <td>0.005601</td>\n",
              "      <td>0.069193</td>\n",
              "      <td>12.354791</td>\n",
              "      <td>100</td>\n",
              "      <td>0.005504</td>\n",
              "      <td>0.069258</td>\n",
              "      <td>12.583073</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>4</td>\n",
              "      <td>50</td>\n",
              "      <td>7851</td>\n",
              "      <td>0.597357</td>\n",
              "      <td>0.005621</td>\n",
              "      <td>0.009411</td>\n",
              "      <td>True</td>\n",
              "      <td>True</td>\n",
              "      <td>True</td>\n",
              "      <td>21</td>\n",
              "      <td>2.082808</td>\n",
              "      <td>7.346237</td>\n",
              "      <td>99.060169</td>\n",
              "      <td>0.007481</td>\n",
              "      <td>0.099101</td>\n",
              "      <td>13.246201</td>\n",
              "      <td>100</td>\n",
              "      <td>0.007348</td>\n",
              "      <td>0.099046</td>\n",
              "      <td>13.480001</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3</th>\n",
              "      <td>5</td>\n",
              "      <td>50</td>\n",
              "      <td>10401</td>\n",
              "      <td>0.851788</td>\n",
              "      <td>0.003289</td>\n",
              "      <td>0.003862</td>\n",
              "      <td>True</td>\n",
              "      <td>True</td>\n",
              "      <td>True</td>\n",
              "      <td>19</td>\n",
              "      <td>2.449647</td>\n",
              "      <td>9.242818</td>\n",
              "      <td>128.951853</td>\n",
              "      <td>0.009377</td>\n",
              "      <td>0.128884</td>\n",
              "      <td>13.744293</td>\n",
              "      <td>100</td>\n",
              "      <td>0.009258</td>\n",
              "      <td>0.128949</td>\n",
              "      <td>13.928690</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>"
            ]
          },
          "metadata": {}
        }
      ],
      "source": [
        "depth_results = run_architecture_sweep(\n",
        "    CONFIG,\n",
        "    sweep_key=\"depth\",\n",
        "    values=CONFIG[\"depth_values\"],\n",
        "    fixed_width=CONFIG[\"depth_width\"],\n",
        ")\n",
        "\n",
        "depth_df = summarize_results(depth_results)\n",
        "display(depth_df)"
      ],
      "id": "PqUXKHJyvpYl"
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "QfLe8S-EvpYm",
        "outputId": "98c5d6df-9a1a-4f4b-91a1-c921b7a27000"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "=== Sweep ===\n",
            "step-wise wins: 4/4\n",
            "time-wise wins: 4/4\n",
            "reaches Adam final within Adam time budget: 4/4\n",
            "CG / Adam per-step ratio over last 100 Phase II epochs: median=13.03, min=11.05, max=13.93\n",
            "CG steps to reach Adam final: min=12, max=28\n",
            "\n"
          ]
        }
      ],
      "source": [
        "def print_takeaway(df: pd.DataFrame, label: str) -> None:\n",
        "    step_wins = int(df[\"cg_beats_adam_by_steps\"].sum())\n",
        "    time_wins = int(df[\"cg_beats_adam_by_time\"].sum())\n",
        "    match_budget = int(df[\"cg_matches_adam_final_within_adam_time_budget\"].sum())\n",
        "    tail_steps = int(df[\"phase2_tail_steps\"].iloc[0])\n",
        "    median_ratio = float(df[\"cg_phase2_tail_over_adam\"].median())\n",
        "    ratio_min = float(df[\"cg_phase2_tail_over_adam\"].min())\n",
        "    ratio_max = float(df[\"cg_phase2_tail_over_adam\"].max())\n",
        "    step_series = pd.to_numeric(df[\"cg_steps_to_reach_adam_final\"], errors=\"coerce\").dropna()\n",
        "    steps_min = None if step_series.empty else int(step_series.min())\n",
        "    steps_max = None if step_series.empty else int(step_series.max())\n",
        "\n",
        "    print(f\"=== {label} ===\")\n",
        "    print(f\"step-wise wins: {step_wins}/{len(df)}\")\n",
        "    print(f\"time-wise wins: {time_wins}/{len(df)}\")\n",
        "    print(f\"reaches Adam final within Adam time budget: {match_budget}/{len(df)}\")\n",
        "    print(\n",
        "        f\"CG / Adam per-step ratio over last {tail_steps} Phase II epochs: \"\n",
        "        f\"median={median_ratio:.2f}, min={ratio_min:.2f}, max={ratio_max:.2f}\"\n",
        "    )\n",
        "    print(f\"CG steps to reach Adam final: min={steps_min}, max={steps_max}\")\n",
        "    print()\n",
        "\n",
        "\n",
        "print_takeaway(depth_df, \"Sweep\")"
      ],
      "id": "QfLe8S-EvpYm"
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python",
      "version": "3.12"
    },
    "colab": {
      "provenance": [],
      "machine_shape": "hm",
      "gpuType": "L4"
    },
    "accelerator": "GPU"
  },
  "nbformat": 4,
  "nbformat_minor": 5
}