{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "from jax import lax\n",
    "from functools import partial"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "((Array(45, dtype=int32), Array(1, dtype=int32, weak_type=True)), Array([ 0,  1,  3,  6, 10, 15, 21, 28, 36, 45], dtype=int32))\n"
     ]
    }
   ],
   "source": [
    "def f(carry, x):\n",
    "    return (carry[0] + x, 1), carry[0] + x\n",
    "print(lax.scan(f, (0,0), jnp.arange(10)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(Array([[1024., 1024.],\n",
      "       [1024., 1024.]], dtype=float32), Array([   4.0000005,    8.000001 ,   16.000002 ,   32.000004 ,\n",
      "         64.00001  ,  128.00002  ,  256.00003  ,  512.00006  ,\n",
      "       1024.0001   , 2048.0002   ], dtype=float32))\n"
     ]
    }
   ],
   "source": [
    "def f_pow(A_k, k, A):\n",
    "    return A_k @ A, jnp.linalg.norm(A_k @ A, ord=2)\n",
    "\n",
    "print(lax.scan(partial(f_pow, A=jnp.ones((2,2))), jnp.ones((2,2)), jnp.arange(10)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "C = jnp.ones((2,2))\n",
    "A = jnp.ones((2,2))\n",
    "B = jnp.ones((2,2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def f(A_prev, k, A, B, C, ord, pow):\n",
    "    A_k = A @ A_prev\n",
    "    norm = jnp.linalg.norm(C @ A_k @ B, ord=ord)**pow\n",
    "    return A_k, norm\n",
    "ones = jnp.ones((2,2))\n",
    "_, traj = lax.scan(partial(f, A=ones, B=ones, C=ones, ord=2, pow=2), init=jnp.ones((2,2)), xs=jnp.arange(10))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array(89478416., dtype=float32)"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "jnp.sum(traj)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
