{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0b864c9f",
   "metadata": {},
   "source": [
    "# Training and evaluating a PPO agent on a CartPole-v1 SCB"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "0117ab49",
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax\n",
    "import chex\n",
    "import optax\n",
    "import gymnax\n",
    "import distrax\n",
    "import numpy as np\n",
    "from jax import numpy as jnp\n",
    "from functools import partial\n",
    "from flax.training.train_state import TrainState\n",
    "from flax import serialization, struct, linen as nn\n",
    "from flax.linen.initializers import constant, orthogonal\n",
    "from typing import Sequence, Callable, Sequence, NamedTuple, Any, Tuple, Union, Optional"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b82a1f2e",
   "metadata": {},
   "source": [
    "## Environment setup\n",
    "Since the checkpoint is pretty small, we include it in the notebook and load it later.\n",
    "For the synthetic environment, we define a neural network, as well as a light-weight wrapper to convert it to a gymnax (Robert Lange, https://github.com/RobertTLange/gymnax) environment. This way we can use it with an existing implementation of PPO."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "c0ccd1f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "state_dict={\"network_params\":{\"params\":{\"initial_state\":{\"Dense_0\":{\"bias\":jnp.array([1.26869515e-01,1.41259685e-01,5.43528274e-02,1.15194455e-01,1.37809336e-01,-8.72058421e-03,7.89055824e-02,-1.99180320e-01,5.40970191e-02,-1.58668771e-01,-7.18499571e-02,7.40191340e-02,-5.22077642e-03,1.69407353e-01,-1.63449153e-01,-8.64701793e-02,1.33883998e-01,5.43285348e-02,9.15792733e-02,1.67521797e-02,-5.08493260e-02,-1.16616540e-01,-1.93413608e-02,2.39237607e-01,1.01968758e-01,9.53310281e-02,-1.04294889e-01,-1.03872642e-01,-1.04525527e-02,-3.95094976e-05,-6.64104968e-02,-1.17814451e-01]),\"kernel\":jnp.array([[0.8846168,0.13692066,0.50594884,0.14921904,-0.14166242,-0.65235424,-0.36718902,-0.10609503,-0.07121946,-0.6565477,-0.67549443,-0.2454942,0.47280326,0.45149615,0.5787622,-0.33040312,0.08890592,-0.9357842,0.06001355,-1.1585295,0.5899545,-0.13386008,-0.29321316,-1.1517092,-0.6731166,0.07064947,0.01272697,0.0737064,-0.23607291,-0.01386749,0.5704219,-0.23453556],[0.14356384,-0.16271158,-0.22376521,0.6414362,-0.45695016,0.37988603,0.00965512,-0.59674424,0.50628364,0.8395857,0.27464223,-0.38726056,-0.3193502,0.994187,-0.06332397,-0.5112642,0.278541,-0.8270671,0.13532868,-0.21969965,0.21553119,-0.75505424,-0.14268877,-0.14371847,0.42653775,0.09487649,0.61873585,0.2655668,-0.17130065,0.6027759,0.27072605,-0.09727632],[0.71527255,-1.0190362,-0.4969015,0.02646949,0.45954618,-0.57374597,-0.16900837,0.20839304,0.7320672,0.23996888,-0.38599852,-0.36110643,0.35913327,0.32669476,-0.56855094,0.13774277,0.61382556,-0.4893724,0.7475364,-0.32178655,0.45245612,-0.29185754,0.14671113,0.76569647,0.4284497,-0.39157662,0.42457896,-0.42482936,0.49772158,0.603451,0.34575403,-0.48154423],[-0.15213561,-0.31649578,-0.4752063,0.09686822,-0.87038094,-0.4419285,-0.10130058,0.56353116,0.8133533,0.26534626,-0.4635922,-0.25251764,-0.34841132,-0.44593742,0.02763209,-0.0444576,-0.2943847,0.20854786,-0.6691867,-0.17127077,0.33252814,0.53534824,-0.5804024,0.85732436,0.16224997,-0.17276968,-0.09949151,0.90296096,-0.3980796,-0.13659588,-0.21636754,-0.23377497]])},\"Dense_1\":{\"bias\":jnp.array([-0.23853292,-0.20640945,0.12088947,0.0406028]),\"kernel\":jnp.array([[0.06586141,-0.20556618,-0.07747183,0.20390964],[-0.12181008,-0.3436964,0.30577776,0.08894946],[0.43202174,0.05864399,-0.16084975,-0.07460365],[-0.08624116,-0.13685812,-0.13052797,-0.22605382],[-0.21751063,-0.27933103,-0.09660649,0.09112976],[0.1193767,-0.03688768,0.36235967,0.1014033],[-0.21952452,0.33819488,0.11265649,0.00769656],[0.23252207,-0.19156662,-0.12310604,0.14480147],[-0.19720016,-0.04912206,-0.15244123,0.11449951],[-0.21839556,0.5338239,-0.01835317,-0.05727997],[-0.15315637,-0.16337974,-0.10488401,0.21952122],[0.39603424,-0.26559308,-0.2598905,-0.08239831],[-0.12711985,-0.23589969,0.02675047,-0.20909514],[-0.29801863,0.08326443,-0.07796545,0.24707524],[0.0743818,-0.29463637,0.03389455,-0.09728192],[0.19888052,0.10196938,-0.09099034,0.06020924],[0.06921208,0.05015983,0.3570317,-0.2957178],[-0.08112952,-0.08914044,-0.0372033,-0.09259935],[-0.25601706,0.11024189,-0.06126438,-0.05513002],[-0.2238151,-0.3609388,-0.1398257,-0.09238093],[-0.08744401,-0.29385138,0.28948435,-0.25384825],[-0.251413,-0.05812272,0.2759809,0.01767877],[-0.23875394,0.03044049,0.11000813,0.02629472],[-0.22074798,0.20562223,-0.02642709,0.10477567],[-0.03424083,0.14251348,0.11503636,-0.23182449],[0.00481565,-0.44274,-0.06756,0.09635373],[0.01311758,0.16591552,-0.09169737,0.05384627],[-0.32783306,-0.1564861,-0.08647604,-0.1260688],[-0.00060268,-0.08288515,0.15344582,-0.00333162],[-0.06070196,-0.33916306,-0.08288447,0.12863146],[-0.29709512,-0.05443503,0.2524976,0.15655148],[-0.11524608,0.14356235,-0.10542386,-0.09798729]])}},\"reward\":{\"Dense_0\":{\"bias\":jnp.array([-0.04167449,-0.07238774,-0.0096982,-0.16192774,-0.2881943,-0.06213038,0.03506619,-0.06296758,-0.02475221,0.11831518,-0.04385006,0.0393946,-0.18567328,-0.06754404,0.00148166,0.08987777,-0.1324569,0.43911675,-0.04311865,-0.00372663,0.09977508,0.01560554,0.03539776,0.03180509,0.08150925,-0.0201122,0.08423095,-0.14955288,0.00356325,-0.03897597,0.09318526,-0.005307]),\"kernel\":jnp.array([[-2.4536091e-01,-8.0870318e-01,3.3167094e-01,2.3530486e-01,2.6437560e-01,-8.8771588e-01,-1.0929084e+00,-2.5805163e-01,-3.0810103e-01,3.8640857e-01,3.5454881e-01,-3.1734530e-02,4.0621221e-02,2.4050610e-02,8.1039466e-02,7.1752167e-01,-6.1280094e-02,1.3606001e-01,-2.2257902e-01,2.2453496e-01,3.0671027e-02,7.4387938e-02,1.6392831e-02,-1.4216736e-01,4.3170811e-03,5.6454962e-01,4.5116369e-02,9.8077953e-03,-4.5192521e-02,1.3466522e-01,3.7277527e-02,1.3829154e-01],[1.9400182e-01,-3.1993192e-01,2.6569979e-02,2.7178788e-01,-5.2814597e-01,-4.5833400e-01,-2.7351955e-01,6.3416243e-01,4.9927495e-02,-1.0000003e-01,-1.4595738e-01,2.9583827e-01,-4.4046786e-02,-4.6217285e-02,1.6516691e-01,-7.0881766e-01,1.5911405e-01,-5.3984519e-02,-1.8340956e-01,-3.2576588e-01,-4.3285894e-01,2.3185222e-01,2.0340200e-01,-5.3867114e-01,3.7474450e-01,-5.1858658e-01,-5.7007395e-02,-3.1988436e-01,5.7587802e-01,1.6344950e-01,4.4552717e-01,1.7630380e-01],[-1.4233860e-01,1.7829532e-02,-1.0612557e+00,-7.9193342e-01,1.9613400e-02,2.8794643e-02,-9.2248458e-01,-2.7885491e-01,-8.3697006e-02,9.7373307e-02,-4.0297535e-01,4.2252842e-01,-1.1134618e-01,-7.5674150e-05,-6.8942986e-02,3.5175362e-01,4.7212580e-01,4.1825497e-01,-7.5464617e-03,4.9670687e-01,-3.3075050e-01,-5.9676951e-01,-5.6231016e-01,-1.9755715e-01,3.9418671e-01,-7.3708993e-01,-4.6680066e-01,4.2716280e-01,-2.0219013e-02,-6.9139689e-02,7.4614488e-02,9.6012127e-01],[7.9288311e-02,3.1390530e-01,-6.6116011e-01,6.4054120e-01,8.6064953e-01,-1.7447892e-01,4.6767259e-01,-1.0840139e-01,-3.8617238e-01,8.8799989e-01,-2.9928008e-01,-8.0398631e-01,5.8623767e-01,-4.2677864e-01,-4.5739237e-02,-1.4057746e-02,3.3015648e-01,-1.8960208e-02,3.1241691e-01,-3.5168323e-01,5.0410533e-01,-4.9086684e-01,-4.0347627e-01,-9.6256810e-01,7.3750234e-01,-2.6644869e-02,-4.1501999e-01,-3.8961864e-01,8.0006999e-01,-2.6040918e-01,2.6258588e-01,1.7658389e-01],[2.7756354e-01,-4.7244169e-02,-3.7886355e-02,3.7895960e-01,7.2201604e-01,1.6328713e-01,-5.1706457e-01,-3.3519265e-01,-4.6358889e-01,-3.6470577e-01,-4.4690678e-01,-8.3203256e-02,3.6670938e-02,-3.5881382e-01,7.9708505e-01,-4.8750471e-03,-1.8828426e-01,6.0292047e-01,7.9636678e-02,4.4521722e-01,-6.3242525e-01,2.1781521e-01,-2.6042330e-01,1.1783392e-01,2.4848078e-01,3.1149933e-01,-2.3994823e-01,5.1283771e-01,4.4937325e-03,-8.7292272e-01,-2.1010350e-02,4.0797612e-01],[5.7491571e-01,1.1168423e-01,7.5962961e-01,2.1093254e-01,7.9654664e-02,-4.5237586e-01,-2.6296297e-01,9.8921865e-02,5.1479787e-01,1.3980688e-01,7.5591260e-01,2.2263643e-01,-1.6671263e-02,-4.1758355e-01,3.1558576e-01,4.9016047e-01,6.3634598e-01,5.1799482e-01,2.0280409e-01,1.3416648e-01,2.8802583e-01,1.2608196e-01,-4.2647141e-01,-5.6586492e-01,-6.1420959e-01,4.7620454e-01,2.4228865e-01,4.4928291e-01,1.5510845e-01,-8.2135014e-02,1.7583951e-01,-1.9948937e-01]])},\"Dense_1\":{\"bias\":jnp.array([0.31663114]),\"kernel\":jnp.array([[0.00961514],[0.17498876],[-0.68293875],[-0.12586214],[-0.22713538],[-0.15064897],[-0.2100513],[-0.43393072],[-0.26706964],[0.27843222],[-0.24953763],[-0.03312748],[0.07204171],[-0.11664964],[0.0068073],[0.08322901],[0.3876118],[-0.12667817],[0.08091829],[-0.19998641],[-0.07741469],[0.2848988],[0.40409082],[0.8589095],[-0.9778193],[0.03011897],[-0.24080771],[-0.2940478],[0.04018991],[-0.04501918],[0.18329582],[-0.7836947]])}}}}}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "06b8da77",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MLP(nn.Module):\n",
    "    \"\"\"Blatantly stolen from https://github.com/google/flax\"\"\"\n",
    "\n",
    "    features: Sequence[int]\n",
    "    activation: Callable = nn.tanh\n",
    "\n",
    "    @nn.compact\n",
    "    def __call__(self, x):\n",
    "        for feat in self.features[:-1]:\n",
    "            x = self.activation(nn.Dense(feat)(x))\n",
    "        x = nn.Dense(self.features[-1])(x)\n",
    "        return x\n",
    "\n",
    "    \n",
    "class SynthEnvMLP(nn.Module):\n",
    "    state_size: int\n",
    "    latent_dist: distrax.Distribution\n",
    "    features: Sequence[int] = (32,)\n",
    "    activation: Callable = nn.relu\n",
    "\n",
    "    def setup(self):\n",
    "        self.initial_state = MLP([*self.features, self.state_size], self.activation)\n",
    "        self.reward = MLP([*self.features, 1], self.activation)\n",
    "        self.next_state_delta = MLP([*self.features, self.state_size], self.activation)\n",
    "        self.done = MLP([*self.features, 1], self.activation)\n",
    "\n",
    "    def __call__(self, rng, state, action, only_reward=False):\n",
    "        return *self.reset(rng), *self.step(state, action)\n",
    "\n",
    "    def reset(self, rng, sample_shape=(1,)):\n",
    "        batch_size = sample_shape[0]\n",
    "        z = self.latent_dist.sample(seed=rng, sample_shape=sample_shape)\n",
    "        x = self.initial_state(z)\n",
    "        return x\n",
    "\n",
    "    def step(self, state, action, only_reward=False):\n",
    "        batch_size = state.shape[0]\n",
    "        x = jnp.hstack([state.reshape(batch_size, -1), action.reshape(batch_size, -1)])\n",
    "        reward = self.reward(x)\n",
    "        return reward.squeeze(-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "e265edf2",
   "metadata": {},
   "outputs": [],
   "source": [
    "@struct.dataclass\n",
    "class SynthEnvParams:\n",
    "    network_params: chex.ArrayTree\n",
    "\n",
    "@struct.dataclass\n",
    "class SynthEnvState:\n",
    "    obs: chex.Array\n",
    "\n",
    "\n",
    "class SynthEnv(gymnax.environments.environment.Environment):\n",
    "    \"\"\"Uses SynthEnvMLP which takes in obs (or vectorized state) and action\"\"\"\n",
    "\n",
    "    def __init__(self, target_env):\n",
    "        super().__init__()\n",
    "        self.target_env, self.target_env_params = gymnax.make(target_env)\n",
    "        self.obs_size = np.prod(self.target_env.observation_space(self.target_env_params).shape)\n",
    "        if isinstance(self.action_space(self.target_env_params), gymnax.environments.spaces.Discrete):\n",
    "            self.action_size = self.target_env.action_space(self.target_env_params).n\n",
    "        else:\n",
    "            self.action_size = np.prod(self.target_env.action_space(self.target_env_params).shape)\n",
    "        latent_dist = distrax.MultivariateNormalDiag(\n",
    "            loc=jnp.zeros(self.obs_size), scale_diag=jnp.ones(self.obs_size)\n",
    "        )\n",
    "        self.network = SynthEnvMLP(self.obs_size, latent_dist)\n",
    "\n",
    "    @property\n",
    "    def default_params(self):\n",
    "        # Default params are just the neural network initialization at seed 0\n",
    "        init_state = jnp.zeros((1, self.obs_size))\n",
    "        init_action = jnp.zeros((1, self.action_size))\n",
    "        params = self.network.init(\n",
    "            jax.random.PRNGKey(0),\n",
    "            jax.random.PRNGKey(0),\n",
    "            init_state,\n",
    "            init_action,\n",
    "        )\n",
    "        return SynthEnvParams(params)\n",
    "\n",
    "    def step_env(self, key, state: SynthEnvState, action, params: SynthEnvParams):\n",
    "        if isinstance(self.action_space(params), gymnax.environments.spaces.Discrete):\n",
    "            action = jax.nn.one_hot(action, self.action_size)\n",
    "\n",
    "        reward = self.network.apply(\n",
    "            params.network_params,\n",
    "            jnp.expand_dims(state.obs, 0),\n",
    "            jnp.expand_dims(action, 0),\n",
    "            method=\"step\",\n",
    "        )\n",
    "\n",
    "        # returned obs doesn't matter due to autoreset\n",
    "        return state.obs, state, reward.squeeze(0), self.is_terminal(state, params), {}\n",
    "\n",
    "    def reset_env(self, key, params, carry=None):\n",
    "        initial_obs = self.network.apply(params.network_params, key, method=\"reset\")\n",
    "        return initial_obs, SynthEnvState(initial_obs)\n",
    "\n",
    "    def get_obs(self, state):\n",
    "        return state.obs\n",
    "\n",
    "    def is_terminal(self, state, params):\n",
    "        return True\n",
    "\n",
    "    def action_space(self, params):\n",
    "        return self.target_env.action_space(self.target_env_params)\n",
    "\n",
    "    def observation_space(self, params):\n",
    "        return self.target_env.observation_space(self.target_env_params)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4df997fa",
   "metadata": {},
   "source": [
    "## PPO implementation\n",
    "\n",
    "For the sake of this notebook, we use a preexisting implementation of PPO from from purejaxrl by Chris Lu (https://github.com/luchris429/purejaxrl). We slightly modified it to train on a synthetic environment, see the first few lines in `make_train`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "e6e68804",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Implementation of PPO \n",
    "# \n",
    "# Slightly modified to train on synthetic environment\n",
    "\n",
    "class GymnaxWrapper(object):\n",
    "    \"\"\"Base class for Gymnax wrappers.\"\"\"\n",
    "\n",
    "    def __init__(self, env):\n",
    "        self._env = env\n",
    "\n",
    "    # provide proxy access to regular attributes of wrapped object\n",
    "    def __getattr__(self, name):\n",
    "        return getattr(self._env, name)\n",
    "\n",
    "\n",
    "class FlattenObservationWrapper(GymnaxWrapper):\n",
    "    \"\"\"Flatten the observations of the environment.\"\"\"\n",
    "\n",
    "    def __init__(self, env: gymnax.environments.environment.Environment):\n",
    "        super().__init__(env)\n",
    "\n",
    "    def observation_space(self, params) -> gymnax.environments.spaces.Box:\n",
    "        assert isinstance(\n",
    "            self._env.observation_space(params), gymnax.environments.spaces.Box\n",
    "        ), \"Only Box spaces are supported for now.\"\n",
    "        return gymnax.environments.spaces.Box(\n",
    "            low=self._env.observation_space(params).low,\n",
    "            high=self._env.observation_space(params).high,\n",
    "            shape=(np.prod(self._env.observation_space(params).shape),),\n",
    "            dtype=self._env.observation_space(params).dtype,\n",
    "        )\n",
    "\n",
    "    @partial(jax.jit, static_argnums=(0,))\n",
    "    def reset(\n",
    "        self, key: chex.PRNGKey, params: Optional[gymnax.environments.environment.EnvParams] = None\n",
    "    ) -> Tuple[chex.Array, gymnax.environments.environment.EnvState]:\n",
    "        obs, state = self._env.reset(key, params)\n",
    "        obs = jnp.reshape(obs, (-1,))\n",
    "        return obs, state\n",
    "\n",
    "    @partial(jax.jit, static_argnums=(0,))\n",
    "    def step(\n",
    "        self,\n",
    "        key: chex.PRNGKey,\n",
    "        state: gymnax.environments.environment.EnvState,\n",
    "        action: Union[int, float],\n",
    "        params: Optional[gymnax.environments.environment.EnvParams] = None,\n",
    "    ) -> Tuple[chex.Array, gymnax.environments.environment.EnvState, float, bool, dict]:\n",
    "        obs, state, reward, done, info = self._env.step(key, state, action, params)\n",
    "        obs = jnp.reshape(obs, (-1,))\n",
    "        return obs, state, reward, done, info\n",
    "\n",
    "\n",
    "@struct.dataclass\n",
    "class LogEnvState:\n",
    "    env_state: gymnax.environments.environment.EnvState\n",
    "    episode_returns: float\n",
    "    episode_lengths: int\n",
    "    returned_episode_returns: float\n",
    "    returned_episode_lengths: int\n",
    "    timestep: int\n",
    "\n",
    "\n",
    "class LogWrapper(GymnaxWrapper):\n",
    "    \"\"\"Log the episode returns and lengths.\"\"\"\n",
    "\n",
    "    def __init__(self, env: gymnax.environments.environment.Environment):\n",
    "        super().__init__(env)\n",
    "\n",
    "    @partial(jax.jit, static_argnums=(0,))\n",
    "    def reset(\n",
    "        self, key: chex.PRNGKey, params: Optional[gymnax.environments.environment.EnvParams] = None\n",
    "    ) -> Tuple[chex.Array, gymnax.environments.environment.EnvState]:\n",
    "        obs, env_state = self._env.reset(key, params)\n",
    "        state = LogEnvState(env_state, 0, 0, 0, 0, 0)\n",
    "        return obs, state\n",
    "\n",
    "    @partial(jax.jit, static_argnums=(0,))\n",
    "    def step(\n",
    "        self,\n",
    "        key: chex.PRNGKey,\n",
    "        state: gymnax.environments.environment.EnvState,\n",
    "        action: Union[int, float],\n",
    "        params: Optional[gymnax.environments.environment.EnvParams] = None,\n",
    "    ) -> Tuple[chex.Array, gymnax.environments.environment.EnvState, float, bool, dict]:\n",
    "        obs, env_state, reward, done, info = self._env.step(\n",
    "            key, state.env_state, action, params\n",
    "        )\n",
    "        new_episode_return = state.episode_returns + reward\n",
    "        new_episode_length = state.episode_lengths + 1\n",
    "        state = LogEnvState(\n",
    "            env_state=env_state,\n",
    "            episode_returns=new_episode_return * (1 - done),\n",
    "            episode_lengths=new_episode_length * (1 - done),\n",
    "            returned_episode_returns=state.returned_episode_returns * (1 - done)\n",
    "            + new_episode_return * done,\n",
    "            returned_episode_lengths=state.returned_episode_lengths * (1 - done)\n",
    "            + new_episode_length * done,\n",
    "            timestep=state.timestep + 1,\n",
    "        )\n",
    "        info[\"returned_episode_returns\"] = state.returned_episode_returns\n",
    "        info[\"returned_episode_lengths\"] = state.returned_episode_lengths\n",
    "        info[\"timestep\"] = state.timestep\n",
    "        info[\"returned_episode\"] = done\n",
    "        return obs, state, reward, done, info\n",
    "\n",
    "class ActorCritic(nn.Module):\n",
    "    action_dim: Sequence[int]\n",
    "    activation: str = \"tanh\"\n",
    "\n",
    "    @nn.compact\n",
    "    def __call__(self, x):\n",
    "        if self.activation == \"relu\":\n",
    "            activation = nn.relu\n",
    "        else:\n",
    "            activation = nn.tanh\n",
    "        actor_mean = nn.Dense(\n",
    "            64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)\n",
    "        )(x)\n",
    "        actor_mean = activation(actor_mean)\n",
    "        actor_mean = nn.Dense(\n",
    "            64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)\n",
    "        )(actor_mean)\n",
    "        actor_mean = activation(actor_mean)\n",
    "        actor_mean = nn.Dense(\n",
    "            self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)\n",
    "        )(actor_mean)\n",
    "        pi = distrax.Categorical(logits=actor_mean)\n",
    "\n",
    "        critic = nn.Dense(\n",
    "            64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)\n",
    "        )(x)\n",
    "        critic = activation(critic)\n",
    "        critic = nn.Dense(\n",
    "            64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)\n",
    "        )(critic)\n",
    "        critic = activation(critic)\n",
    "        critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(\n",
    "            critic\n",
    "        )\n",
    "\n",
    "        return pi, jnp.squeeze(critic, axis=-1)\n",
    "\n",
    "\n",
    "class Transition(NamedTuple):\n",
    "    done: jnp.ndarray\n",
    "    action: jnp.ndarray\n",
    "    value: jnp.ndarray\n",
    "    reward: jnp.ndarray\n",
    "    log_prob: jnp.ndarray\n",
    "    obs: jnp.ndarray\n",
    "    info: jnp.ndarray\n",
    "\n",
    "\n",
    "def make_train(config):\n",
    "    config[\"NUM_UPDATES\"] = (\n",
    "        config[\"TOTAL_TIMESTEPS\"] // config[\"NUM_STEPS\"] // config[\"NUM_ENVS\"]\n",
    "    )\n",
    "    config[\"MINIBATCH_SIZE\"] = (\n",
    "        config[\"NUM_ENVS\"] * config[\"NUM_STEPS\"] // config[\"NUM_MINIBATCHES\"]\n",
    "    )\n",
    "    env = SynthEnv(config[\"ENV_NAME\"])\n",
    "    env_params = serialization.from_state_dict(env.default_params, state_dict)\n",
    "    env = FlattenObservationWrapper(env)\n",
    "    env = LogWrapper(env)\n",
    "\n",
    "    def linear_schedule(count):\n",
    "        frac = (\n",
    "            1.0\n",
    "            - (count // (config[\"NUM_MINIBATCHES\"] * config[\"UPDATE_EPOCHS\"]))\n",
    "            / config[\"NUM_UPDATES\"]\n",
    "        )\n",
    "        return config[\"LR\"] * frac\n",
    "\n",
    "    def train(rng):\n",
    "        # INIT NETWORK\n",
    "        network = ActorCritic(\n",
    "            env.action_space(env_params).n, activation=config[\"ACTIVATION\"]\n",
    "        )\n",
    "        rng, _rng = jax.random.split(rng)\n",
    "        init_x = jnp.zeros(env.observation_space(env_params).shape)\n",
    "        network_params = network.init(_rng, init_x)\n",
    "        if config[\"ANNEAL_LR\"]:\n",
    "            tx = optax.chain(\n",
    "                optax.clip_by_global_norm(config[\"MAX_GRAD_NORM\"]),\n",
    "                optax.adam(learning_rate=linear_schedule, eps=1e-5),\n",
    "            )\n",
    "        else:\n",
    "            tx = optax.chain(\n",
    "                optax.clip_by_global_norm(config[\"MAX_GRAD_NORM\"]),\n",
    "                optax.adam(config[\"LR\"], eps=1e-5),\n",
    "            )\n",
    "        train_state = TrainState.create(\n",
    "            apply_fn=network.apply,\n",
    "            params=network_params,\n",
    "            tx=tx,\n",
    "        )\n",
    "\n",
    "        # INIT ENV\n",
    "        rng, _rng = jax.random.split(rng)\n",
    "        reset_rng = jax.random.split(_rng, config[\"NUM_ENVS\"])\n",
    "        obsv, env_state = jax.vmap(env.reset, in_axes=(0, None))(reset_rng, env_params)\n",
    "\n",
    "        # TRAIN LOOP\n",
    "        def _update_step(runner_state, unused):\n",
    "            # COLLECT TRAJECTORIES\n",
    "            def _env_step(runner_state, unused):\n",
    "                train_state, env_state, last_obs, rng = runner_state\n",
    "\n",
    "                # SELECT ACTION\n",
    "                rng, _rng = jax.random.split(rng)\n",
    "                pi, value = network.apply(train_state.params, last_obs)\n",
    "                action = pi.sample(seed=_rng)\n",
    "                log_prob = pi.log_prob(action)\n",
    "\n",
    "                # STEP ENV\n",
    "                rng, _rng = jax.random.split(rng)\n",
    "                rng_step = jax.random.split(_rng, config[\"NUM_ENVS\"])\n",
    "                obsv, env_state, reward, done, info = jax.vmap(\n",
    "                    env.step, in_axes=(0, 0, 0, None)\n",
    "                )(rng_step, env_state, action, env_params)\n",
    "                transition = Transition(\n",
    "                    done, action, value, reward, log_prob, last_obs, info\n",
    "                )\n",
    "                runner_state = (train_state, env_state, obsv, rng)\n",
    "                return runner_state, transition\n",
    "\n",
    "            runner_state, traj_batch = jax.lax.scan(\n",
    "                _env_step, runner_state, None, config[\"NUM_STEPS\"]\n",
    "            )\n",
    "\n",
    "            # CALCULATE ADVANTAGE\n",
    "            train_state, env_state, last_obs, rng = runner_state\n",
    "            _, last_val = network.apply(train_state.params, last_obs)\n",
    "\n",
    "            def _calculate_gae(traj_batch, last_val):\n",
    "                def _get_advantages(gae_and_next_value, transition):\n",
    "                    gae, next_value = gae_and_next_value\n",
    "                    done, value, reward = (\n",
    "                        transition.done,\n",
    "                        transition.value,\n",
    "                        transition.reward,\n",
    "                    )\n",
    "                    delta = reward + config[\"GAMMA\"] * next_value * (1 - done) - value\n",
    "                    gae = (\n",
    "                        delta\n",
    "                        + config[\"GAMMA\"] * config[\"GAE_LAMBDA\"] * (1 - done) * gae\n",
    "                    )\n",
    "                    return (gae, value), gae\n",
    "\n",
    "                _, advantages = jax.lax.scan(\n",
    "                    _get_advantages,\n",
    "                    (jnp.zeros_like(last_val), last_val),\n",
    "                    traj_batch,\n",
    "                    reverse=True,\n",
    "                    unroll=16,\n",
    "                )\n",
    "                return advantages, advantages + traj_batch.value\n",
    "\n",
    "            advantages, targets = _calculate_gae(traj_batch, last_val)\n",
    "\n",
    "            # UPDATE NETWORK\n",
    "            def _update_epoch(update_state, unused):\n",
    "                def _update_minbatch(train_state, batch_info):\n",
    "                    traj_batch, advantages, targets = batch_info\n",
    "\n",
    "                    def _loss_fn(params, traj_batch, gae, targets):\n",
    "                        # RERUN NETWORK\n",
    "                        pi, value = network.apply(params, traj_batch.obs)\n",
    "                        log_prob = pi.log_prob(traj_batch.action)\n",
    "\n",
    "                        # CALCULATE VALUE LOSS\n",
    "                        value_pred_clipped = traj_batch.value + (\n",
    "                            value - traj_batch.value\n",
    "                        ).clip(-config[\"CLIP_EPS\"], config[\"CLIP_EPS\"])\n",
    "                        value_losses = jnp.square(value - targets)\n",
    "                        value_losses_clipped = jnp.square(value_pred_clipped - targets)\n",
    "                        value_loss = (\n",
    "                            0.5 * jnp.maximum(value_losses, value_losses_clipped).mean()\n",
    "                        )\n",
    "\n",
    "                        # CALCULATE ACTOR LOSS\n",
    "                        ratio = jnp.exp(log_prob - traj_batch.log_prob)\n",
    "                        gae = (gae - gae.mean()) / (gae.std() + 1e-8)\n",
    "                        loss_actor1 = ratio * gae\n",
    "                        loss_actor2 = (\n",
    "                            jnp.clip(\n",
    "                                ratio,\n",
    "                                1.0 - config[\"CLIP_EPS\"],\n",
    "                                1.0 + config[\"CLIP_EPS\"],\n",
    "                            )\n",
    "                            * gae\n",
    "                        )\n",
    "                        loss_actor = -jnp.minimum(loss_actor1, loss_actor2)\n",
    "                        loss_actor = loss_actor.mean()\n",
    "                        entropy = pi.entropy().mean()\n",
    "\n",
    "                        total_loss = (\n",
    "                            loss_actor\n",
    "                            + config[\"VF_COEF\"] * value_loss\n",
    "                            - config[\"ENT_COEF\"] * entropy\n",
    "                        )\n",
    "                        return total_loss, (value_loss, loss_actor, entropy)\n",
    "\n",
    "                    grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)\n",
    "                    total_loss, grads = grad_fn(\n",
    "                        train_state.params, traj_batch, advantages, targets\n",
    "                    )\n",
    "                    train_state = train_state.apply_gradients(grads=grads)\n",
    "                    return train_state, total_loss\n",
    "\n",
    "                train_state, traj_batch, advantages, targets, rng = update_state\n",
    "                rng, _rng = jax.random.split(rng)\n",
    "                batch_size = config[\"MINIBATCH_SIZE\"] * config[\"NUM_MINIBATCHES\"]\n",
    "                assert (\n",
    "                    batch_size == config[\"NUM_STEPS\"] * config[\"NUM_ENVS\"]\n",
    "                ), \"batch size must be equal to number of steps * number of envs\"\n",
    "                permutation = jax.random.permutation(_rng, batch_size)\n",
    "                batch = (traj_batch, advantages, targets)\n",
    "                batch = jax.tree_util.tree_map(\n",
    "                    lambda x: x.reshape((batch_size,) + x.shape[2:]), batch\n",
    "                )\n",
    "                shuffled_batch = jax.tree_util.tree_map(\n",
    "                    lambda x: jnp.take(x, permutation, axis=0), batch\n",
    "                )\n",
    "                minibatches = jax.tree_util.tree_map(\n",
    "                    lambda x: jnp.reshape(\n",
    "                        x, [config[\"NUM_MINIBATCHES\"], -1] + list(x.shape[1:])\n",
    "                    ),\n",
    "                    shuffled_batch,\n",
    "                )\n",
    "                train_state, total_loss = jax.lax.scan(\n",
    "                    _update_minbatch, train_state, minibatches\n",
    "                )\n",
    "                update_state = (train_state, traj_batch, advantages, targets, rng)\n",
    "                return update_state, total_loss\n",
    "\n",
    "            update_state = (train_state, traj_batch, advantages, targets, rng)\n",
    "            update_state, loss_info = jax.lax.scan(\n",
    "                _update_epoch, update_state, None, config[\"UPDATE_EPOCHS\"]\n",
    "            )\n",
    "            train_state = update_state[0]\n",
    "            metric = traj_batch.info\n",
    "            rng = update_state[-1]\n",
    "            if config.get(\"DEBUG\"):\n",
    "                def callback(info):\n",
    "                    return_values = info[\"returned_episode_returns\"][info[\"returned_episode\"]]\n",
    "                    timesteps = info[\"timestep\"][info[\"returned_episode\"]] * config[\"NUM_ENVS\"]\n",
    "                    for t in range(len(timesteps)):\n",
    "                        print(f\"global step={timesteps[t]}, episodic return={return_values[t]}\")\n",
    "                jax.debug.callback(callback, metric)\n",
    "\n",
    "            runner_state = (train_state, env_state, last_obs, rng)\n",
    "            return runner_state, metric\n",
    "\n",
    "        rng, _rng = jax.random.split(rng)\n",
    "        runner_state = (train_state, env_state, obsv, _rng)\n",
    "        runner_state, metric = jax.lax.scan(\n",
    "            _update_step, runner_state, None, config[\"NUM_UPDATES\"]\n",
    "        )\n",
    "        return {\"runner_state\": runner_state, \"metrics\": metric}\n",
    "\n",
    "    return train"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0df24274",
   "metadata": {},
   "source": [
    "## Training an agent on the SCB\n",
    "We just use the config from github, don't worry about tuning  :-)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "72259577",
   "metadata": {},
   "outputs": [],
   "source": [
    "config = {\n",
    "    \"LR\": 2.5e-4,\n",
    "    \"NUM_ENVS\": 4,\n",
    "    \"NUM_STEPS\": 128,\n",
    "    \"TOTAL_TIMESTEPS\": 5e4,  # was originally 5e5\n",
    "    \"UPDATE_EPOCHS\": 4,\n",
    "    \"NUM_MINIBATCHES\": 4,\n",
    "    \"GAMMA\": 0.99,\n",
    "    \"GAE_LAMBDA\": 0.95,\n",
    "    \"CLIP_EPS\": 0.2,\n",
    "    \"ENT_COEF\": 0.01,\n",
    "    \"VF_COEF\": 0.5,\n",
    "    \"MAX_GRAD_NORM\": 0.5,\n",
    "    \"ACTIVATION\": \"tanh\",\n",
    "    \"ENV_NAME\": \"CartPole-v1\",\n",
    "    \"ANNEAL_LR\": True,\n",
    "    \"DEBUG\": False,\n",
    "}\n",
    "rng = jax.random.PRNGKey(0)\n",
    "train_jit = jax.jit(make_train(config))\n",
    "out = train_jit(rng)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3174bdf3",
   "metadata": {},
   "source": [
    "## Evaluating the agent\n",
    "We quickly setup a vmappable evaluation function, in order to evaluate many seeds in parallel."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "58933ddc",
   "metadata": {},
   "outputs": [],
   "source": [
    "@struct.dataclass\n",
    "class EvalState:\n",
    "    rng: chex.PRNGKey\n",
    "    env_state: Any\n",
    "    last_obs: chex.Array\n",
    "    done: bool = False\n",
    "    return_: float = 0.0\n",
    "        \n",
    "def evaluate(act, env, env_params, rng):\n",
    "    def step(state):\n",
    "        rng, rng_act, rng_step = jax.random.split(state.rng, 3)\n",
    "        action = act(state.last_obs, rng_act)\n",
    "        obs, env_state, reward, done, info = env.step(\n",
    "            rng_step, state.env_state, action, env_params\n",
    "        )\n",
    "        state = EvalState(\n",
    "            rng=rng,\n",
    "            env_state=env_state,\n",
    "            last_obs=obs,\n",
    "            done=done,\n",
    "            return_=state.return_ + reward.squeeze(),\n",
    "        )\n",
    "        return state\n",
    "\n",
    "    rng_reset, rng_eval = jax.random.split(rng)\n",
    "    obs, env_state = env.reset(rng_reset, env_params)\n",
    "    state = EvalState(rng_eval, env_state, obs)\n",
    "    state = jax.lax.while_loop(\n",
    "        lambda s: jnp.logical_not(s.done),\n",
    "        step,\n",
    "        state,\n",
    "    )\n",
    "    return state.return_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "76e62171",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Mean return after training in synthetic environment: 500.0\n"
     ]
    }
   ],
   "source": [
    "def act(obs, rng):\n",
    "    ts = out[\"runner_state\"][0]\n",
    "    pi, v = ts.apply_fn(ts.params, obs)\n",
    "    action = pi.sample(seed=rng)\n",
    "    return action\n",
    "\n",
    "env, env_params = gymnax.make(\"CartPole-v1\")\n",
    "rng = jax.random.PRNGKey(0)\n",
    "rngs = jax.random.split(rng, 200)\n",
    "returns = jax.vmap(evaluate, in_axes=(None, None, None, 0))(act, env, env_params, rngs)\n",
    "print(f\"Mean return after training in synthetic environment: {returns.mean()}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2cd024d0",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "jax",
   "language": "python",
   "name": "jax"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
