{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "7c974745-4680-418a-8f5e-f6856371ade3",
      "metadata": {
        "id": "6PIorEo4wH86"
      },
      "outputs": [],
      "source": [
        "%load_ext autoreload\n",
        "\n",
        "%autoreload 2"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "e5997156-d7ca-4965-8c98-7d652bc3fa91",
      "metadata": {
        "id": "NNMD0KPowH86"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "\n",
        "xla_flags = os.environ.get(\"XLA_FLAGS\", \"\")\n",
        "xla_flags += \" --xla_gpu_triton_gemm_any=True\"\n",
        "os.environ[\"XLA_FLAGS\"] = xla_flags\n",
        "os.environ[\"XLA_PYTHON_CLIENT_PREALLOCATE\"] = \"false\"\n",
        "os.environ[\"MUJOCO_GL\"] = \"egl\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "05ca1152-6a55-447e-9ab4-1163437b23c1",
      "metadata": {
        "id": "1y_RJYmGwH86"
      },
      "outputs": [],
      "source": [
        "import functools\n",
        "import json\n",
        "from datetime import datetime\n",
        "from typing import Any, Dict, Optional\n",
        "\n",
        "import jax\n",
        "import jax.numpy as jp\n",
        "import pickle\n",
        "import matplotlib.pyplot as plt\n",
        "import mediapy as media\n",
        "import mujoco\n",
        "import numpy as np\n",
        "import wandb\n",
        "from brax.training.agents.ppo import networks as ppo_networks\n",
        "from brax.training.agents.ppo import train as ppo\n",
        "from etils import epath\n",
        "from flax.training import orbax_utils\n",
        "from IPython.display import clear_output, display\n",
        "from orbax import checkpoint as ocp\n",
        "\n",
        "from mujoco_playground import registry\n",
        "from mujoco_playground.config import locomotion_params, manipulation_params, dm_control_suite_params\n",
        "\n",
        "from brax.io import model as brax_io_model\n",
        "from brax.training.acme import running_statistics\n",
        "from brax.training.agents.ppo import networks as brax_ppo_networks\n",
        "from brax.training.agents.sac import networks as brax_sac_networks\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "4c1823f2-d423-4b4c-a13f-70144f307ac9",
      "metadata": {
        "id": "j-iyCc-3wH86"
      },
      "outputs": [],
      "source": [
        "def get_inference_fn(\n",
        "    obs_size, act_size, normalize_obs, network_factory_kwargs, params,\n",
        "    is_ppo=True,\n",
        "):\n",
        "    def make_inference_fn(\n",
        "        observation_size: int,\n",
        "        action_size: int,\n",
        "        normalize_observations: bool = True,\n",
        "        network_factory_kwargs: Optional[Dict[str, Any]] = None,\n",
        "    ):\n",
        "      normalize = lambda x, y: x\n",
        "      if normalize_observations:\n",
        "        normalize = running_statistics.normalize\n",
        "      if is_ppo:\n",
        "        ppo_network = brax_ppo_networks.make_ppo_networks(\n",
        "              observation_size,\n",
        "              action_size,\n",
        "              preprocess_observations_fn=normalize,\n",
        "              **(network_factory_kwargs or {}),\n",
        "        )\n",
        "        make_policy = brax_ppo_networks.make_inference_fn(ppo_network)\n",
        "      else:\n",
        "        sac_network = brax_sac_networks.make_sac_networks(\n",
        "              observation_size,\n",
        "              action_size,\n",
        "              preprocess_observations_fn=normalize,\n",
        "              **(network_factory_kwargs or {}),\n",
        "        )\n",
        "        make_policy = brax_sac_networks.make_inference_fn(sac_network)\n",
        "      return make_policy\n",
        "\n",
        "    make_policy = make_inference_fn(\n",
        "        obs_size,\n",
        "        act_size,\n",
        "        normalize_obs,\n",
        "        network_factory_kwargs,\n",
        "    )\n",
        "    jit_inference_fn = jax.jit(make_policy(params, deterministic=True))\n",
        "    return jit_inference_fn"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "8c09456f-bbca-4ec9-aaf8-4e782b59ccc6",
      "metadata": {
        "id": "lHwe1zMRwH86"
      },
      "outputs": [],
      "source": [
        "policy_params = {}\n",
        "inference_fns = {}\n",
        "for f in epath.Path('../data/pkl').glob('*.pkl'):\n",
        "    env_name, algo, *_ = f.name.split('_')\n",
        "    algo = algo.rstrip('.pkl')\n",
        "    policy_params[env_name] = brax_io_model.load_params(f)\n",
        "\n",
        "    if env_name in registry.locomotion.ALL_ENVS:\n",
        "        algo_params = locomotion_params.brax_ppo_config(env_name)\n",
        "    elif env_name in registry.dm_control_suite.ALL_ENVS:\n",
        "        if algo == 'ppo':\n",
        "            algo_params = dm_control_suite_params.brax_ppo_config(env_name)\n",
        "        elif algo == 'sac':\n",
        "            algo_params = dm_control_suite_params.brax_sac_config(env_name)\n",
        "    elif env_name in registry.manipulation.ALL_ENVS:\n",
        "        algo_params = manipulation_params.brax_ppo_config(env_name)\n",
        "    else:\n",
        "        raise AssertionError('nope')\n",
        "\n",
        "    network_factory_kwargs = None\n",
        "    if hasattr(algo_params, 'network_factory'):\n",
        "        network_factory_kwargs = dict(algo_params.network_factory)\n",
        "\n",
        "    env = registry.load(env_name)\n",
        "    inference_fns[env_name] = get_inference_fn(\n",
        "        env.observation_size, env.action_size,\n",
        "        normalize_obs=algo_params.normalize_observations,\n",
        "        network_factory_kwargs=network_factory_kwargs,\n",
        "        params=policy_params[env_name],\n",
        "        is_ppo=(algo == 'ppo')\n",
        "    )\n",
        "    print(env_name, ' Done!')\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "f938bbd4-03c1-4ffa-a1cc-2992c52251b1",
      "metadata": {
        "id": "yzgiYP5GwH86"
      },
      "outputs": [],
      "source": [
        "env_name = 'CheetahRun'\n",
        "N_EPISODES = 1\n",
        "\n",
        "env = registry.load(env_name)\n",
        "cfg = registry.get_default_config(env_name)\n",
        "\n",
        "jit_reset = jax.jit(env.reset)\n",
        "jit_step = jax.jit(env.step)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "b06896ad-d629-470f-982a-486fa07ee880",
      "metadata": {
        "id": "RSeraO18wH86"
      },
      "outputs": [],
      "source": [
        "rng = jax.random.PRNGKey(0)\n",
        "rollout = [jit_reset(rng)]\n",
        "\n",
        "for _ in range(N_EPISODES):\n",
        "    step, done = 0, False\n",
        "    while not done and step < cfg.episode_length:\n",
        "        if step % 100 == 0:\n",
        "            print(step)\n",
        "\n",
        "        rng, _ = jax.random.split(rng)\n",
        "        state = rollout[-1]\n",
        "        action = inference_fns[env_name](state.obs, rng)[0]\n",
        "        rollout.append(jit_step(state, action))\n",
        "\n",
        "        step += cfg.action_repeat\n",
        "        done = bool(state.done)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "2b3e06e1-cbbb-4b09-8138-dd01a446b1f4",
      "metadata": {
        "id": "H0eI4WoswH86"
      },
      "outputs": [],
      "source": [
        "fps = 1.0 / env.dt\n",
        "print(f\"fps: {fps}\")\n",
        "\n",
        "scene_option = mujoco.MjvOption()\n",
        "scene_option.geomgroup[2] = True\n",
        "scene_option.geomgroup[3] = False\n",
        "scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = True\n",
        "scene_option.flags[mujoco.mjtVisFlag.mjVIS_PERTFORCE] = True\n",
        "\n",
        "frames = env.render(\n",
        "    rollout,\n",
        "    # camera=\"track\",\n",
        "    height=480,\n",
        "    width=640,\n",
        "    # modify_scene_fns=mod_fns,\n",
        "    scene_option=scene_option,\n",
        ")\n",
        "media.show_video(frames, fps=fps, loop=False)"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3 (ipykernel)",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.12.3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}
