{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "xla_flags = os.environ.get(\"XLA_FLAGS\", \"\")\n",
    "xla_flags += \" --xla_gpu_triton_gemm_any=True\"\n",
    "## For deterministic results.\n",
    "# xla_flags += \" --xla_gpu_deterministic_ops=true --xla_gpu_autotune_level=0\"\n",
    "os.environ[\"XLA_FLAGS\"] = xla_flags\n",
    "os.environ[\"XLA_PYTHON_CLIENT_PREALLOCATE\"] = \"false\"\n",
    "os.environ[\"MUJOCO_GL\"] = \"egl\"\n",
    "## In case your machine has GPUs of different types (pjit does not work with mixed GPUs).\n",
    "# os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import functools\n",
    "import json\n",
    "from datetime import datetime\n",
    "\n",
    "import jax\n",
    "import jax.numpy as jp\n",
    "import matplotlib.pyplot as plt\n",
    "import mediapy as media\n",
    "import mujoco\n",
    "import numpy as np\n",
    "import wandb\n",
    "from brax.training.agents.ppo import networks as ppo_networks\n",
    "from brax.training.agents.ppo import train as ppo\n",
    "from etils import epath\n",
    "from flax.training import orbax_utils\n",
    "from IPython.display import clear_output, display\n",
    "from orbax import checkpoint as ocp\n",
    "\n",
    "from mujoco_playground import BraxEnvWrapper, locomotion\n",
    "from mujoco_playground.locomotion.h1 import randomize as h1_randomize\n",
    "from mujoco_playground.locomotion.spot import randomize as spot_randomize\n",
    "\n",
    "jax.config.update(\"jax_compilation_cache_dir\", \"/tmp/jax_cache\")\n",
    "jax.config.update(\"jax_persistent_cache_min_entry_size_bytes\", -1)\n",
    "jax.config.update(\"jax_persistent_cache_min_compile_time_secs\", 0)\n",
    "\n",
    "RANDOMIZER = {\n",
    "    \"SpotInplaceGaitTracking\": spot_randomize,\n",
    "    \"H1InplaceGaitTracking\": h1_randomize,\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "env_name = \"SpotInplaceGaitTracking\"\n",
    "env_cfg = locomotion.get_default_config(env_name)\n",
    "env_cfg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Setup wandb logging.\n",
    "USE_WANDB = False\n",
    "\n",
    "if USE_WANDB:\n",
    "  wandb.init(project=\"mjxrl\", config=env_cfg)\n",
    "  wandb.config.update({\n",
    "      \"env_name\": env_name,\n",
    "  })"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "SUFFIX = None\n",
    "FINETUNE_PATH = None\n",
    "\n",
    "# Generate unique experiment name.\n",
    "now = datetime.now()\n",
    "timestamp = now.strftime(\"%Y%m%d-%H%M%S\")\n",
    "exp_name = f\"{env_name}-{timestamp}\"\n",
    "if SUFFIX is not None:\n",
    "  exp_name += f\"-{SUFFIX}\"\n",
    "print(f\"Experiment name: {exp_name}\")\n",
    "\n",
    "# Possibly restore from the latest checkpoint.\n",
    "if FINETUNE_PATH is not None:\n",
    "  FINETUNE_PATH = epath.Path(FINETUNE_PATH)\n",
    "  latest_ckpts = list(FINETUNE_PATH.glob(\"*\"))\n",
    "  latest_ckpts.sort()\n",
    "  latest_ckpt = latest_ckpts[0]\n",
    "  restore_checkpoint_path = latest_ckpt\n",
    "  print(f\"Restoring from: {restore_checkpoint_path}\")\n",
    "else:\n",
    "  restore_checkpoint_path = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ckpt_path = epath.Path(\"checkpoints\").resolve() / exp_name\n",
    "ckpt_path.mkdir(parents=True, exist_ok=True)\n",
    "print(f\"Checkpoint path: {ckpt_path}\")\n",
    "\n",
    "with open(ckpt_path / \"config.json\", \"w\") as fp:\n",
    "  json.dump(env_cfg.to_dict(), fp, indent=4)\n",
    "\n",
    "\n",
    "def policy_params_fn(current_step, make_policy, params):\n",
    "  orbax_checkpointer = ocp.PyTreeCheckpointer()\n",
    "  save_args = orbax_utils.save_args_from_target(params)\n",
    "  path = ckpt_path / f\"{current_step}\"\n",
    "  orbax_checkpointer.save(path, params, force=True, save_args=save_args)\n",
    "\n",
    "\n",
    "make_networks_factory = functools.partial(\n",
    "    ppo_networks.make_ppo_networks,\n",
    "    policy_hidden_layer_sizes=(128, 128, 128, 128),\n",
    ")\n",
    "\n",
    "train_fn = functools.partial(\n",
    "    ppo.train,\n",
    "    num_timesteps=70_000_000,\n",
    "    num_evals=5,\n",
    "    reward_scaling=1.0,\n",
    "    episode_length=env_cfg.episode_length,\n",
    "    normalize_observations=True,\n",
    "    action_repeat=1,\n",
    "    unroll_length=20,\n",
    "    num_minibatches=32,\n",
    "    num_updates_per_batch=4,\n",
    "    discounting=0.97,\n",
    "    learning_rate=3e-4,\n",
    "    entropy_cost=1e-2,\n",
    "    num_envs=8192,\n",
    "    batch_size=256,\n",
    "    network_factory=make_networks_factory,\n",
    "    policy_params_fn=policy_params_fn,\n",
    "    randomization_fn=RANDOMIZER[env_name].domain_randomize,\n",
    "    seed=1,\n",
    "    restore_checkpoint_path=restore_checkpoint_path,\n",
    "    max_grad_norm=1.0,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_data, y_data, y_dataerr = [], [], []\n",
    "times = [datetime.now()]\n",
    "\n",
    "\n",
    "def progress(num_steps, metrics):\n",
    "  # Log to wandb.\n",
    "  if USE_WANDB:\n",
    "    wandb.log(metrics, step=num_steps)\n",
    "\n",
    "  # Plot.\n",
    "  clear_output(wait=True)\n",
    "  times.append(datetime.now())\n",
    "  x_data.append(num_steps)\n",
    "  y_data.append(metrics[\"eval/episode_reward\"])\n",
    "  y_dataerr.append(metrics[\"eval/episode_reward_std\"])\n",
    "\n",
    "  plt.xlim([0, train_fn.keywords[\"num_timesteps\"] * 1.25])\n",
    "  plt.ylim([0, 75])\n",
    "  plt.xlabel(\"# environment steps\")\n",
    "  plt.ylabel(\"reward per episode\")\n",
    "  plt.title(f\"y={y_data[-1]:.3f}\")\n",
    "  plt.errorbar(x_data, y_data, yerr=y_dataerr, color=\"blue\")\n",
    "\n",
    "  display(plt.gcf())\n",
    "\n",
    "\n",
    "env = locomotion.load(env_name, config=env_cfg)\n",
    "eval_env = locomotion.load(env_name, config=env_cfg)\n",
    "make_inference_fn, params, _ = train_fn(\n",
    "    environment=BraxEnvWrapper(env),\n",
    "    progress_fn=progress,\n",
    "    eval_env=BraxEnvWrapper(eval_env),\n",
    ")\n",
    "if len(times) > 1:\n",
    "  print(f\"time to jit: {times[1] - times[0]}\")\n",
    "  print(f\"time to train: {times[-1] - times[1]}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "inference_fn = make_inference_fn(params, deterministic=True)\n",
    "jit_inference_fn = jax.jit(inference_fn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_env = locomotion.load(env_name, config=env_cfg)\n",
    "jit_reset = jax.jit(eval_env.reset)\n",
    "jit_step = jax.jit(eval_env.step)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from mujoco_playground.locomotion.go1.visualization import draw_joystick_command\n",
    "# x_min = 0.0\n",
    "# x_max = 1.0\n",
    "# def get_scl(x: float) -> float:\n",
    "# return (x - x_min) / (x_max - x_min)\n",
    "# yaw = 0.2\n",
    "# x = 0.1\n",
    "# cmd = jp.array([x, 0, yaw])\n",
    "\n",
    "rng = jax.random.PRNGKey(12345)\n",
    "rng, reset_rng = jax.random.split(rng)\n",
    "state = jit_reset(reset_rng)\n",
    "# state.info[\"command\"] = cmd\n",
    "\n",
    "# Change gait type.\n",
    "PHASES = jp.array([\n",
    "    [0, jp.pi, jp.pi, 0],  # trot\n",
    "    [0, 0.5 * jp.pi, jp.pi, 1.5 * jp.pi],  # walk\n",
    "    [0, jp.pi, 0, jp.pi],  # pace\n",
    "    [0, 0, jp.pi, jp.pi],  # bound\n",
    "    [0, 0, 0, 0],  # pronk\n",
    "])\n",
    "# PHASES = jp.array(\n",
    "#     [\n",
    "#         [0, jp.pi],  # walk\n",
    "#         [0, 0],  # jump\n",
    "#     ]\n",
    "# )\n",
    "gait = jp.array(0)\n",
    "state.info[\"gait\"] = gait\n",
    "state.info[\"phase\"] = PHASES[gait]\n",
    "\n",
    "modify_scene_fns = []\n",
    "rollout = []\n",
    "swing_peak = []\n",
    "foot_height = 0.4  # [0.08, 0.4]\n",
    "gait_freq = 2  # [0.5, 4.0]\n",
    "phase_dt = 2 * jp.pi * env.dt * jp.array(gait_freq)\n",
    "state.info[\"phase_dt\"] = phase_dt\n",
    "state.info[\"gait_freq\"] = jp.array(gait_freq)\n",
    "for i in range(800):\n",
    "  if i % 200 == 0:\n",
    "    foot_height = np.random.uniform(0.08, 0.4)\n",
    "    state.info[\"foot_height\"] = jp.array(foot_height)\n",
    "    # gait_freq = np.random.uniform(1.5, 4.0)\n",
    "    # phase_dt = 2 * jp.pi * env.dt * jp.array(gait_freq)\n",
    "    # state.info[\"phase_dt\"] = phase_dt\n",
    "    # state.info[\"gait_freq\"] = jp.array(gait_freq)\n",
    "    # x += 0.2\n",
    "    # yaw = np.random.uniform(-0.7, 0.7)\n",
    "    # cmd = jp.array([x, 0, yaw])\n",
    "    # state.info[\"command\"] = cmd\n",
    "    # Print all the sampled parameters.\n",
    "    # print(\n",
    "    #     f\"foot_height: {foot_height:.3f}, gait_freq: {gait_freq:.3f}, phase_dt: {phase_dt:.3f}, yaw: {yaw:.3f}, x: {x:.3f}\"\n",
    "    # )\n",
    "\n",
    "  act_rng, rng = jax.random.split(rng)\n",
    "  ctrl, _ = jit_inference_fn(state.obs, act_rng)\n",
    "  state = jit_step(state, ctrl)\n",
    "  # state.info[\"command\"] = cmd\n",
    "  state.info[\"phase_dt\"] = phase_dt\n",
    "  state.info[\"gait_freq\"] = jp.array(gait_freq)\n",
    "  rollout.append(state)\n",
    "  swing_peak.append(state.info[\"swing_peak\"])\n",
    "\n",
    "  # xyz = np.array(state.data.xpos[env._torso_body_id])\n",
    "  # xyz += np.array([0, 0, 0.2])\n",
    "  # x_axis = state.data.xmat[env._torso_body_id, 0]\n",
    "  # yaw = -np.arctan2(x_axis[1], x_axis[0])\n",
    "  # modify_scene_fns.append(\n",
    "  #     functools.partial(\n",
    "  #         draw_joystick_command,\n",
    "  #         cmd=cmd,\n",
    "  #         xyz=xyz,\n",
    "  #         foot_height=foot_height,\n",
    "  #         foot_pos=state.data.site_xpos[env._feet_site_id],\n",
    "  #         theta=yaw,\n",
    "  #         scl=get_scl(x),\n",
    "  #     )\n",
    "  # )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "render_every = 1\n",
    "fps = 1.0 / eval_env.dt / render_every\n",
    "print(f\"fps: {fps}\")\n",
    "\n",
    "traj = rollout[::render_every]\n",
    "\n",
    "scene_option = mujoco.MjvOption()\n",
    "scene_option.geomgroup[2] = True\n",
    "scene_option.geomgroup[3] = False\n",
    "scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = True\n",
    "\n",
    "frames = eval_env.render(\n",
    "    traj,\n",
    "    camera=\"side\",\n",
    "    scene_option=scene_option,\n",
    "    height=480 * 2,\n",
    "    width=640 * 2,\n",
    ")\n",
    "media.show_video(frames, fps=fps)\n",
    "# media.write_video(\"h1_inplace_gait_tracking.mp4\", frames, fps=fps, qp=18)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
