{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [
        {
          "file_id": "13IY7lcl3cYVAiau8xEKNFa-BSyAej-iT",
          "timestamp": 1723232395890
        }
      ],
      "last_runtime": {
        "build_target": "//learning/grp/tools/ml_python/gpu:ml_notebook",
        "kind": "private"
      }
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4Di34zgQaOzr"
      },
      "outputs": [],
      "source": [
        "import scipy as sp\n",
        "import numpy as np\n",
        "import jax.numpy as jnp\n",
        "from collections.abc import Callable\n",
        "import functools\n",
        "import jax\n",
        "import jax.scipy as jsp\n",
        "import jaxopt\n",
        "from typing import TypeVar, Any\n",
        "import abc\n",
        "import flax\n",
        "import equinox\n",
        "from jax import config\n",
        "import time\n",
        "\n",
        "config.update('jax_enable_x64', True)"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Prior Work: Explicit Optimization"
      ],
      "metadata": {
        "id": "953qpe0t8Jj4"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def get_circulant_idx(n: int) -> jnp.ndarray:\n",
        "  \"\"\"Computes a symmetric Circulant matrix where entries are integer indices [0, ..., n-1].\n",
        "\n",
        "  In particular, T_{ij} = | i - j |.  For example, get_circulant_idx(4) returns:\n",
        "  ```\n",
        "  [0 1 2 3]\n",
        "  [1 0 1 2]\n",
        "  [2 1 0 1]\n",
        "  [3 2 1 0]\n",
        "  ```\n",
        "\n",
        "  Args:\n",
        "    n: the size of the Toeplitz matrix\n",
        "\n",
        "  Returns:\n",
        "    A symmetric Circulant matrix with integer indices as entries.\n",
        "  \"\"\"\n",
        "  return jnp.array(sp.linalg.toeplitz(np.arange(n)))\n",
        "\n",
        "\n",
        "def banded_symmetric_mask(n: int, num_bands: int) -> jnp.ndarray:\n",
        "  \"\"\"Returns n x n symmetric {0, 1} matrix with 2b - 1 bands of 1s.\"\"\"\n",
        "  b = num_bands\n",
        "  if b < 1:\n",
        "    raise ValueError(f'num_bands must be >= 0, found {num_bands}')\n",
        "  return (jnp.tri(n, k=b - 1) - jnp.tri(n, k=-b)).astype(jnp.int32)\n",
        "\n",
        "\n",
        "def get_orthogonal_mask(n: int, epochs: int = 1) -> jnp.ndarray:\n",
        "  \"\"\"Computes a mask that imposes orthognality constraints on the optimization.\n",
        "\n",
        "  This is specific to the fixed-epoch-order (k, b)-participation schema of\n",
        "  https://arxiv.org/pdf/2211.06530.pdf, where participations are separated by\n",
        "  exactly b-1 steps, and b = n / epochs.\n",
        "\n",
        "  This mask sets entry M_{ij} = 0 if i == j (mod n/epochs) and M_{ij} = 1\n",
        "  otherwise.  Sensitivity for any matrix with 0s in these entries is easy to\n",
        "  calculate as only a function of the diagonal.  Moreover, the sensitivity is\n",
        "  equal for all possible {-1,1} participation vectors.\n",
        "\n",
        "  Args:\n",
        "    n: the size of the mask\n",
        "    epochs: The number of epochs\n",
        "\n",
        "  Returns:\n",
        "    A 0/1 mask\n",
        "  \"\"\"\n",
        "  mask = np.ones((n, n))\n",
        "  for i in range(n // epochs):\n",
        "    mask[i :: n // epochs, i :: n // epochs] = np.eye(epochs)\n",
        "  return jnp.array(mask)\n",
        "\n",
        "\n",
        "class MatrixFactorizer:\n",
        "  \"\"\"Class for factorizing matrices.\"\"\"\n",
        "\n",
        "  def __init__(\n",
        "      self,\n",
        "      iterations: int,\n",
        "      epochs: int = 1,\n",
        "      bands: int | None = None,\n",
        "      circulant: bool = False,\n",
        "      equal_norm: bool = False,\n",
        "  ):\n",
        "    \"\"\"Initializes this MatrixFactorizer object.\n",
        "\n",
        "    Note: Although this implementation of MatrixFactorizer supports optimization\n",
        "      of structured matrices, it does nothing to exploit their structure to\n",
        "      speed up optimization.\n",
        "\n",
        "    Note: Currently, this class only supports the canonical (k,b) participation\n",
        "      pattern where a user contributes k times total, every b iterations.\n",
        "      Currently, this class specifically supports the fixed-epoch-order\n",
        "      (k, b)-participation schema of https://arxiv.org/pdf/2211.06530.pdf,\n",
        "      where participations are separated by exactly b-1 steps, and\n",
        "      b = num_iterations / num_epochs.\n",
        "\n",
        "    Note: In this class we are always imposing certain orthogonality constraints\n",
        "    to make the optimization problem easier to solve and tractable for large\n",
        "    numbers of epochs.  Specifically, we require that X_{ij} = 0 if i != j and\n",
        "    a user can appear in both iteration i and iteration j.  This ensures that\n",
        "    columns i and j of C are orthogonal, and that their L2^2 sensitivites simply\n",
        "    add up. We have found empirically that this structure is optimal for the\n",
        "    Prefix workload and we conjecture it may be optimal more generally.\n",
        "\n",
        "    Args:\n",
        "      iterations: The number of iterations (size of the workload matrix).\n",
        "      epochs: A positive integer in the range [1,n] that evenly divides n. Is\n",
        "        used to determine the participation pattern and define the sensitivity\n",
        "        of a given matrix.\n",
        "      bands: [Optional] A positive integer in the range [1,n]. A structural\n",
        "        constraint to place on the matrix X.  If set, X has 2*bands - 1 bands\n",
        "        with (possible) nonzeros, so e.g. bands=1 gives a diagonal matrix, and\n",
        "        bands=iterations allows a dense X.\n",
        "      circulant: [Optional] A flag to indicate whether X should be constrained to\n",
        "        have symmetric Circulant structure (all entries in a each band have the same\n",
        "        value). Note: This implies C will be \"almost\" Toeplitz, but with columns\n",
        "        having equal norms (hence, the equal_norm option is a no-op when\n",
        "        toeplitz=True).\n",
        "      equal_norm: [Optional] A flag to indicate if columns of C should have\n",
        "        equal norm (i.e., X_ii = 1/epochs).\n",
        "    \"\"\"\n",
        "    self._n = iterations\n",
        "    self._k = epochs\n",
        "    self._equal_norm = equal_norm\n",
        "    self._circulant = circulant\n",
        "    if circulant:\n",
        "      self._Tidx = get_circulant_idx(self._n)\n",
        "    # These masks determine which entries of X are allowed to be non-zero.\n",
        "    orth_mask = get_orthogonal_mask(self._n, epochs)\n",
        "    if bands is not None:\n",
        "      banded_mask = banded_symmetric_mask(self._n, bands)\n",
        "      self._mask = orth_mask * banded_mask\n",
        "    else:\n",
        "      self._mask = orth_mask\n",
        "\n",
        "  @functools.partial(jax.jit, static_argnums=(0,))\n",
        "  def project_update(self, dX: jnp.ndarray) -> jnp.ndarray:\n",
        "    r\"\"\"Project dX so that X + alpha*dX satisfies constraints for any alpha.\n",
        "\n",
        "    Note: this function assumes that X already satisfies the constraints.\n",
        "\n",
        "    This function does multiple things:\n",
        "      1. To ensure the sensitivity of the resulting mechanism remains 1:\n",
        "        a. It sets $dX[i,j] = 0$ if $i \\neq j$ and a user can appear in both\n",
        "          round $i$ and $j$.\n",
        "        b. It normalizs $sum_i dX[i,i] = 0$, where sum is taken over rounds a\n",
        "          single user can participate in.  This ensures that sum_i X[i,i]\n",
        "          remains equal to 1.\n",
        "      2. If banded constraints are given, sets dX[i,j] = 0 if |i-j| > # bands.\n",
        "      3. If Toeplitz structure is needed, ensures dX is a toeplitz matrix.  This\n",
        "        also ensures X is Toeplitz because Toeplitz + Toeplitz = Toeplitz.\n",
        "\n",
        "    Args:\n",
        "      dX: an n x n matrix, representing the gradient with respect to X.\n",
        "\n",
        "    Returns:\n",
        "      an n x n matrix, representing the projected gradient.\n",
        "    \"\"\"\n",
        "    if self._equal_norm:\n",
        "      diag = 0\n",
        "    else:  # Implement 1(b) from above:\n",
        "      dsum = jnp.diag(dX).reshape(self._k, -1).sum(axis=0) / self._k\n",
        "      diag = jnp.diag(dX) - jnp.kron(jnp.ones(self._k), dsum)\n",
        "    dX = dX.at[jnp.diag_indices(self._n)].set(diag)\n",
        "    # Implement 1(a) and 2 from above:\n",
        "    dX = dX * self._mask\n",
        "    if self._circulant:  # Implement 3 from above:\n",
        "      # We sum the gradient along each (pair of) diagonals, which by chain rule\n",
        "      # gives the derivate wrt the Toeplitz coefficients.  Toeplitz matrix with\n",
        "      # these coefficients is constructed so the shape is compatible with X.\n",
        "      locs = self._Tidx.flatten()\n",
        "      weights = dX.flatten()\n",
        "      dX = jnp.bincount(locs, weights, length=self._n)[self._Tidx]\n",
        "    return dX\n",
        "\n",
        "  @functools.partial(jax.jit, static_argnums=(0,))\n",
        "  def loss_and_gradient(\n",
        "      self, X: jnp.ndarray, A: jnp.ndarray\n",
        "  ) -> tuple[float, jnp.ndarray]:\n",
        "    r\"\"\"Computes the matrix mechanism total squared error loss and gradient.\n",
        "\n",
        "    This function computes $\\tr[A^T A X^{-1}]$ and the associated gradient\n",
        "    $dX = -X^{-1} A^T A X^{-1}$.  It assumes that $X$ is a symmetric positive\n",
        "    definite matrix.  For efficiency, no error is thrown if this assumption is\n",
        "    not satisfied, but the returned loss or gradient may contain NaN's if this\n",
        "    is the case.\n",
        "\n",
        "    Args:\n",
        "      X: The current iterate, an n x n matrix\n",
        "      A: The workload, an n x n matrix\n",
        "\n",
        "    Returns:\n",
        "      loss: a real-valued number\n",
        "      gradient: the (projected) gradient of the loss w.r.t. X, an n x n matrix\n",
        "    \"\"\"\n",
        "    H = jsp.linalg.solve(X, A.T, assume_a='pos')\n",
        "    return jnp.trace(H @ A), self.project_update(-H @ H.T)  # pytype: disable=bad-return-type  # jnp-type\n",
        "\n",
        "  @functools.partial(jax.jit, static_argnums=(0,))\n",
        "  def _lbfgs_direction(\n",
        "      self, X: jnp.ndarray, dX: jnp.ndarray, X1: jnp.ndarray, dX1: jnp.ndarray\n",
        "  ) -> jnp.ndarray:\n",
        "    \"\"\"Computes the LBFGS search direction.\n",
        "\n",
        "    Given the current/previous iterates (X and X1) and the current/previous\n",
        "    gradients (dX and dX1), compute a search direction (Z) according to the\n",
        "    LBFGS update rule.\n",
        "\n",
        "    Args:\n",
        "      X: The current iterate, an n x n matrix\n",
        "      dX: The current gradient, an n x n matrix\n",
        "      X1: The previous iterate, an n x n matrix\n",
        "      dX1: The previous gradient, an n x n matrix\n",
        "\n",
        "    Returns:\n",
        "      The (negative) search direction, an n x n matrix\n",
        "    \"\"\"\n",
        "    S = X - X1\n",
        "    Y = dX - dX1\n",
        "    rho = 1.0 / jnp.sum(Y * S)\n",
        "    alpha = rho * jnp.sum(S * dX)\n",
        "    gamma = jnp.sum(S * Y) / jnp.sum(Y**2)\n",
        "    Z = gamma * (dX - rho * jnp.sum(S * dX) * Y)\n",
        "    beta = rho * jnp.sum(Y * Z)\n",
        "    Z = Z + S * (alpha - beta)\n",
        "    return Z\n",
        "\n",
        "  def optimize(\n",
        "      self,\n",
        "      A: jnp.ndarray,\n",
        "      iters: int = 1000,\n",
        "      metric_callback: Callable[[int, dict[str, float]], None] | None = None,\n",
        "      initial_X: jnp.ndarray | None = None,\n",
        "  ) -> jnp.ndarray:\n",
        "    \"\"\"Optimize the strategy matrix with an iterative gradient-based method.\n",
        "\n",
        "    This function optimizes the total squared error of the mechanism.  To\n",
        "    optimize maximum per-query error, consider using optimize_max_error instead.\n",
        "\n",
        "    Args:\n",
        "      A: The input workload, a lower triangular n x n matrix.\n",
        "      iters: The number of iterations to run the optimization.\n",
        "      metric_callback: A function for logging that must consume a dictionary of\n",
        "        metrics.\n",
        "      initial_X: Matrix to use as the starting value for initialization. Assumed\n",
        "        to have sensitivity 1 under the (k, b) participation pattern for which\n",
        "        `self` is configured. Sensitivity will remain unchanged for this pattern\n",
        "        during the course of optimization. In (k, b) participation, any entries\n",
        "        in initial_X at index [i, j] with i != j will remain unchanged if a user\n",
        "        can participate in both iteration i and j. Such entries must be set to\n",
        "        0. If `None`, defaults to a normalized identity.\n",
        "      termination_fn: Function which controls early termination. Must take X,\n",
        "        dX, and loss as keyword arguments. Returning true indicates that X\n",
        "        should be immediately returned from the optimization procedure.\n",
        "\n",
        "    Returns:\n",
        "      A matrix X that approximately minimizes the objective tr[A^T A X^{-1}] and\n",
        "      satisfies the sensitivity and structural constraints.\n",
        "    \"\"\"\n",
        "    # TODO: b/296607503 - consider normalizing by n here instead.\n",
        "    normalization_factor = A.mean()\n",
        "    A = (\n",
        "        A / normalization_factor\n",
        "    )  # Helpful for numerical stability in some cases.\n",
        "    if initial_X is None:\n",
        "      X = jnp.eye(self._n, dtype=jnp.float64) / self._k\n",
        "    else:\n",
        "      # It may be desirable to check / raise on sensitivity here. For now,\n",
        "      # assume our callers performed this check.\n",
        "      X = initial_X\n",
        "    if not np.all((1 - self._mask) * X == 0):\n",
        "      raise ValueError(\n",
        "          'Initial X matrix is nonzero in indices i, j where '\n",
        "          'i != j and some user can participate in rounds i and '\n",
        "          'j. Such entries being zero is generally assumed by the '\n",
        "          'optimization code here and downstream consumers in '\n",
        "          'order to easily reason about sensitivity.'\n",
        "      )\n",
        "    loss, dX = self.loss_and_gradient(X, A)\n",
        "    X1 = X  # X at previous iteration\n",
        "    dX1 = dX  # dX at previous iteration\n",
        "    loss1 = loss  # Loss at previous iteration\n",
        "    Z = dX  # The negative search direction (different from dX in general)\n",
        "\n",
        "    for step in range(iters):\n",
        "      step_size = 1.0\n",
        "      for _ in range(30):\n",
        "        X = X1 - step_size * Z\n",
        "        loss, dX = self.loss_and_gradient(X, A)\n",
        "        if jnp.isnan(loss).any() or jnp.isnan(dX).any():\n",
        "          step_size *= 0.25\n",
        "        elif loss < loss1:\n",
        "          loss1 = loss\n",
        "          break\n",
        "\n",
        "      if metric_callback is not None:\n",
        "        metric_callback(step, {'loss': loss})\n",
        "\n",
        "      Z = self._lbfgs_direction(X, dX, X1, dX1)\n",
        "      X1 = X\n",
        "      dX1 = dX\n",
        "    return X"
      ],
      "metadata": {
        "id": "EXAGc4sA8Lae"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "A = jnp.tri(128)\n",
        "\n",
        "opt = MatrixFactorizer(128, circulant=False)\n",
        "X = opt.optimize(A, iters=100)\n",
        "print(opt.loss_and_gradient(X, A)[0])\n",
        "\n",
        "opt = MatrixFactorizer(128, circulant=True)\n",
        "X = opt.optimize(A, iters=100)\n",
        "print(opt.loss_and_gradient(X, A)[0])"
      ],
      "metadata": {
        "id": "ZwliRe4787W1",
        "executionInfo": {
          "status": "ok",
          "timestamp": 1723236378942,
          "user_tz": 420,
          "elapsed": 1078,
          "user": {
            "displayName": "Ryan McKenna",
            "userId": "14049140147394149988"
          }
        },
        "outputId": "3947fc37-8fe4-4ad5-9d3d-35a693a11cc2"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "683.6131153366335\n",
            "728.1238421991167\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "A.dtype"
      ],
      "metadata": {
        "id": "Sl7dUr4A9ma5",
        "executionInfo": {
          "status": "ok",
          "timestamp": 1723236379084,
          "user_tz": 420,
          "elapsed": 7,
          "user": {
            "displayName": "Ryan McKenna",
            "userId": "14049140147394149988"
          }
        },
        "outputId": "7ec3fadc-18ab-4927-f43d-f473b8addd47"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "dtype('float32')"
            ]
          },
          "metadata": {},
          "execution_count": 43
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "NLX0Uci9-Uie"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Helper Classes"
      ],
      "metadata": {
        "id": "Wp4eSY_i-519"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "class StreamingLinearOperator(abc.ABC):\n",
        "  \"\"\"A linear mapping x -> C x for a lower triangular C matrix.\n",
        "\n",
        "  The linear mapping can be represented as an (init_fn, next_fn) pair, where\n",
        "  init_fn constructs an initial state, and next_fn computes the next state and\n",
        "  output from the current state and input.  Importantly, the output y[i] where\n",
        "  y = C x only depends on the x[i] and state captured from computing y[0], ...,\n",
        "  y[i-1].\n",
        "  \"\"\"\n",
        "\n",
        "  def __init__(\n",
        "      self,\n",
        "      init_fn: Callable[[tuple[int, ...]], Any],\n",
        "      next_fn: Callable[[Any, jnp.ndarray], tuple[Any, jnp.ndarray]],\n",
        "  ):\n",
        "    \"\"\"Construct a StreamingLinearOperator object.\n",
        "\n",
        "    This class couples the init/next representation of a lower triangular\n",
        "    matrix. The loop computes the matrix-vector product C x:\n",
        "\n",
        "    shape = (1,)\n",
        "    st = init_fn(shape)\n",
        "    for i in range(n):\n",
        "      st, y[i] = next_fn(st, x[i])\n",
        "\n",
        "    Args:\n",
        "      init_fn: A function that returns the initial state given the expected\n",
        "        shape of inputs to each call to next_fn. For example, if multiplying\n",
        "        this `StreamingLinearOperator` with an n x n matrix X, the shape would be\n",
        "        (n,) corresponding to each row of X.\n",
        "      next_fn: A function that computes the next state and output given the\n",
        "        current state and input.\n",
        "    \"\"\"\n",
        "    self.init = init_fn\n",
        "    self.next = next_fn\n",
        "\n",
        "  def materialize(self, n: int) -> jnp.ndarray:\n",
        "    \"\"\"A utility method to materialize this matrix as an n x n ndarray.\n",
        "\n",
        "    Note `n` needs to be a parameter, because a general `StreamingLinearOperator`\n",
        "    can represent an infinite x infinite matrix (e.g., prefix sum or\n",
        "    SGDM matrices below).\n",
        "\n",
        "    Args:\n",
        "      n: The size of the square matrix to materialize.\n",
        "\n",
        "    Returns:\n",
        "      An n x n materialization of this matrix.\n",
        "    \"\"\"\n",
        "    return self.__matmul__(jnp.eye(n))\n",
        "\n",
        "  def __mul__(A: 'StreamingLinearOperator', B: 'StreamingLinearOperator') -> 'StreamingLinearOperator':\n",
        "    \"\"\"Multiply a StreamingLinearOperator by another StreamingLinearOperator.\n",
        "\n",
        "    Args:\n",
        "      A: The left hand side matrix\n",
        "      B: The right hand side matrix\n",
        "\n",
        "    Returns:\n",
        "      A * B, represented as another StreamingLinearOperator.\n",
        "    \"\"\"\n",
        "    def init_fn(shape=()):\n",
        "      return A.init(shape), B.init(shape)\n",
        "\n",
        "    def next_fn(state, value):\n",
        "      A_state, B_state = state\n",
        "      B_state, inner = B.next(B_state, value)\n",
        "      A_state, outer = A.next(A_state, inner)\n",
        "      return (A_state, B_state), outer\n",
        "\n",
        "    return StreamingLinearOperator(init_fn, next_fn)\n",
        "\n",
        "  def __matmul__(A: 'StreamingLinearOperator', Z: jnp.ndarray) -> jnp.ndarray:\n",
        "    return jax.lax.scan(A.next, A.init(Z.shape[1:]), Z)[1]\n",
        "\n",
        "def identity() -> StreamingLinearOperator:\n",
        "  \"\"\"An implicit representation of the identity matrix.\"\"\"\n",
        "  return StreamingLinearOperator(lambda _: (), lambda _, value: ((), value))\n",
        "\n",
        "\n",
        "def prefix_sum() -> StreamingLinearOperator:\n",
        "  \"\"\"An implicit representation of the lower triangular matrix of ones.\"\"\"\n",
        "\n",
        "  def init_fn(shape: tuple[int, ...]) -> jnp.ndarray:\n",
        "    return jnp.zeros(shape)\n",
        "\n",
        "  def next_fn(\n",
        "      state: jnp.ndarray, value: jnp.ndarray\n",
        "  ) -> tuple[jnp.ndarray, jnp.ndarray]:\n",
        "    result = state + value\n",
        "    return result, result\n",
        "\n",
        "  return StreamingLinearOperator(init_fn, next_fn)\n",
        "\n",
        "A = prefix_sum()\n",
        "A.materialize(5)"
      ],
      "metadata": {
        "id": "8k-qgseG_NU5",
        "executionInfo": {
          "status": "ok",
          "timestamp": 1723236379802,
          "user_tz": 420,
          "elapsed": 133,
          "user": {
            "displayName": "Ryan McKenna",
            "userId": "14049140147394149988"
          }
        },
        "outputId": "4124c3e5-f644-4370-b352-c57262e124a5"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "Array([[1., 0., 0., 0., 0.],\n",
              "       [1., 1., 0., 0., 0.],\n",
              "       [1., 1., 1., 0., 0.],\n",
              "       [1., 1., 1., 1., 0.],\n",
              "       [1., 1., 1., 1., 1.]], dtype=float64)"
            ]
          },
          "metadata": {},
          "execution_count": 44
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "def default_toeplitz_coefficients(n: int) -> jnp.ndarray:\n",
        "  \"\"\"Returns the coefs of the optimal Toeplitz strategy matrix C for max error.\n",
        "\n",
        "  These coefficients were introduced by Fichtenberger, Henzinger, and Upadhyay\n",
        "  in \"Constant Matters: Fine-grained Error Bound on Differentially Private\n",
        "  Continual Observation\"\n",
        "  (https://proceedings.mlr.press/v202/fichtenberger23a/fichtenberger23a.pdf,\n",
        "  https://arxiv.org/pdf/2202.11205).\n",
        "\n",
        "  Args:\n",
        "    n: The number of coefficients to return.\n",
        "\n",
        "  Returns:\n",
        "    The coefficients of the lower-triangular Toeplitz matrix C that\n",
        "    factorizes the prefix sum matrix A as A = C @ C.\n",
        "  \"\"\"\n",
        "  k = jnp.arange(n)\n",
        "  return jnp.cumprod(((2 * k - 1) / (2 * k)).at[0].set(1))"
      ],
      "metadata": {
        "id": "KcpWJ4sLicIi"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import abc\n",
        "\n",
        "class Parameterization(abc.ABC):\n",
        "  \"\"\"A class for parameterizing lower triangular encoder matrices.\n",
        "\n",
        "  A parameterization defines a mapping from a set of parameters to a lower\n",
        "  triangular encoder matrix C through the \"materialize\" method. It is to be\n",
        "  used within ParameterizedMatrixFactorization as the template strategy.\n",
        "\n",
        "  In order to be used with parameterized_optimization.py, the C matrix should\n",
        "  have fixed (or bounded) sensitivity under an appropriate participation schema;\n",
        "  that is, it is the responsibility of the Parameterization to implement any\n",
        "  necessary projection onto the set of bounded-sensitivity C matrices.\n",
        "\n",
        "  Most implementations of this class should implement a `default` classmethod,\n",
        "  which constructs a default instance given parameterization-specific arguments.\n",
        "  The primary purpose of this classmethod is to initialize a Parameterization\n",
        "  object that can then be passed into parameterized_optimization.optimize as\n",
        "  the initial encoder matrix.\n",
        "  \"\"\"\n",
        "\n",
        "  @property\n",
        "  @abc.abstractmethod\n",
        "  def n(self) -> int:\n",
        "    \"\"\"The size of the matrix implied by the parameterization.\"\"\"\n",
        "\n",
        "  @abc.abstractmethod\n",
        "  def materialize(self) -> jnp.ndarray:\n",
        "    \"\"\"Constructs an explicit encoder matrix C.\n",
        "\n",
        "    Concrete implementations of this abstract method must be written in terms\n",
        "    of jax primitives so that jax can back-propagate gradients through this\n",
        "    function during optimization in parameterized_optimization.py.\n",
        "    \"\"\"\n",
        "\n",
        "  @abc.abstractmethod\n",
        "  def inverse(\n",
        "      self,\n",
        "  ) -> StreamingLinearOperator:\n",
        "    \"\"\"Returns a `StreamingLinearOperator` object for the inverse.\"\"\"\n",
        "\n",
        "@flax.struct.dataclass\n",
        "class ColumnNormalizedBanded(Parameterization):\n",
        "  \"\"\"A column-normalized banded lower triangular n x n matrix.\n",
        "\n",
        "  This matrix class is parameterized by an arbitrary n x b matrix.\n",
        "  C(params) is obtained by setting the first b bands of C based on params.\n",
        "  The matrix is normalized to have sensitivity 1 under a single epoch,\n",
        "  by dividing each column by its respective norm.\n",
        "\n",
        "  Below we show how params relates to the matrix (before column normalization):\n",
        "\n",
        "  ```\n",
        "  params = [a b c]\n",
        "           [d e f]\n",
        "           [g h i]\n",
        "           [j k -]\n",
        "           [m - -]\n",
        "\n",
        "  C = [a        ]\n",
        "      [b d      ]\n",
        "      [c e g    ]\n",
        "      [  f h j  ]\n",
        "      [    i k m]\n",
        "  ```\n",
        "  \"\"\"\n",
        "\n",
        "  params: jnp.ndarray\n",
        "\n",
        "  @property\n",
        "  def n(self) -> int:\n",
        "    return self.params.shape[0]\n",
        "\n",
        "  @property\n",
        "  def bands(self) -> int:\n",
        "    return self.params.shape[1]\n",
        "\n",
        "  @classmethod\n",
        "  def from_banded_toeplitz(\n",
        "      cls, n: int, coefs: jnp.ndarray\n",
        "  ) -> 'ColumnNormalizedBanded':\n",
        "    \"\"\"Construct an instance of this object from banded toeplitz coefficients.\n",
        "\n",
        "    Args:\n",
        "      n: the number of training iterations.\n",
        "      coefs: an array of b toeplitz coefficients defining the strategy.\n",
        "\n",
        "    Returns:\n",
        "      A ColumnNormalizedBanded representation of the banded toeplitz matrix.\n",
        "    \"\"\"\n",
        "    bands = coefs.size\n",
        "    if bands > n or bands < 1:\n",
        "      raise ValueError(f'len(coefs) must be in the range [1, n], got {bands}')\n",
        "    coefs = coefs / jnp.linalg.norm(coefs)\n",
        "    params = jnp.broadcast_to(coefs, (n, bands))\n",
        "    params = jnp.tril(params[::-1])[::-1]  # set the lower right triangle to 0\n",
        "    return cls(params)\n",
        "\n",
        "  @classmethod\n",
        "  def default(cls, n: int, bands: int) -> 'ColumnNormalizedBanded':\n",
        "    \"\"\"Construct a default instance of this object given n and bands.\n",
        "\n",
        "    This object is initialized by using the fixed toeplitz strategy proposed\n",
        "    in [1; Algorithm 1], truncating to $b$ entries, and column normalizing.\n",
        "    It can act as a useful initialization for further optimization.\n",
        "\n",
        "    [1] https://proceedings.mlr.press/v202/fichtenberger23a/fichtenberger23a.pdf\n",
        "\n",
        "    Args:\n",
        "      n: the number of training iterations.\n",
        "      bands: the number of bands in the strategy.\n",
        "\n",
        "    Returns:\n",
        "      A ColumnNormalizedBanded object.\n",
        "    \"\"\"\n",
        "    coefs = default_toeplitz_coefficients(n)[:bands]\n",
        "    return ColumnNormalizedBanded.from_banded_toeplitz(n, coefs)\n",
        "\n",
        "  def materialize(self) -> jnp.ndarray:\n",
        "    I = jnp.arange(self.n)[:, None]\n",
        "    J = jnp.arange(self.n)[None]\n",
        "    D = I - J\n",
        "    indexer = (D + self.bands * J + 1) * (D >= 0) * (D < self.bands)\n",
        "    C = jnp.append(0, self.params.flatten())[indexer]\n",
        "    return C / jnp.linalg.norm(C, axis=0)\n",
        "\n",
        "  def inverse(self) -> StreamingLinearOperator:\n",
        "    \"\"\"Create $C^{-1}$ as a StreamingLinearOperator object.\"\"\"\n",
        "\n",
        "    def init_fn(shape):\n",
        "      return 0, jnp.zeros((self.bands,) + shape)\n",
        "\n",
        "    def next_fn(state, value):\n",
        "      index, bufs = state\n",
        "      if self.bands == 1:\n",
        "        return (index + 1, bufs), value\n",
        "      k = index % self.bands\n",
        "      r = jnp.arange(self.bands)\n",
        "      row = self.params[index - r, r]\n",
        "      # Algorithm 9 from https://arxiv.org/abs/2306.08153\n",
        "      # Compute xi = (value - row[1:] @ bufs[k-r][1:]) / row[0]\n",
        "      inner = jnp.tensordot(row[1:], bufs[k - r][1:], axes=((0,), (0,)))\n",
        "      xi = (value - inner) / row[0]\n",
        "      col_norm = jnp.linalg.norm(self.params[index])\n",
        "      return (index + 1, bufs.at[k].set(xi)), xi * col_norm\n",
        "\n",
        "    return StreamingLinearOperator(init_fn, next_fn)\n",
        "\n",
        "C = ColumnNormalizedBanded.default(n=5, bands=2)\n",
        "C.materialize()"
      ],
      "metadata": {
        "id": "NT47k5x2-7vU",
        "executionInfo": {
          "status": "ok",
          "timestamp": 1723236380249,
          "user_tz": 420,
          "elapsed": 127,
          "user": {
            "displayName": "Ryan McKenna",
            "userId": "14049140147394149988"
          }
        },
        "outputId": "41ca3692-1078-475b-acf2-ef5b1c7edd7f"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "Array([[0.89442719, 0.        , 0.        , 0.        , 0.        ],\n",
              "       [0.4472136 , 0.89442719, 0.        , 0.        , 0.        ],\n",
              "       [0.        , 0.4472136 , 0.89442719, 0.        , 0.        ],\n",
              "       [0.        , 0.        , 0.4472136 , 0.89442719, 0.        ],\n",
              "       [0.        , 0.        , 0.        , 0.4472136 , 1.        ]],      dtype=float64)"
            ]
          },
          "metadata": {},
          "execution_count": 46
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Parameterized Optimization of General Banded Matrices\n"
      ],
      "metadata": {
        "id": "1w280ufj-VqF"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "T = TypeVar('T', bound=Parameterization)\n",
        "LossAndGradFn = Callable[[T], tuple[float, T]]\n",
        "\n",
        "DEFAULT_LBFGS_OPTIMIZER = functools.partial(\n",
        "    jaxopt.LBFGS,\n",
        "    # These options improve robustness of optimization.\n",
        "    history_size=1,\n",
        "    stop_if_linesearch_fails=True,\n",
        "    use_gamma=False,\n",
        "    linesearch='backtracking',\n",
        "    maxls=100,\n",
        "    decrease_factor=0.70710678118,\n",
        ")\n",
        "\n",
        "def _robustify(fun: LossAndGradFn) -> LossAndGradFn:\n",
        "  \"\"\"Make a loss function more robust for usage in jaxopt by replacing nans.\"\"\"\n",
        "\n",
        "  def robust_fun(C):\n",
        "    value, grad = fun(C)\n",
        "    new_value = jnp.nan_to_num(value, nan=jnp.inf)\n",
        "    new_grad = jax.tree_util.tree_map(jnp.nan_to_num, grad)\n",
        "    return new_value, new_grad\n",
        "\n",
        "  return robust_fun\n",
        "\n",
        "\n",
        "def implicit_loss(\n",
        "    C: Parameterization,\n",
        "    A: StreamingLinearOperator,\n",
        ") -> float:\n",
        "  \"\"\"Compute the loss and gradient without materializing n x n matrices.\n",
        "\n",
        "  The objective function is || A C^{-1} ||_F^2. It is assumed that\n",
        "  C has constant fixed sensitivity under an appropriate participation schema.\n",
        "\n",
        "  Args:\n",
        "    C: the strategy matrix, represented implicitly.\n",
        "    A: The workload matrix, represented implicitly.\n",
        "\n",
        "  Returns:\n",
        "    The matrix mechanism loss, and associated gradient wrt params.\n",
        "  \"\"\"\n",
        "  A = A or prefix_sum()\n",
        "  B = A * C.inverse()\n",
        "  zero = jnp.zeros(C.n)\n",
        "\n",
        "  def next_state_and_row_norm(state, i):\n",
        "    ei = zero.at[i].set(1)\n",
        "    state, row = B.next(state, ei)\n",
        "    return state, row @ row\n",
        "\n",
        "  checkpoints = None\n",
        "  if isinstance(C, ColumnNormalizedBanded):\n",
        "    used = C.bands * C.n\n",
        "    limit = 2**29  # 4 GB\n",
        "    checkpoints = 2 ** int(np.log(limit // used) - 1)\n",
        "\n",
        "  # we use equinox rather than jax.lax to allow backprop through the scan\n",
        "  # without OOMing the accelerator.\n",
        "  row_norms_squared = equinox.internal.scan(\n",
        "      next_state_and_row_norm,\n",
        "      B.init((C.n,)),\n",
        "      jnp.arange(C.n),\n",
        "      kind='checkpointed',\n",
        "      checkpoints=checkpoints,\n",
        "  )[1]\n",
        "  # We scale by n so this is the MSE (and we get RMSE by taking sqrt).\n",
        "  return row_norms_squared.sum() / C.n\n",
        "\n",
        "\n",
        "\n",
        "def optimize(\n",
        "    C0: Parameterization,\n",
        "    A: StreamingLinearOperator,\n",
        "    maxiter: int = 100,\n",
        "    method = DEFAULT_LBFGS_OPTIMIZER,\n",
        "    callback: Callable[[jaxopt.OptStep], None] = lambda _: None,\n",
        ") -> Parameterization:\n",
        "  \"\"\"Optimize the strategy using a gradient-based method.\n",
        "\n",
        "  Note that the default optimization method is LBFGS. Due to non-convexity\n",
        "  in this formulation of the optimization problem, it is possible that simple\n",
        "  GradientDescent may work better, but that method generally requires\n",
        "  some level of learning rate tuning. LBFGS has been observed to work pretty\n",
        "  well in our experiments, but if you notice nonconvergence issues, consider\n",
        "  switching the method to jaxopt.GradientDescent.\n",
        "\n",
        "  Args:\n",
        "    C0: The initial encoder to be optimized.\n",
        "    A: The target workload.\n",
        "    maxiter: The maximum number of iterations to optimize for.\n",
        "    use_implicit_grad: Flag to calculate loss+gradient implicitly.\n",
        "    method: A jaxopt optimizer (e.g., LBFGS or GradientDescent).\n",
        "    callback: A function to call after each optimizer step.\n",
        "\n",
        "  Returns:\n",
        "    An optimized encoder having the same structure as C0.\n",
        "  \"\"\"\n",
        "\n",
        "  fun = jax.value_and_grad(implicit_loss)\n",
        "  robust_fun = _robustify(functools.partial(fun, A=A))\n",
        "\n",
        "  optimizer = method(fun=robust_fun, value_and_grad=True, maxiter=maxiter)\n",
        "  step = jaxopt.OptStep(params=C0, state=optimizer.init_state(C0))\n",
        "  for _ in range(maxiter):\n",
        "    step = optimizer.update(step.params, step.state)\n",
        "    callback(step)\n",
        "  return step.params\n",
        "\n",
        "Copt = optimize(C, A)\n",
        "# The optimizd strategy has slightly lower error than the default one.\n",
        "print(implicit_loss(C, A), implicit_loss(Copt, A))\n",
        "Copt.materialize()"
      ],
      "metadata": {
        "id": "mmcWnVU4-W6M",
        "executionInfo": {
          "status": "ok",
          "timestamp": 1723236383953,
          "user_tz": 420,
          "elapsed": 3050,
          "user": {
            "displayName": "Ryan McKenna",
            "userId": "14049140147394149988"
          }
        },
        "outputId": "ae0c198f-0b06-44b1-d435-be074cf82e2a"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "2.1914316581836113 2.1672810387539343\n"
          ]
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "Array([[0.83634719, 0.        , 0.        , 0.        , 0.        ],\n",
              "       [0.54820013, 0.9245377 , 0.        , 0.        , 0.        ],\n",
              "       [0.        , 0.3810906 , 0.90304796, 0.        , 0.        ],\n",
              "       [0.        , 0.        , 0.42953973, 0.94568753, 0.        ],\n",
              "       [0.        , 0.        , 0.        , 0.32507707, 1.        ]],      dtype=float64)"
            ]
          },
          "metadata": {},
          "execution_count": 47
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Banded Toeplitz Optimization"
      ],
      "metadata": {
        "id": "zirvjq4lgxR6"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def solve_banded_toeplitz(coef: jnp.ndarray, rhs: jnp.ndarray) -> jnp.ndarray:\n",
        "  \"\"\"Solve the linear system T_{coef} x = rhs for x for Toeplitz matrix T.\n",
        "\n",
        "  Specifically, T_{coef} is a lower triangular banded Toeplitz matrix.\n",
        "\n",
        "  Note we want to be able to back-propagate gradients through this function,\n",
        "  hence we cannot use scipy.linalg.solve_toeplitz.\n",
        "\n",
        "  Example: coef = [a, b, c], rhs = [1, 1, 1, 1, 1, 1], we solve the following\n",
        "  system for x\n",
        "  ```\n",
        "  [a 0 0 0 0 0] [x_0]   [1]\n",
        "  [b a 0 0 0 0] [x_1]   [1]\n",
        "  [c b a 0 0 0] [x_2] = [1]\n",
        "  [0 c b a 0 0] [x_3]   [1]\n",
        "  [0 0 c b a 0] [x_4]   [1]\n",
        "  [0 0 0 c b a] [x_5]   [1]\n",
        "  ```\n",
        "\n",
        "  Args:\n",
        "    coef: The nonzero coefficients of a lower-triangular Toeplitz matrix C, that\n",
        "      is, `coef` are the leading nonzero entries of C[:, 0]. C is of size n x n\n",
        "      where n = len(rhs) (see below); if len(coef) < n, the remaining\n",
        "      coefficients are assumed to be zero. If len(coef) > n, then only the first\n",
        "      n coefficients are used.\n",
        "    rhs: The right hand side vector, of length `n`.\n",
        "\n",
        "  Returns:\n",
        "    The solution to the linear system Toeplitz(coef, n) x = rhs.\n",
        "  \"\"\"\n",
        "  if coef.shape[0] == 1:\n",
        "    return rhs / coef[0]\n",
        "\n",
        "  def next_fn(state, value):\n",
        "    result = (value - coef[1:] @ state) / coef[0]\n",
        "    state = state.at[1:].set(state[:-1]).at[0].set(result)\n",
        "    return state, result\n",
        "\n",
        "  return jax.lax.scan(next_fn, jnp.zeros(coef.shape[0] - 1), rhs)[1]\n",
        "\n",
        "\n",
        "def toeplitz_prefix_mse(coef: jnp.ndarray, n: int):\n",
        "  \"\"\"Mean squared error for a banded Toeplitz C matrix.\"\"\"\n",
        "  sensitivity_squared = coef @ coef\n",
        "  A_coef = jnp.ones(n)\n",
        "  A_Cinv_coef = solve_banded_toeplitz(coef, A_coef)\n",
        "  frobenius_norm_squared = (jnp.arange(n)[::-1] + 1) @ A_Cinv_coef**2\n",
        "  return sensitivity_squared * frobenius_norm_squared / n\n",
        "\n",
        "\n",
        "def optimize_banded_toeplitz(\n",
        "    n: int,\n",
        "    bands: int,\n",
        "    maxiter: int = 1000,\n",
        "    method = DEFAULT_LBFGS_OPTIMIZER,\n",
        "    loss_fn: Callable[[jnp.ndarray, int], jnp.ndarray] = toeplitz_prefix_mse,\n",
        ") -> Any:\n",
        "  \"\"\"Optimize over the space of banded Toeplitz strategies on a Prefix workload.\n",
        "\n",
        "  This optimization problem is solved assuming a single participation.\n",
        "  The generated matrices can also be used in the multi-epoch setting\n",
        "  (both fixed_epoch_order and min_sep) as long as bands <= n / epochs.\n",
        "  See https://arxiv.org/abs/2306.08153 for more details.\n",
        "\n",
        "  Args:\n",
        "    n: the number of iterations that defines the workload.\n",
        "    bands: The number of bands in the Toeplitz matrix.\n",
        "    maxiter: The maximum number of LBFGS iterations.\n",
        "    method: A jaxopt optimizer (e.g., LBFGS or GradientDescent).\n",
        "    loss_fn: The loss function to use (e.g., mean_loss or max_loss). Should\n",
        "      consume `coefs` with len(coefs) == bands and `n` as arguments.\n",
        "\n",
        "  Returns:\n",
        "    The result of the optimization, including the optimal parameters and\n",
        "    convergence metrics.\n",
        "  \"\"\"\n",
        "  loss_and_grad = _robustify(\n",
        "      jax.value_and_grad(functools.partial(loss_fn, n=n))\n",
        "  )\n",
        "  init = default_toeplitz_coefficients(n)[:bands]\n",
        "\n",
        "  result = method(fun=loss_and_grad, value_and_grad=True, maxiter=maxiter)\n",
        "  return result.run(init).params\n",
        "\n",
        "n = 5\n",
        "bands = 2\n",
        "coef0 = default_toeplitz_coefficients(n)[:bands]\n",
        "coef_opt = optimize_banded_toeplitz(n, bands)\n",
        "print(toeplitz_prefix_mse(coef0, n), toeplitz_prefix_mse(coef_opt, n))"
      ],
      "metadata": {
        "id": "3jiUWtSF-jVl",
        "executionInfo": {
          "status": "ok",
          "timestamp": 1723236385899,
          "user_tz": 420,
          "elapsed": 1319,
          "user": {
            "displayName": "Ryan McKenna",
            "userId": "14049140147394149988"
          }
        },
        "outputId": "107e32bf-b408-43b1-950f-2acaf80dbf12"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "2.2353515625 2.2307688371924232\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Scalability Experiments"
      ],
      "metadata": {
        "id": "EpiMsPKyqH2L"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Fixed n, vary b"
      ],
      "metadata": {
        "id": "IxczSfAivxOd"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "n = 2**14\n",
        "toeplitz_results_b = {}\n",
        "\n",
        "for b in [2**k for k in range(1,15)]:\n",
        "  coef = jnp.linspace(1, 0, b)\n",
        "  toeplitz_loss = functools.partial(toeplitz_prefix_mse, n=n)\n",
        "  toeplitz_grad = jax.grad(toeplitz_loss)\n",
        "  _, _ = toeplitz_loss(coef), toeplitz_grad(coef)  # jit compile\n",
        "\n",
        "  t0 = time.time()\n",
        "  loss = toeplitz_loss(coef)\n",
        "  t1 = time.time()\n",
        "  grad = toeplitz_grad(coef)\n",
        "  t2 = time.time()\n",
        "  toeplitz_results_b[b] = (t1-t0, t2-t1)\n",
        "  print(b, t1-t0, t2-t1)\n",
        "  if t2 - t0 > 60:\n",
        "    break"
      ],
      "metadata": {
        "id": "DIWJLCICqGx2",
        "executionInfo": {
          "status": "ok",
          "timestamp": 1723236409101,
          "user_tz": 420,
          "elapsed": 22953,
          "user": {
            "displayName": "Ryan McKenna",
            "userId": "14049140147394149988"
          }
        },
        "outputId": "e2630409-0c35-4c4b-c366-263a4bebc0bf"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "2 0.0524749755859375 0.15679287910461426\n",
            "4 0.0811927318572998 0.21791434288024902\n",
            "8 0.05785560607910156 0.21265006065368652\n",
            "16 0.07308506965637207 0.2609875202178955\n",
            "32 0.07757091522216797 0.34455299377441406\n",
            "64 0.09162354469299316 0.33848118782043457\n",
            "128 0.09942889213562012 0.39403700828552246\n",
            "256 0.09731745719909668 0.32970476150512695\n",
            "512 0.07697510719299316 0.3107779026031494\n",
            "1024 0.10648727416992188 0.3816838264465332\n",
            "2048 0.1475086212158203 0.5264036655426025\n",
            "4096 0.2867777347564697 0.8005540370941162\n",
            "8192 0.5778720378875732 1.4547457695007324\n",
            "16384 1.0360898971557617 2.731982469558716\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Best to run this snippet on GPU\n",
        "implicit_results_b = {}\n",
        "\n",
        "C = ColumnNormalizedBanded.default(n, 1)\n",
        "loss = functools.partial(implicit_loss, A=A)\n",
        "grad = jax.grad(loss)\n",
        "_, _ = loss(C), grad(C)  # jit compile\n",
        "\n",
        "A = prefix_sum()\n",
        "for b in [2**k for k in range(1,15)]:\n",
        "  C = ColumnNormalizedBanded.default(n, b)\n",
        "\n",
        "  t0 = time.time()\n",
        "  _ = loss(C)\n",
        "  t1 = time.time()\n",
        "  _ = grad(C)\n",
        "  t2 = time.time()\n",
        "  implicit_results_b[b] = (t1-t0, t2-t1)\n",
        "  print(b, t1-t0, t2-t1)\n",
        "  if t2 - t0 > 60:\n",
        "    break"
      ],
      "metadata": {
        "id": "kXUR1484v8UV",
        "executionInfo": {
          "status": "ok",
          "timestamp": 1723236540556,
          "user_tz": 420,
          "elapsed": 131321,
          "user": {
            "displayName": "Ryan McKenna",
            "userId": "14049140147394149988"
          }
        },
        "outputId": "8181a94a-947e-4490-a3ae-ac8d80ead66a"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "2 1.0360057353973389 15.756796836853027\n",
            "4 3.1636645793914795 29.963915586471558\n",
            "8 3.5404155254364014 75.69868350028992\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Fix b, vary n"
      ],
      "metadata": {
        "id": "z_Dl8Kchzh0g"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Fix n, vary b\n",
        "n = 2**14\n",
        "toeplitz_results_b = {}\n",
        "\n",
        "for b in [2**k for k in range(1,15)]:\n",
        "  coef = jnp.linspace(1, 0, b)\n",
        "  toeplitz_loss = functools.partial(toeplitz_prefix_mse, n=n)\n",
        "  toeplitz_grad = jax.grad(toeplitz_loss)\n",
        "  _, _ = toeplitz_loss(coef), toeplitz_grad(coef)  # jit compile\n",
        "\n",
        "  t0 = time.time()\n",
        "  loss = toeplitz_loss(coef)\n",
        "  t1 = time.time()\n",
        "  grad = toeplitz_grad(coef)\n",
        "  t2 = time.time()\n",
        "  toeplitz_results_b[b] = (t1-t0, t2-t1)\n",
        "  print(b, t1-t0, t2-t1)\n",
        "  if t2 - t0 > 60:\n",
        "    break"
      ],
      "metadata": {
        "id": "nVLwSqIGzoKs",
        "executionInfo": {
          "status": "ok",
          "timestamp": 1723236563528,
          "user_tz": 420,
          "elapsed": 22832,
          "user": {
            "displayName": "Ryan McKenna",
            "userId": "14049140147394149988"
          }
        },
        "outputId": "aa91646e-d9b5-4ee5-ea48-a50ab700224a"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "2 0.05277872085571289 0.16751694679260254\n",
            "4 0.061669111251831055 0.2167205810546875\n",
            "8 0.06015443801879883 0.23075389862060547\n",
            "16 0.06358504295349121 0.2418069839477539\n",
            "32 0.06878519058227539 0.24698805809020996\n",
            "64 0.06042075157165527 0.2347257137298584\n",
            "128 0.0606541633605957 0.24430155754089355\n",
            "256 0.06907343864440918 0.25476646423339844\n",
            "512 0.07537078857421875 0.29889512062072754\n",
            "1024 0.10672283172607422 0.38060855865478516\n",
            "2048 0.2035684585571289 0.6740512847900391\n",
            "4096 0.30074191093444824 0.844667911529541\n",
            "8192 0.5080418586730957 1.3705871105194092\n",
            "16384 1.015289068222046 2.616572856903076\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "implicit_results = {}\n",
        "A = prefix_sum()\n",
        "b = 16\n",
        "for n in [2**k for k in range(4, 18)]:\n",
        "  C = ColumnNormalizedBanded.default(n, b)\n",
        "  loss = functools.partial(implicit_loss, A=A)\n",
        "  grad = jax.grad(loss)\n",
        "  _, _ = loss(C), grad(C)  # jit compile\n",
        "\n",
        "  t0 = time.time()\n",
        "  _ = loss(C)\n",
        "  t1 = time.time()\n",
        "  _ = grad(C)\n",
        "  t2 = time.time()\n",
        "  implicit_results[n] = (t1-t0, t2-t1)\n",
        "  print(n, t1-t0, t2-t1)\n",
        "  if t2 - t0 > 60:\n",
        "    break"
      ],
      "metadata": {
        "id": "8tlrVEstqeNh",
        "executionInfo": {
          "status": "ok",
          "timestamp": 1723236958809,
          "user_tz": 420,
          "elapsed": 395141,
          "user": {
            "displayName": "Ryan McKenna",
            "userId": "14049140147394149988"
          }
        },
        "outputId": "9c7e9161-c3ff-44eb-c546-25ce25a5fa41"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "16 0.18504810333251953 0.8066926002502441\n",
            "32 0.22147727012634277 0.9098904132843018\n",
            "64 0.2816617488861084 1.0883243083953857\n",
            "128 0.2036426067352295 0.8235890865325928\n",
            "256 0.19147944450378418 0.8278636932373047\n",
            "512 0.19395971298217773 0.8496277332305908\n",
            "1024 0.23160147666931152 1.500481367111206\n",
            "2048 0.32336997985839844 2.889063596725464\n",
            "4096 0.6319406032562256 7.7263243198394775\n",
            "8192 1.5835483074188232 22.88981580734253\n",
            "16384 5.810267210006714 146.3994584083557\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Solution Quality"
      ],
      "metadata": {
        "id": "v2-xrS-u3oxC"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# n = 1024, vary b\n",
        "n = 1024\n",
        "A = jnp.tri(n)\n",
        "\n",
        "for b in [2**k for k in range(2)]:\n",
        "  dense_mf = MatrixFactorizer(n, bands=b)\n",
        "  X = dense_mf.optimize(A)\n",
        "\n",
        "  rmse = jnp.sqrt(dense_mf.loss_and_gradient(X, A))\n",
        "  print(rmse)\n",
        "\n",
        "\n"
      ],
      "metadata": {
        "colab": {
          "height": 397
        },
        "id": "MjVilaso3wgy",
        "executionInfo": {
          "status": "error",
          "timestamp": 1723236230959,
          "user_tz": 420,
          "elapsed": 82468,
          "user": {
            "displayName": "Ryan McKenna",
            "userId": "14049140147394149988"
          }
        },
        "outputId": "d3403c3d-5a42-45f6-b9b5-055290e85124"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "error",
          "ename": "KeyboardInterrupt",
          "evalue": "",
          "traceback": [
            "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
            "\u001b[1;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
            "\u001b[1;32m<ipython-input-39-e7d6b460a45b>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m()\u001b[0m\n\u001b[0;32m      5\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mb\u001b[0m \u001b[1;32min\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;36m2\u001b[0m\u001b[1;33m**\u001b[0m\u001b[0mk\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mk\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m2\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      6\u001b[0m   \u001b[0mdense_mf\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mMatrixFactorizer\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mn\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbands\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mb\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 7\u001b[1;33m   \u001b[0mX\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mdense_mf\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0moptimize\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mA\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m      8\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      9\u001b[0m   \u001b[0mrmse\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mjnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msqrt\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdense_mf\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mloss_and_gradient\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mX\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mA\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
            "\u001b[1;32m<ipython-input-28-142dd7ab4ca6>\u001b[0m in \u001b[0;36moptimize\u001b[1;34m(self, A, iters, metric_callback, initial_X)\u001b[0m\n\u001b[0;32m    268\u001b[0m       \u001b[1;32mfor\u001b[0m \u001b[0m_\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m30\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    269\u001b[0m         \u001b[0mX\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mX1\u001b[0m \u001b[1;33m-\u001b[0m \u001b[0mstep_size\u001b[0m \u001b[1;33m*\u001b[0m \u001b[0mZ\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 270\u001b[1;33m         \u001b[0mloss\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdX\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mloss_and_gradient\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mX\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mA\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    271\u001b[0m         \u001b[1;32mif\u001b[0m \u001b[0mjnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0misnan\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mloss\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0many\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mor\u001b[0m \u001b[0mjnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0misnan\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdX\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0many\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    272\u001b[0m           \u001b[0mstep_size\u001b[0m \u001b[1;33m*=\u001b[0m \u001b[1;36m0.25\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
            "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
          ],
          "debug": {
            "argv": [
              "/export/hda3/borglet/remote_hdd_fs_dirs/0.colab_kernel_brain_frameworks_cpu_mckennar.kernel.mckennar.4430272851038.14b334fb3717c109/mount/server/ml_notebook",
              "kernel",
              "-f",
              "/tmp/ipy-be-pxihz9g8/profile_colab/security/kernel-ca76d1ce-5200-40ac-8298-ac4c2ff2ddc4.json",
              "--profile-dir",
              "/tmp/ipy-be-pxihz9g8/profile_colab",
              "--profile=colab",
              "--ipython-dir=/tmp/ipy-be-jlvx14_2",
              "--no-secure"
            ],
            "build": "Built on Tue Jul 16 10:35:58 2024 (1721151358)\nBuilt by brain-frameworks-releaser@jjbme2.prod.google.com:/google/src/cloud/buildrabbit-username/buildrabbit-client/google3\nBuilt as //learning/grp/tools/ml_python:ml_notebook\nBuild ID: d8cfe869-8fe3-47bc-a3e6-c4487a378da0\nBuilt from changelist 652854365 in a mint client based on //depot/google3\nBuild label: ml_notebook_2024-07-16-10_00_RC00\nBuild platform: gcc-4.X.Y-crosstool-v18-llvm-grtev4-k8\nBuild tool: Blaze, release blaze-2024.07.10-2 (mainline @650821060)\nBuilt with par options [\"--compress\", \"--compress_level=6\", \"--compress\"]\nCurrently running under Python 3.11.8: embedded.\n",
            "user": "mckennar"
          }
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "q7LRomz24OOd"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}