
<!doctype html>
<html lang="en" class="no-js">
  <head>
    
      <meta charset="utf-8">
      <meta name="viewport" content="width=device-width,initial-scale=1">
      
        <meta name="description" content="The documentation for the Rex software library.">
      
      
        <meta name="author" content="Anonymous">
      
      
        <link rel="canonical" href="https://github.com/anonymous/rex/examples/node_definitions.html">
      
      
        <link rel="prev" href="../index.html">
      
      
        <link rel="next" href="graph_and_environment_creation.html">
      
      
      <link rel="icon" href="../_static/favicon_trex.ico">
      <meta name="generator" content="mkdocs-1.6.1, mkdocs-material-9.5.47">
    
    
      
        <title>How to define nodes - Rex</title>
      
    
    
      <link rel="stylesheet" href="../assets/stylesheets/main.6f8fc17f.min.css">
      
        
        <link rel="stylesheet" href="../assets/stylesheets/palette.06af60db.min.css">
      
      
  
  
    
    
  
    
    
  
    
    
  
    
    
  
    
    
  
    
    
  
    
    
  
    
    
  
    
    
  
    
    
  
    
    
  
    
    
  
  
  <style>:root{--md-admonition-icon--note:url('data:image/svg+xml;charset=utf-8,<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 16 16"><path d="M1 7.775V2.75C1 1.784 1.784 1 2.75 1h5.025c.464 0 .91.184 1.238.513l6.25 6.25a1.75 1.75 0 0 1 0 2.474l-5.026 5.026a1.75 1.75 0 0 1-2.474 0l-6.25-6.25A1.75 1.75 0 0 1 1 7.775m1.5 0c0 .066.026.13.073.177l6.25 6.25a.25.25 0 0 0 .354 0l5.025-5.025a.25.25 0 0 0 0-.354l-6.25-6.25a.25.25 0 0 0-.177-.073H2.75a.25.25 0 0 0-.25.25ZM6 5a1 1 0 1 1 0 2 1 1 0 0 1 0-2"/></svg>');--md-admonition-icon--abstract:url('data:image/svg+xml;charset=utf-8,<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 16 16"><path d="M2.5 1.75v11.5c0 .138.112.25.25.25h3.17a.75.75 0 0 1 0 1.5H2.75A1.75 1.75 0 0 1 1 13.25V1.75C1 .784 1.784 0 2.75 0h8.5C12.216 0 13 .784 13 1.75v7.736a.75.75 0 0 1-1.5 0V1.75a.25.25 0 0 0-.25-.25h-8.5a.25.25 0 0 0-.25.25m13.274 9.537zl-4.557 4.45a.75.75 0 0 1-1.055-.008l-1.943-1.95a.75.75 0 0 1 1.062-1.058l1.419 1.425 4.026-3.932a.75.75 0 1 1 1.048 1.074M4.75 4h4.5a.75.75 0 0 1 0 1.5h-4.5a.75.75 0 0 1 0-1.5M4 7.75A.75.75 0 0 1 4.75 7h2a.75.75 0 0 1 0 1.5h-2A.75.75 0 0 1 4 7.75"/></svg>');--md-admonition-icon--info:url('data:image/svg+xml;charset=utf-8,<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 16 16"><path d="M0 8a8 8 0 1 1 16 0A8 8 0 0 1 0 8m8-6.5a6.5 6.5 0 1 0 0 13 6.5 6.5 0 0 0 0-13M6.5 7.75A.75.75 0 0 1 7.25 7h1a.75.75 0 0 1 .75.75v2.75h.25a.75.75 0 0 1 0 1.5h-2a.75.75 0 0 1 0-1.5h.25v-2h-.25a.75.75 0 0 1-.75-.75M8 6a1 1 0 1 1 0-2 1 1 0 0 1 0 2"/></svg>');--md-admonition-icon--tip:url('data:image/svg+xml;charset=utf-8,<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 16 16"><path d="M3.499.75a.75.75 0 0 1 1.5 0v.996C5.9 2.903 6.793 3.65 7.662 4.376l.24.202c-.036-.694.055-1.422.426-2.163C9.1.873 10.794-.045 12.622.26 14.408.558 16 1.94 16 4.25c0 1.278-.954 2.575-2.44 2.734l.146.508.065.22c.203.701.412 1.455.476 2.226.142 1.707-.4 3.03-1.487 3.898C11.714 14.671 10.27 15 8.75 15h-6a.75.75 0 0 1 0-1.5h1.376a4.5 4.5 0 0 1-.563-1.191 3.84 3.84 0 0 1-.05-2.063 4.65 4.65 0 0 1-2.025-.293.75.75 0 0 1 .525-1.406c1.357.507 2.376-.006 2.698-.318l.009-.01a.747.747 0 0 1 1.06 0 .75.75 0 0 1-.012 1.074c-.912.92-.992 1.835-.768 2.586.221.74.745 1.337 1.196 1.621H8.75c1.343 0 2.398-.296 3.074-.836.635-.507 1.036-1.31.928-2.602-.05-.603-.216-1.224-.422-1.93l-.064-.221c-.12-.407-.246-.84-.353-1.29a2.4 2.4 0 0 1-.507-.441 3.1 3.1 0 0 1-.633-1.248.75.75 0 0 1 1.455-.364c.046.185.144.436.31.627.146.168.353.305.712.305.738 0 1.25-.615 1.25-1.25 0-1.47-.95-2.315-2.123-2.51-1.172-.196-2.227.387-2.706 1.345-.46.92-.27 1.774.019 3.062l.042.19.01.05c.348.443.666.949.94 1.553a.75.75 0 1 1-1.365.62c-.553-1.217-1.32-1.94-2.3-2.768L6.7 5.527c-.814-.68-1.75-1.462-2.692-2.619a3.7 3.7 0 0 0-1.023.88c-.406.495-.663 1.036-.722 1.508.116.122.306.21.591.239.388.038.797-.06 1.032-.19a.75.75 0 0 1 .728 1.31c-.515.287-1.23.439-1.906.373-.682-.067-1.473-.38-1.879-1.193L.75 5.677V5.5c0-.984.48-1.94 1.077-2.664.46-.559 1.05-1.055 1.673-1.353z"/></svg>');--md-admonition-icon--success:url('data:image/svg+xml;charset=utf-8,<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 16 16"><path d="M13.78 4.22a.75.75 0 0 1 0 1.06l-7.25 7.25a.75.75 0 0 1-1.06 0L2.22 9.28a.75.75 0 0 1 .018-1.042.75.75 0 0 1 1.042-.018L6 10.94l6.72-6.72a.75.75 0 0 1 1.06 0"/></svg>');--md-admonition-icon--question:url('data:image/svg+xml;charset=utf-8,<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 16 16"><path d="M0 8a8 8 0 1 1 16 0A8 8 0 0 1 0 8m8-6.5a6.5 6.5 0 1 0 0 13 6.5 6.5 0 0 0 0-13M6.92 6.085h.001a.749.749 0 1 1-1.342-.67c.169-.339.436-.701.849-.977C6.845 4.16 7.369 4 8 4a2.76 2.76 0 0 1 1.637.525c.503.377.863.965.863 1.725 0 .448-.115.83-.329 1.15-.205.307-.47.513-.692.662-.109.072-.22.138-.313.195l-.006.004a6 6 0 0 0-.26.16 1 1 0 0 0-.276.245.75.75 0 0 1-1.248-.832c.184-.264.42-.489.692-.661q.154-.1.313-.195l.007-.004c.1-.061.182-.11.258-.161a1 1 0 0 0 .277-.245C8.96 6.514 9 6.427 9 6.25a.61.61 0 0 0-.262-.525A1.27 1.27 0 0 0 8 5.5c-.369 0-.595.09-.74.187a1 1 0 0 0-.34.398M9 11a1 1 0 1 1-2 0 1 1 0 0 1 2 0"/></svg>');--md-admonition-icon--warning:url('data:image/svg+xml;charset=utf-8,<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 16 16"><path d="M6.457 1.047c.659-1.234 2.427-1.234 3.086 0l6.082 11.378A1.75 1.75 0 0 1 14.082 15H1.918a1.75 1.75 0 0 1-1.543-2.575Zm1.763.707a.25.25 0 0 0-.44 0L1.698 13.132a.25.25 0 0 0 .22.368h12.164a.25.25 0 0 0 .22-.368Zm.53 3.996v2.5a.75.75 0 0 1-1.5 0v-2.5a.75.75 0 0 1 1.5 0M9 11a1 1 0 1 1-2 0 1 1 0 0 1 2 0"/></svg>');--md-admonition-icon--failure:url('data:image/svg+xml;charset=utf-8,<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 16 16"><path d="M2.344 2.343za8 8 0 0 1 11.314 11.314A8.002 8.002 0 0 1 .234 10.089a8 8 0 0 1 2.11-7.746m1.06 10.253a6.5 6.5 0 1 0 9.108-9.275 6.5 6.5 0 0 0-9.108 9.275M6.03 4.97 8 6.94l1.97-1.97a.749.749 0 0 1 1.275.326.75.75 0 0 1-.215.734L9.06 8l1.97 1.97a.749.749 0 0 1-.326 1.275.75.75 0 0 1-.734-.215L8 9.06l-1.97 1.97a.749.749 0 0 1-1.275-.326.75.75 0 0 1 .215-.734L6.94 8 4.97 6.03a.75.75 0 0 1 .018-1.042.75.75 0 0 1 1.042-.018"/></svg>');--md-admonition-icon--danger:url('data:image/svg+xml;charset=utf-8,<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 16 16"><path d="M9.504.43a1.516 1.516 0 0 1 2.437 1.713L10.415 5.5h2.123c1.57 0 2.346 1.909 1.22 3.004l-7.34 7.142a1.25 1.25 0 0 1-.871.354h-.302a1.25 1.25 0 0 1-1.157-1.723L5.633 10.5H3.462c-1.57 0-2.346-1.909-1.22-3.004zm1.047 1.074L3.286 8.571A.25.25 0 0 0 3.462 9H6.75a.75.75 0 0 1 .694 1.034l-1.713 4.188 6.982-6.793A.25.25 0 0 0 12.538 7H9.25a.75.75 0 0 1-.683-1.06l2.008-4.418.003-.006-.004-.009-.006-.006-.008-.001q-.005 0-.009.004"/></svg>');--md-admonition-icon--bug:url('data:image/svg+xml;charset=utf-8,<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 16 16"><path d="M4.72.22a.75.75 0 0 1 1.06 0l1 .999a3.5 3.5 0 0 1 2.441 0l.999-1a.748.748 0 0 1 1.265.332.75.75 0 0 1-.205.729l-.775.776c.616.63.995 1.493.995 2.444v.327q0 .15-.025.292c.408.14.764.392 1.029.722l1.968-.787a.75.75 0 0 1 .556 1.392L13 7.258V9h2.25a.75.75 0 0 1 0 1.5H13v.5q-.002.615-.141 1.186l2.17.868a.75.75 0 0 1-.557 1.392l-2.184-.873A5 5 0 0 1 8 16a5 5 0 0 1-4.288-2.427l-2.183.873a.75.75 0 0 1-.558-1.392l2.17-.868A5 5 0 0 1 3 11v-.5H.75a.75.75 0 0 1 0-1.5H3V7.258L.971 6.446a.75.75 0 0 1 .558-1.392l1.967.787c.265-.33.62-.583 1.03-.722a1.7 1.7 0 0 1-.026-.292V4.5c0-.951.38-1.814.995-2.444L4.72 1.28a.75.75 0 0 1 0-1.06m.53 6.28a.75.75 0 0 0-.75.75V11a3.5 3.5 0 1 0 7 0V7.25a.75.75 0 0 0-.75-.75ZM6.173 5h3.654A.17.17 0 0 0 10 4.827V4.5a2 2 0 1 0-4 0v.327c0 .096.077.173.173.173"/></svg>');--md-admonition-icon--example:url('data:image/svg+xml;charset=utf-8,<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 16 16"><path d="M5 5.782V2.5h-.25a.75.75 0 0 1 0-1.5h6.5a.75.75 0 0 1 0 1.5H11v3.282l3.666 5.76C15.619 13.04 14.543 15 12.767 15H3.233c-1.776 0-2.852-1.96-1.899-3.458Zm-2.4 6.565a.75.75 0 0 0 .633 1.153h9.534a.75.75 0 0 0 .633-1.153L12.225 10.5h-8.45ZM9.5 2.5h-3V6c0 .143-.04.283-.117.403L4.73 9h6.54L9.617 6.403A.75.75 0 0 1 9.5 6Z"/></svg>');--md-admonition-icon--quote:url('data:image/svg+xml;charset=utf-8,<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 16 16"><path d="M1.75 2.5h10.5a.75.75 0 0 1 0 1.5H1.75a.75.75 0 0 1 0-1.5m4 5h8.5a.75.75 0 0 1 0 1.5h-8.5a.75.75 0 0 1 0-1.5m0 5h8.5a.75.75 0 0 1 0 1.5h-8.5a.75.75 0 0 1 0-1.5M2.5 7.75v6a.75.75 0 0 1-1.5 0v-6a.75.75 0 0 1 1.5 0"/></svg>');}</style>



    
    
      
    
    
      
        
        
        <link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
        <link rel="stylesheet" href="https://fonts.googleapis.com/css?family=Roboto:300,300i,400,400i,700,700i%7CRoboto+Mono:400,400i,700,700i&display=fallback">
        <style>:root{--md-text-font:"Roboto";--md-code-font:"Roboto Mono"}</style>
      
    
    
      <link rel="stylesheet" href="../assets/_mkdocstrings.css">
    
      <link rel="stylesheet" href="../_static/custom_css.css">
    
      <link rel="stylesheet" href="../css/ansi-colours.css">
    
      <link rel="stylesheet" href="../css/jupyter-cells.css">
    
      <link rel="stylesheet" href="../css/pandas-dataframe.css">
    
    <script>__md_scope=new URL("..",location),__md_hash=e=>[...e].reduce(((e,_)=>(e<<5)-e+_.charCodeAt(0)),0),__md_get=(e,_=localStorage,t=__md_scope)=>JSON.parse(_.getItem(t.pathname+"."+e)),__md_set=(e,_,t=localStorage,a=__md_scope)=>{try{t.setItem(a.pathname+"."+e,JSON.stringify(_))}catch(e){}}</script>
    
      

    
    
    
  </head>
  
  
    
    
      
    
    
    
    
    <body dir="ltr" data-md-color-scheme="default" data-md-color-primary="white" data-md-color-accent="amber">
  
    
    <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer" autocomplete="off">
    <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search" autocomplete="off">
    <label class="md-overlay" for="__drawer"></label>
    <div data-md-component="skip">
      
        
        <a href="#defining-nodes-in-rex-robotic-environments-with-jax" class="md-skip">
          Skip to content
        </a>
      
    </div>
    <div data-md-component="announce">
      
    </div>
    
    
      

  

<header class="md-header md-header--shadow" data-md-component="header">
  <nav class="md-header__inner md-grid" aria-label="Header">
    <a href="../index.html" title="Rex" class="md-header__button md-logo" aria-label="Rex" data-md-component="logo">
      
  
  <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M13 2v1h-1v6h-1v1H9v1H8v1H7v1H5v-1H4v-1H3V9H2v6h1v1h1v1h1v1h1v4h2v-1H7v-1h1v-1h1v-1h1v1h1v3h2v-1h-1v-4h1v-1h1v-1h1v-3h1v1h1v-2h-2V9h5V8h-3V7h5V3h-1V2m-7 1h1v1h-1Z"/></svg>

    </a>
    <label class="md-header__button md-icon" for="__drawer">
      
      <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M3 6h18v2H3zm0 5h18v2H3zm0 5h18v2H3z"/></svg>
    </label>
    <div class="md-header__title" data-md-component="header-title">
      <div class="md-header__ellipsis">
        <div class="md-header__topic">
          <span class="md-ellipsis">
            Rex
          </span>
        </div>
        <div class="md-header__topic" data-md-component="header-topic">
          <span class="md-ellipsis">
            
              How to define nodes
            
          </span>
        </div>
      </div>
    </div>
    
      
        <form class="md-header__option" data-md-component="palette">
  
    
    
    
    <input class="md-option" data-md-color-media="" data-md-color-scheme="default" data-md-color-primary="white" data-md-color-accent="amber"  aria-label="Switch to dark mode"  type="radio" name="__palette" id="__palette_0">
    
      <label class="md-header__button md-icon" title="Switch to dark mode" for="__palette_1" hidden>
        <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="m17.75 4.09-2.53 1.94.91 3.06-2.63-1.81-2.63 1.81.91-3.06-2.53-1.94L12.44 4l1.06-3 1.06 3zm3.5 6.91-1.64 1.25.59 1.98-1.7-1.17-1.7 1.17.59-1.98L15.75 11l2.06-.05L18.5 9l.69 1.95zm-2.28 4.95c.83-.08 1.72 1.1 1.19 1.85-.32.45-.66.87-1.08 1.27C15.17 23 8.84 23 4.94 19.07c-3.91-3.9-3.91-10.24 0-14.14.4-.4.82-.76 1.27-1.08.75-.53 1.93.36 1.85 1.19-.27 2.86.69 5.83 2.89 8.02a9.96 9.96 0 0 0 8.02 2.89m-1.64 2.02a12.08 12.08 0 0 1-7.8-3.47c-2.17-2.19-3.33-5-3.49-7.82-2.81 3.14-2.7 7.96.31 10.98 3.02 3.01 7.84 3.12 10.98.31"/></svg>
      </label>
    
  
    
    
    
    <input class="md-option" data-md-color-media="" data-md-color-scheme="slate" data-md-color-primary="black" data-md-color-accent="amber"  aria-label="Switch to light mode"  type="radio" name="__palette" id="__palette_1">
    
      <label class="md-header__button md-icon" title="Switch to light mode" for="__palette_0" hidden>
        <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M12 7a5 5 0 0 1 5 5 5 5 0 0 1-5 5 5 5 0 0 1-5-5 5 5 0 0 1 5-5m0 2a3 3 0 0 0-3 3 3 3 0 0 0 3 3 3 3 0 0 0 3-3 3 3 0 0 0-3-3m0-7 2.39 3.42C13.65 5.15 12.84 5 12 5s-1.65.15-2.39.42zM3.34 7l4.16-.35A7.2 7.2 0 0 0 5.94 8.5c-.44.74-.69 1.5-.83 2.29zm.02 10 1.76-3.77a7.131 7.131 0 0 0 2.38 4.14zM20.65 7l-1.77 3.79a7.02 7.02 0 0 0-2.38-4.15zm-.01 10-4.14.36c.59-.51 1.12-1.14 1.54-1.86.42-.73.69-1.5.83-2.29zM12 22l-2.41-3.44c.74.27 1.55.44 2.41.44.82 0 1.63-.17 2.37-.44z"/></svg>
      </label>
    
  
</form>
      
    
    
      <script>var palette=__md_get("__palette");if(palette&&palette.color){if("(prefers-color-scheme)"===palette.color.media){var media=matchMedia("(prefers-color-scheme: light)"),input=document.querySelector(media.matches?"[data-md-color-media='(prefers-color-scheme: light)']":"[data-md-color-media='(prefers-color-scheme: dark)']");palette.color.media=input.getAttribute("data-md-color-media"),palette.color.scheme=input.getAttribute("data-md-color-scheme"),palette.color.primary=input.getAttribute("data-md-color-primary"),palette.color.accent=input.getAttribute("data-md-color-accent")}for(var[key,value]of Object.entries(palette.color))document.body.setAttribute("data-md-color-"+key,value)}</script>
    
    
    
      <label class="md-header__button md-icon" for="__search">
        
        <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M9.5 3A6.5 6.5 0 0 1 16 9.5c0 1.61-.59 3.09-1.56 4.23l.27.27h.79l5 5-1.5 1.5-5-5v-.79l-.27-.27A6.52 6.52 0 0 1 9.5 16 6.5 6.5 0 0 1 3 9.5 6.5 6.5 0 0 1 9.5 3m0 2C7 5 5 7 5 9.5S7 14 9.5 14 14 12 14 9.5 12 5 9.5 5"/></svg>
      </label>
      <div class="md-search" data-md-component="search" role="dialog">
  <label class="md-search__overlay" for="__search"></label>
  <div class="md-search__inner" role="search">
    <form class="md-search__form" name="search">
      <input type="text" class="md-search__input" name="query" aria-label="Search" placeholder="Search" autocapitalize="off" autocorrect="off" autocomplete="off" spellcheck="false" data-md-component="search-query" required>
      <label class="md-search__icon md-icon" for="__search">
        
        <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M9.5 3A6.5 6.5 0 0 1 16 9.5c0 1.61-.59 3.09-1.56 4.23l.27.27h.79l5 5-1.5 1.5-5-5v-.79l-.27-.27A6.52 6.52 0 0 1 9.5 16 6.5 6.5 0 0 1 3 9.5 6.5 6.5 0 0 1 9.5 3m0 2C7 5 5 7 5 9.5S7 14 9.5 14 14 12 14 9.5 12 5 9.5 5"/></svg>
        
        <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M20 11v2H8l5.5 5.5-1.42 1.42L4.16 12l7.92-7.92L13.5 5.5 8 11z"/></svg>
      </label>
      <nav class="md-search__options" aria-label="Search">
        
        <button type="reset" class="md-search__icon md-icon" title="Clear" aria-label="Clear" tabindex="-1">
          
          <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M19 6.41 17.59 5 12 10.59 6.41 5 5 6.41 10.59 12 5 17.59 6.41 19 12 13.41 17.59 19 19 17.59 13.41 12z"/></svg>
        </button>
      </nav>
      
    </form>
    <div class="md-search__output">
      <div class="md-search__scrollwrap" tabindex="0" data-md-scrollfix>
        <div class="md-search-result" data-md-component="search-result">
          <div class="md-search-result__meta">
            Initializing search
          </div>
          <ol class="md-search-result__list" role="presentation"></ol>
        </div>
      </div>
    </div>
  </div>
</div>
    
    
      <div class="md-header__source">
        
<a href="https://github.com/anonymous/rex" title="source.link.title" class="md-source" data-md-component="source">
  <div class="md-source__icon md-icon">
    
    <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 496 512"><!--! Font Awesome Free 6.7.1 by @fontawesome - https://fontawesome.com License - https://fontawesome.com/license/free (Icons: CC BY 4.0, Fonts: SIL OFL 1.1, Code: MIT License) Copyright 2024 Fonticons, Inc.--><path d="M165.9 397.4c0 2-2.3 3.6-5.2 3.6-3.3.3-5.6-1.3-5.6-3.6 0-2 2.3-3.6 5.2-3.6 3-.3 5.6 1.3 5.6 3.6m-31.1-4.5c-.7 2 1.3 4.3 4.3 4.9 2.6 1 5.6 0 6.2-2s-1.3-4.3-4.3-5.2c-2.6-.7-5.5.3-6.2 2.3m44.2-1.7c-2.9.7-4.9 2.6-4.6 4.9.3 2 2.9 3.3 5.9 2.6 2.9-.7 4.9-2.6 4.6-4.6-.3-1.9-3-3.2-5.9-2.9M244.8 8C106.1 8 0 113.3 0 252c0 110.9 69.8 205.8 169.5 239.2 12.8 2.3 17.3-5.6 17.3-12.1 0-6.2-.3-40.4-.3-61.4 0 0-70 15-84.7-29.8 0 0-11.4-29.1-27.8-36.6 0 0-22.9-15.7 1.6-15.4 0 0 24.9 2 38.6 25.8 21.9 38.6 58.6 27.5 72.9 20.9 2.3-16 8.8-27.1 16-33.7-55.9-6.2-112.3-14.3-112.3-110.5 0-27.5 7.6-41.3 23.6-58.9-2.6-6.5-11.1-33.3 2.6-67.9 20.9-6.5 69 27 69 27 20-5.6 41.5-8.5 62.8-8.5s42.8 2.9 62.8 8.5c0 0 48.1-33.6 69-27 13.7 34.7 5.2 61.4 2.6 67.9 16 17.7 25.8 31.5 25.8 58.9 0 96.5-58.9 104.2-114.8 110.5 9.2 7.9 17 22.9 17 46.4 0 33.7-.3 75.4-.3 83.6 0 6.5 4.6 14.4 17.3 12.1C428.2 457.8 496 362.9 496 252 496 113.3 383.5 8 244.8 8M97.2 352.9c-1.3 1-1 3.3.7 5.2 1.6 1.6 3.9 2.3 5.2 1 1.3-1 1-3.3-.7-5.2-1.6-1.6-3.9-2.3-5.2-1m-10.8-8.1c-.7 1.3.3 2.9 2.3 3.9 1.6 1 3.6.7 4.3-.7.7-1.3-.3-2.9-2.3-3.9-2-.6-3.6-.3-4.3.7m32.4 35.6c-1.6 1.3-1 4.3 1.3 6.2 2.3 2.3 5.2 2.6 6.5 1 1.3-1.3.7-4.3-1.3-6.2-2.2-2.3-5.2-2.6-6.5-1m-11.4-14.7c-1.6 1-1.6 3.6 0 5.9s4.3 3.3 5.6 2.3c1.6-1.3 1.6-3.9 0-6.2-1.4-2.3-4-3.3-5.6-2"/></svg>
  </div>
  <div class="md-source__repository">
    anonymous/rex
  </div>
</a>

      </div>
    
  </nav>
  
</header>
    
    <div class="md-container" data-md-component="container">
      
      
        
          
        
      
      <main class="md-main" data-md-component="main">
        <div class="md-main__inner md-grid">
          
            
              
              <div class="md-sidebar md-sidebar--primary" data-md-component="sidebar" data-md-type="navigation" >
                <div class="md-sidebar__scrollwrap">
                  <div class="md-sidebar__inner">
                    



  

<nav class="md-nav md-nav--primary md-nav--integrated" aria-label="Navigation" data-md-level="0">
  <label class="md-nav__title" for="__drawer">
    <a href="../index.html" title="Rex" class="md-nav__button md-logo" aria-label="Rex" data-md-component="logo">
      
  
  <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M13 2v1h-1v6h-1v1H9v1H8v1H7v1H5v-1H4v-1H3V9H2v6h1v1h1v1h1v1h1v4h2v-1H7v-1h1v-1h1v-1h1v1h1v3h2v-1h-1v-4h1v-1h1v-1h1v-3h1v1h1v-2h-2V9h5V8h-3V7h5V3h-1V2m-7 1h1v1h-1Z"/></svg>

    </a>
    Rex
  </label>
  
    <div class="md-nav__source">
      
<a href="https://github.com/anonymous/rex" title="source.link.title" class="md-source" data-md-component="source">
  <div class="md-source__icon md-icon">
    
    <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 496 512"><!--! Font Awesome Free 6.7.1 by @fontawesome - https://fontawesome.com License - https://fontawesome.com/license/free (Icons: CC BY 4.0, Fonts: SIL OFL 1.1, Code: MIT License) Copyright 2024 Fonticons, Inc.--><path d="M165.9 397.4c0 2-2.3 3.6-5.2 3.6-3.3.3-5.6-1.3-5.6-3.6 0-2 2.3-3.6 5.2-3.6 3-.3 5.6 1.3 5.6 3.6m-31.1-4.5c-.7 2 1.3 4.3 4.3 4.9 2.6 1 5.6 0 6.2-2s-1.3-4.3-4.3-5.2c-2.6-.7-5.5.3-6.2 2.3m44.2-1.7c-2.9.7-4.9 2.6-4.6 4.9.3 2 2.9 3.3 5.9 2.6 2.9-.7 4.9-2.6 4.6-4.6-.3-1.9-3-3.2-5.9-2.9M244.8 8C106.1 8 0 113.3 0 252c0 110.9 69.8 205.8 169.5 239.2 12.8 2.3 17.3-5.6 17.3-12.1 0-6.2-.3-40.4-.3-61.4 0 0-70 15-84.7-29.8 0 0-11.4-29.1-27.8-36.6 0 0-22.9-15.7 1.6-15.4 0 0 24.9 2 38.6 25.8 21.9 38.6 58.6 27.5 72.9 20.9 2.3-16 8.8-27.1 16-33.7-55.9-6.2-112.3-14.3-112.3-110.5 0-27.5 7.6-41.3 23.6-58.9-2.6-6.5-11.1-33.3 2.6-67.9 20.9-6.5 69 27 69 27 20-5.6 41.5-8.5 62.8-8.5s42.8 2.9 62.8 8.5c0 0 48.1-33.6 69-27 13.7 34.7 5.2 61.4 2.6 67.9 16 17.7 25.8 31.5 25.8 58.9 0 96.5-58.9 104.2-114.8 110.5 9.2 7.9 17 22.9 17 46.4 0 33.7-.3 75.4-.3 83.6 0 6.5 4.6 14.4 17.3 12.1C428.2 457.8 496 362.9 496 252 496 113.3 383.5 8 244.8 8M97.2 352.9c-1.3 1-1 3.3.7 5.2 1.6 1.6 3.9 2.3 5.2 1 1.3-1 1-3.3-.7-5.2-1.6-1.6-3.9-2.3-5.2-1m-10.8-8.1c-.7 1.3.3 2.9 2.3 3.9 1.6 1 3.6.7 4.3-.7.7-1.3-.3-2.9-2.3-3.9-2-.6-3.6-.3-4.3.7m32.4 35.6c-1.6 1.3-1 4.3 1.3 6.2 2.3 2.3 5.2 2.6 6.5 1 1.3-1.3.7-4.3-1.3-6.2-2.2-2.3-5.2-2.6-6.5-1m-11.4-14.7c-1.6 1-1.6 3.6 0 5.9s4.3 3.3 5.6 2.3c1.6-1.3 1.6-3.9 0-6.2-1.4-2.3-4-3.3-5.6-2"/></svg>
  </div>
  <div class="md-source__repository">
    anonymous/rex
  </div>
</a>

    </div>
  
  <ul class="md-nav__list" data-md-scrollfix>
    
      
      
  
  
  
  
    <li class="md-nav__item">
      <a href="../index.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    Getting Started
  </span>
  

      </a>
    </li>
  

    
      
      
  
  
    
  
  
  
    
    
    
      
        
        
      
    
    
    <li class="md-nav__item md-nav__item--active md-nav__item--section md-nav__item--nested">
      
        
        
        <input class="md-nav__toggle md-toggle " type="checkbox" id="__nav_2" checked>
        
          
          <label class="md-nav__link" for="__nav_2" id="__nav_2_label" tabindex="">
            
  
  <span class="md-ellipsis">
    Examples
  </span>
  

            <span class="md-nav__icon md-icon"></span>
          </label>
        
        <nav class="md-nav" data-md-level="1" aria-labelledby="__nav_2_label" aria-expanded="true">
          <label class="md-nav__title" for="__nav_2">
            <span class="md-nav__icon md-icon"></span>
            Examples
          </label>
          <ul class="md-nav__list" data-md-scrollfix>
            
              
                
  
  
    
  
  
  
    
    
    
      
    
    
    <li class="md-nav__item md-nav__item--active md-nav__item--nested">
      
        
        
        <input class="md-nav__toggle md-toggle " type="checkbox" id="__nav_2_1" checked>
        
          
          <label class="md-nav__link" for="__nav_2_1" id="__nav_2_1_label" tabindex="0">
            
  
  <span class="md-ellipsis">
    Introductory
  </span>
  

            <span class="md-nav__icon md-icon"></span>
          </label>
        
        <nav class="md-nav" data-md-level="2" aria-labelledby="__nav_2_1_label" aria-expanded="true">
          <label class="md-nav__title" for="__nav_2_1">
            <span class="md-nav__icon md-icon"></span>
            Introductory
          </label>
          <ul class="md-nav__list" data-md-scrollfix>
            
              
                
  
  
    
  
  
  
    <li class="md-nav__item md-nav__item--active">
      
      <input class="md-nav__toggle md-toggle" type="checkbox" id="__toc">
      
      
        
      
      
      <a href="node_definitions.html" class="md-nav__link md-nav__link--active">
        
  
  <span class="md-ellipsis">
    How to define nodes
  </span>
  

      </a>
      
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="graph_and_environment_creation.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    Graphs and environments
  </span>
  

      </a>
    </li>
  

              
            
          </ul>
        </nav>
      
    </li>
  

              
            
              
                
  
  
  
  
    
    
    
      
    
    
    <li class="md-nav__item md-nav__item--nested">
      
        
        
        <input class="md-nav__toggle md-toggle " type="checkbox" id="__nav_2_2" >
        
          
          <label class="md-nav__link" for="__nav_2_2" id="__nav_2_2_label" tabindex="0">
            
  
  <span class="md-ellipsis">
    Advanced
  </span>
  

            <span class="md-nav__icon md-icon"></span>
          </label>
        
        <nav class="md-nav" data-md-level="2" aria-labelledby="__nav_2_2_label" aria-expanded="false">
          <label class="md-nav__title" for="__nav_2_2">
            <span class="md-nav__icon md-icon"></span>
            Advanced
          </label>
          <ul class="md-nav__list" data-md-scrollfix>
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="sim2real.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    Sim2real with a pendulum
  </span>
  

      </a>
    </li>
  

              
            
          </ul>
        </nav>
      
    </li>
  

              
            
          </ul>
        </nav>
      
    </li>
  

    
      
      
  
  
  
  
    
    
    
      
        
        
      
    
    
    <li class="md-nav__item md-nav__item--section md-nav__item--nested">
      
        
        
        <input class="md-nav__toggle md-toggle " type="checkbox" id="__nav_3" >
        
          
          <label class="md-nav__link" for="__nav_3" id="__nav_3_label" tabindex="">
            
  
  <span class="md-ellipsis">
    Usage
  </span>
  

            <span class="md-nav__icon md-icon"></span>
          </label>
        
        <nav class="md-nav" data-md-level="1" aria-labelledby="__nav_3_label" aria-expanded="false">
          <label class="md-nav__title" for="__nav_3">
            <span class="md-nav__icon md-icon"></span>
            Usage
          </label>
          <ul class="md-nav__list" data-md-scrollfix>
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="../api/base.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    Base
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="../api/node.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    Node
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    
    
    
      
    
    
    <li class="md-nav__item md-nav__item--nested">
      
        
        
        <input class="md-nav__toggle md-toggle " type="checkbox" id="__nav_3_3" >
        
          
          <label class="md-nav__link" for="__nav_3_3" id="__nav_3_3_label" tabindex="0">
            
  
  <span class="md-ellipsis">
    Graphs
  </span>
  

            <span class="md-nav__icon md-icon"></span>
          </label>
        
        <nav class="md-nav" data-md-level="2" aria-labelledby="__nav_3_3_label" aria-expanded="false">
          <label class="md-nav__title" for="__nav_3_3">
            <span class="md-nav__icon md-icon"></span>
            Graphs
          </label>
          <ul class="md-nav__list" data-md-scrollfix>
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="../api/asynchronous.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    Asynchronous
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="../api/compiled.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    Compiled
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="../api/artificial.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    Artificial
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="../api/record.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    Record
  </span>
  

      </a>
    </li>
  

              
            
          </ul>
        </nav>
      
    </li>
  

              
            
              
                
  
  
  
  
    
    
    
      
    
    
    <li class="md-nav__item md-nav__item--nested">
      
        
        
        <input class="md-nav__toggle md-toggle " type="checkbox" id="__nav_3_4" >
        
          
          <label class="md-nav__link" for="__nav_3_4" id="__nav_3_4_label" tabindex="0">
            
  
  <span class="md-ellipsis">
    Delays
  </span>
  

            <span class="md-nav__icon md-icon"></span>
          </label>
        
        <nav class="md-nav" data-md-level="2" aria-labelledby="__nav_3_4_label" aria-expanded="false">
          <label class="md-nav__title" for="__nav_3_4">
            <span class="md-nav__icon md-icon"></span>
            Delays
          </label>
          <ul class="md-nav__list" data-md-scrollfix>
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="../api/delays.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    Delays
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="../api/gmm_estimator.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    Gmm estimator
  </span>
  

      </a>
    </li>
  

              
            
          </ul>
        </nav>
      
    </li>
  

              
            
              
                
  
  
  
  
    
    
    
      
    
    
    <li class="md-nav__item md-nav__item--nested">
      
        
        
        <input class="md-nav__toggle md-toggle " type="checkbox" id="__nav_3_5" >
        
          
          <label class="md-nav__link" for="__nav_3_5" id="__nav_3_5_label" tabindex="0">
            
  
  <span class="md-ellipsis">
    System identification
  </span>
  

            <span class="md-nav__icon md-icon"></span>
          </label>
        
        <nav class="md-nav" data-md-level="2" aria-labelledby="__nav_3_5_label" aria-expanded="false">
          <label class="md-nav__title" for="__nav_3_5">
            <span class="md-nav__icon md-icon"></span>
            System identification
          </label>
          <ul class="md-nav__list" data-md-scrollfix>
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="../api/evosax.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    Evosax
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="../api/cem.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    Cross-Entropy Method
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="../api/transforms.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    Transforms
  </span>
  

      </a>
    </li>
  

              
            
          </ul>
        </nav>
      
    </li>
  

              
            
              
                
  
  
  
  
    
    
    
      
    
    
    <li class="md-nav__item md-nav__item--nested">
      
        
        
        <input class="md-nav__toggle md-toggle " type="checkbox" id="__nav_3_6" >
        
          
          <label class="md-nav__link" for="__nav_3_6" id="__nav_3_6_label" tabindex="0">
            
  
  <span class="md-ellipsis">
    Reinforcement learning
  </span>
  

            <span class="md-nav__icon md-icon"></span>
          </label>
        
        <nav class="md-nav" data-md-level="2" aria-labelledby="__nav_3_6_label" aria-expanded="false">
          <label class="md-nav__title" for="__nav_3_6">
            <span class="md-nav__icon md-icon"></span>
            Reinforcement learning
          </label>
          <ul class="md-nav__list" data-md-scrollfix>
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="../api/environment.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    Environment and Wrappers
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="../api/ppo.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    Proximal Policy Optimization
  </span>
  

      </a>
    </li>
  

              
            
          </ul>
        </nav>
      
    </li>
  

              
            
              
                
  
  
  
  
    
    
    
      
    
    
    <li class="md-nav__item md-nav__item--nested">
      
        
        
        <input class="md-nav__toggle md-toggle " type="checkbox" id="__nav_3_7" >
        
          
          <label class="md-nav__link" for="__nav_3_7" id="__nav_3_7_label" tabindex="0">
            
  
  <span class="md-ellipsis">
    Misc
  </span>
  

            <span class="md-nav__icon md-icon"></span>
          </label>
        
        <nav class="md-nav" data-md-level="2" aria-labelledby="__nav_3_7_label" aria-expanded="false">
          <label class="md-nav__title" for="__nav_3_7">
            <span class="md-nav__icon md-icon"></span>
            Misc
          </label>
          <ul class="md-nav__list" data-md-scrollfix>
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="../citation.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    Citation
  </span>
  

      </a>
    </li>
  

              
            
          </ul>
        </nav>
      
    </li>
  

              
            
          </ul>
        </nav>
      
    </li>
  

    
  </ul>
</nav>
                  </div>
                </div>
              </div>
            
            
          
          
            <div class="md-content" data-md-component="content">
              <article class="md-content__inner md-typeset">
                
                  


<script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/2.0.3/jquery.min.js"></script>
<script>
(function() {
  function addWidgetsRenderer() {
    var requireJsScript = document.createElement('script');
    requireJsScript.src = 'https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.4/require.min.js';

    var mimeElement = document.querySelector('script[type="application/vnd.jupyter.widget-view+json"]');
    var jupyterWidgetsScript = document.createElement('script');
    var widgetRendererSrc = 'https://unpkg.com/@jupyter-widgets/html-manager@*/dist/embed-amd.js';
    var widgetState;

    // Fallback for older version:
    try {
      widgetState = mimeElement && JSON.parse(mimeElement.innerHTML);

      if (widgetState && (widgetState.version_major < 2 || !widgetState.version_major)) {
        widgetRendererSrc = 'jupyter-js-widgets@*/dist/embed.js';
      }
    } catch(e) {}

    jupyterWidgetsScript.src = widgetRendererSrc;

    document.body.appendChild(requireJsScript);
    document.body.appendChild(jupyterWidgetsScript);
  }

  document.addEventListener('DOMContentLoaded', addWidgetsRenderer);
}());
</script>

<div class="cell border-box-sizing text_cell rendered">
<div class="inner_cell">
<div class="text_cell_render border-box-sizing rendered_html">
<h1 id="defining-nodes-in-rex-robotic-environments-with-jax"><h1><center>Defining Nodes in <strong>rex</strong> (Robotic Environments with jaX)  <a href="http://colab.research.google.com/github/anonymous/rex/blob/master/examples/sim2real.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" width="140" align="center"/></a></center></h1><a class="headerlink" href="#defining-nodes-in-rex-robotic-environments-with-jax" title="Permanent link">¤</a></h1>
<p>This notebook offers an introductory tutorial for <strong>rex (Robotic Environments with jaX)</strong>, a <strong>JAX-based framework</strong> for creating <strong>graph-based environments</strong> tailored for <strong>sim2real robotics</strong>.</p>
<p>In this tutorial, we will guide you through the process of defining <strong>nodes</strong>, which are the <strong>fundamental building blocks</strong> for constructing <strong>graph-based simulations</strong> and <strong>real-world systems</strong> within rex. Specifically, we will demonstrate how to define the nodes used in the <strong>sim2real.ipynb</strong> notebook.</p>
</div>
</div>
</div>
<div class="cell border-box-sizing code_cell rendered">
<div class="input">

<div class="highlight"><pre><span></span><code><span class="c1"># @title Install Necessary Libraries</span>
<span class="c1"># @markdown This cell installs the required libraries for the project.</span>
<span class="c1"># @markdown If you are running this notebook in Google Colab, most libraries should already be installed.</span>

<span class="k">try</span><span class="p">:</span>
    <span class="kn">import</span> <span class="nn">rex</span>  <span class="c1"># noqa: F401</span>

    <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Rex already installed&quot;</span><span class="p">)</span>
<span class="k">except</span> <span class="ne">ImportError</span><span class="p">:</span>
    <span class="nb">print</span><span class="p">(</span>
        <span class="s2">&quot;Installing rex via `pip install rex-lib[examples]`. &quot;</span>
        <span class="s2">&quot;If you are running this in a Colab notebook, you can ignore this message.&quot;</span>
    <span class="p">)</span>
    <span class="err">!</span><span class="n">pip</span> <span class="n">install</span> <span class="n">rex</span><span class="o">-</span><span class="n">lib</span><span class="p">[</span><span class="n">examples</span><span class="p">]</span>
</code></pre></div>

</div>
<div class="output_wrapper">
<div class="output">
<div class="output_area">
<div class="output_subarea output_stream output_stdout output_text">
<pre>
<code>Installing rex via `pip install rex-lib[examples]`. If you are running this in a Colab notebook, you can ignore this message.
Collecting rex-lib[examples]
  Downloading rex_lib-0.0.5-py3-none-any.whl.metadata (15 kB)
Collecting dill&gt;=0.3.8 (from rex-lib[examples])
  Downloading dill-0.3.9-py3-none-any.whl.metadata (10 kB)
Collecting distrax&gt;=0.1.5 (from rex-lib[examples])
  Downloading distrax-0.1.5-py3-none-any.whl.metadata (13 kB)
Collecting equinox&gt;=0.11.4 (from rex-lib[examples])
  Downloading equinox-0.11.7-py3-none-any.whl.metadata (18 kB)
Collecting evosax&gt;=0.1.6 (from rex-lib[examples])
  Downloading evosax-0.1.6-py3-none-any.whl.metadata (26 kB)
Requirement already satisfied: flax&gt;=0.8.5 in /usr/local/lib/python3.10/dist-packages (from rex-lib[examples]) (0.8.5)
Collecting gymnasium&gt;=0.29.1 (from rex-lib[examples])
  Downloading gymnasium-1.0.0-py3-none-any.whl.metadata (9.5 kB)
Requirement already satisfied: jax&gt;=0.4.30 in /usr/local/lib/python3.10/dist-packages (from rex-lib[examples]) (0.4.33)
Requirement already satisfied: matplotlib&gt;=3.7.0 in /usr/local/lib/python3.10/dist-packages (from rex-lib[examples]) (3.7.1)
Requirement already satisfied: networkx&gt;=3.2.1 in /usr/local/lib/python3.10/dist-packages (from rex-lib[examples]) (3.3)
Requirement already satisfied: optax&gt;=0.2.3 in /usr/local/lib/python3.10/dist-packages (from rex-lib[examples]) (0.2.3)
Collecting seaborn&gt;=0.13.2 (from rex-lib[examples])
  Downloading seaborn-0.13.2-py3-none-any.whl.metadata (5.4 kB)
Collecting supergraph&gt;=0.0.8 (from rex-lib[examples])
  Downloading supergraph-0.0.8-py3-none-any.whl.metadata (1.2 kB)
Requirement already satisfied: termcolor&gt;=2.4.0 in /usr/local/lib/python3.10/dist-packages (from rex-lib[examples]) (2.4.0)
Requirement already satisfied: tqdm&gt;=4.66.4 in /usr/local/lib/python3.10/dist-packages (from rex-lib[examples]) (4.66.5)
Collecting brax&gt;=0.10.5 (from rex-lib[examples])
  Downloading brax-0.11.0-py3-none-any.whl.metadata (7.7 kB)
Requirement already satisfied: absl-py in /usr/local/lib/python3.10/dist-packages (from brax&gt;=0.10.5-&gt;rex-lib[examples]) (1.4.0)
Collecting dm-env (from brax&gt;=0.10.5-&gt;rex-lib[examples])
  Downloading dm_env-1.6-py3-none-any.whl.metadata (966 bytes)
Requirement already satisfied: etils in /usr/local/lib/python3.10/dist-packages (from brax&gt;=0.10.5-&gt;rex-lib[examples]) (1.9.4)
Requirement already satisfied: flask in /usr/local/lib/python3.10/dist-packages (from brax&gt;=0.10.5-&gt;rex-lib[examples]) (2.2.5)
Collecting flask-cors (from brax&gt;=0.10.5-&gt;rex-lib[examples])
  Downloading Flask_Cors-5.0.0-py2.py3-none-any.whl.metadata (5.5 kB)
Requirement already satisfied: grpcio in /usr/local/lib/python3.10/dist-packages (from brax&gt;=0.10.5-&gt;rex-lib[examples]) (1.64.1)
Requirement already satisfied: gym in /usr/local/lib/python3.10/dist-packages (from brax&gt;=0.10.5-&gt;rex-lib[examples]) (0.25.2)
Requirement already satisfied: jaxlib&gt;=0.4.6 in /usr/local/lib/python3.10/dist-packages (from brax&gt;=0.10.5-&gt;rex-lib[examples]) (0.4.33)
Collecting jaxopt (from brax&gt;=0.10.5-&gt;rex-lib[examples])
  Downloading jaxopt-0.8.3-py3-none-any.whl.metadata (2.6 kB)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from brax&gt;=0.10.5-&gt;rex-lib[examples]) (3.1.4)
Collecting ml-collections (from brax&gt;=0.10.5-&gt;rex-lib[examples])
  Downloading ml_collections-0.1.1.tar.gz (77 kB)
     <span class="ansi-black-intense-fg">━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━</span> <span class="ansi-green-fg">77.9/77.9 kB</span> <span class="ansi-red-fg">1.8 MB/s</span> eta <span class="ansi-cyan-fg">0:00:00</span>
  Preparing metadata (setup.py) ... done
Collecting mujoco (from brax&gt;=0.10.5-&gt;rex-lib[examples])
  Downloading mujoco-3.2.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (44 kB)
     <span class="ansi-black-intense-fg">━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━</span> <span class="ansi-green-fg">44.4/44.4 kB</span> <span class="ansi-red-fg">1.2 MB/s</span> eta <span class="ansi-cyan-fg">0:00:00</span>
Collecting mujoco-mjx (from brax&gt;=0.10.5-&gt;rex-lib[examples])
  Downloading mujoco_mjx-3.2.3-py3-none-any.whl.metadata (3.4 kB)
Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from brax&gt;=0.10.5-&gt;rex-lib[examples]) (1.26.4)
Requirement already satisfied: orbax-checkpoint in /usr/local/lib/python3.10/dist-packages (from brax&gt;=0.10.5-&gt;rex-lib[examples]) (0.6.4)
Requirement already satisfied: Pillow in /usr/local/lib/python3.10/dist-packages (from brax&gt;=0.10.5-&gt;rex-lib[examples]) (10.4.0)
Collecting pytinyrenderer (from brax&gt;=0.10.5-&gt;rex-lib[examples])
  Downloading pytinyrenderer-0.0.14-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.2 kB)
Requirement already satisfied: scipy in /usr/local/lib/python3.10/dist-packages (from brax&gt;=0.10.5-&gt;rex-lib[examples]) (1.13.1)
Collecting tensorboardX (from brax&gt;=0.10.5-&gt;rex-lib[examples])
  Downloading tensorboardX-2.6.2.2-py2.py3-none-any.whl.metadata (5.8 kB)
Collecting trimesh (from brax&gt;=0.10.5-&gt;rex-lib[examples])
  Downloading trimesh-4.4.9-py3-none-any.whl.metadata (18 kB)
Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from brax&gt;=0.10.5-&gt;rex-lib[examples]) (4.12.2)
Requirement already satisfied: chex&gt;=0.1.8 in /usr/local/lib/python3.10/dist-packages (from distrax&gt;=0.1.5-&gt;rex-lib[examples]) (0.1.87)
Requirement already satisfied: tensorflow-probability&gt;=0.15.0 in /usr/local/lib/python3.10/dist-packages (from distrax&gt;=0.1.5-&gt;rex-lib[examples]) (0.24.0)
Collecting jaxtyping&gt;=0.2.20 (from equinox&gt;=0.11.4-&gt;rex-lib[examples])
  Downloading jaxtyping-0.2.34-py3-none-any.whl.metadata (6.4 kB)
Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from evosax&gt;=0.1.6-&gt;rex-lib[examples]) (6.0.2)
Collecting dotmap (from evosax&gt;=0.1.6-&gt;rex-lib[examples])
  Downloading dotmap-1.3.30-py3-none-any.whl.metadata (3.2 kB)
Requirement already satisfied: msgpack in /usr/local/lib/python3.10/dist-packages (from flax&gt;=0.8.5-&gt;rex-lib[examples]) (1.0.8)
Requirement already satisfied: tensorstore in /usr/local/lib/python3.10/dist-packages (from flax&gt;=0.8.5-&gt;rex-lib[examples]) (0.1.66)
Requirement already satisfied: rich&gt;=11.1 in /usr/local/lib/python3.10/dist-packages (from flax&gt;=0.8.5-&gt;rex-lib[examples]) (13.9.1)
Requirement already satisfied: cloudpickle&gt;=1.2.0 in /usr/local/lib/python3.10/dist-packages (from gymnasium&gt;=0.29.1-&gt;rex-lib[examples]) (2.2.1)
Collecting farama-notifications&gt;=0.0.1 (from gymnasium&gt;=0.29.1-&gt;rex-lib[examples])
  Downloading Farama_Notifications-0.0.4-py3-none-any.whl.metadata (558 bytes)
Requirement already satisfied: ml-dtypes&gt;=0.2.0 in /usr/local/lib/python3.10/dist-packages (from jax&gt;=0.4.30-&gt;rex-lib[examples]) (0.4.1)
Requirement already satisfied: opt-einsum in /usr/local/lib/python3.10/dist-packages (from jax&gt;=0.4.30-&gt;rex-lib[examples]) (3.4.0)
Requirement already satisfied: contourpy&gt;=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib&gt;=3.7.0-&gt;rex-lib[examples]) (1.3.0)
Requirement already satisfied: cycler&gt;=0.10 in /usr/local/lib/python3.10/dist-packages (from matplotlib&gt;=3.7.0-&gt;rex-lib[examples]) (0.12.1)
Requirement already satisfied: fonttools&gt;=4.22.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib&gt;=3.7.0-&gt;rex-lib[examples]) (4.54.1)
Requirement already satisfied: kiwisolver&gt;=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib&gt;=3.7.0-&gt;rex-lib[examples]) (1.4.7)
Requirement already satisfied: packaging&gt;=20.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib&gt;=3.7.0-&gt;rex-lib[examples]) (24.1)
Requirement already satisfied: pyparsing&gt;=2.3.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib&gt;=3.7.0-&gt;rex-lib[examples]) (3.1.4)
Requirement already satisfied: python-dateutil&gt;=2.7 in /usr/local/lib/python3.10/dist-packages (from matplotlib&gt;=3.7.0-&gt;rex-lib[examples]) (2.8.2)
Requirement already satisfied: pandas&gt;=1.2 in /usr/local/lib/python3.10/dist-packages (from seaborn&gt;=0.13.2-&gt;rex-lib[examples]) (2.2.2)
Requirement already satisfied: toolz&gt;=0.9.0 in /usr/local/lib/python3.10/dist-packages (from chex&gt;=0.1.8-&gt;distrax&gt;=0.1.5-&gt;rex-lib[examples]) (0.12.1)
Collecting typeguard==2.13.3 (from jaxtyping&gt;=0.2.20-&gt;equinox&gt;=0.11.4-&gt;rex-lib[examples])
  Downloading typeguard-2.13.3-py3-none-any.whl.metadata (3.6 kB)
Requirement already satisfied: pytz&gt;=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas&gt;=1.2-&gt;seaborn&gt;=0.13.2-&gt;rex-lib[examples]) (2024.2)
Requirement already satisfied: tzdata&gt;=2022.7 in /usr/local/lib/python3.10/dist-packages (from pandas&gt;=1.2-&gt;seaborn&gt;=0.13.2-&gt;rex-lib[examples]) (2024.2)
Requirement already satisfied: six&gt;=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil&gt;=2.7-&gt;matplotlib&gt;=3.7.0-&gt;rex-lib[examples]) (1.16.0)
Requirement already satisfied: markdown-it-py&gt;=2.2.0 in /usr/local/lib/python3.10/dist-packages (from rich&gt;=11.1-&gt;flax&gt;=0.8.5-&gt;rex-lib[examples]) (3.0.0)
Requirement already satisfied: pygments&lt;3.0.0,&gt;=2.13.0 in /usr/local/lib/python3.10/dist-packages (from rich&gt;=11.1-&gt;flax&gt;=0.8.5-&gt;rex-lib[examples]) (2.18.0)
Requirement already satisfied: decorator in /usr/local/lib/python3.10/dist-packages (from tensorflow-probability&gt;=0.15.0-&gt;distrax&gt;=0.1.5-&gt;rex-lib[examples]) (4.4.2)
Requirement already satisfied: gast&gt;=0.3.2 in /usr/local/lib/python3.10/dist-packages (from tensorflow-probability&gt;=0.15.0-&gt;distrax&gt;=0.1.5-&gt;rex-lib[examples]) (0.6.0)
Requirement already satisfied: dm-tree in /usr/local/lib/python3.10/dist-packages (from tensorflow-probability&gt;=0.15.0-&gt;distrax&gt;=0.1.5-&gt;rex-lib[examples]) (0.1.8)
Requirement already satisfied: Werkzeug&gt;=2.2.2 in /usr/local/lib/python3.10/dist-packages (from flask-&gt;brax&gt;=0.10.5-&gt;rex-lib[examples]) (3.0.4)
Requirement already satisfied: itsdangerous&gt;=2.0 in /usr/local/lib/python3.10/dist-packages (from flask-&gt;brax&gt;=0.10.5-&gt;rex-lib[examples]) (2.2.0)
Requirement already satisfied: click&gt;=8.0 in /usr/local/lib/python3.10/dist-packages (from flask-&gt;brax&gt;=0.10.5-&gt;rex-lib[examples]) (8.1.7)
Requirement already satisfied: MarkupSafe&gt;=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2-&gt;brax&gt;=0.10.5-&gt;rex-lib[examples]) (2.1.5)
Requirement already satisfied: gym-notices&gt;=0.0.4 in /usr/local/lib/python3.10/dist-packages (from gym-&gt;brax&gt;=0.10.5-&gt;rex-lib[examples]) (0.0.8)
Requirement already satisfied: contextlib2 in /usr/local/lib/python3.10/dist-packages (from ml-collections-&gt;brax&gt;=0.10.5-&gt;rex-lib[examples]) (21.6.0)
Collecting glfw (from mujoco-&gt;brax&gt;=0.10.5-&gt;rex-lib[examples])
  Downloading glfw-2.7.0-py2.py27.py3.py30.py31.py32.py33.py34.py35.py36.py37.py38-none-manylinux2014_x86_64.whl.metadata (5.4 kB)
Requirement already satisfied: pyopengl in /usr/local/lib/python3.10/dist-packages (from mujoco-&gt;brax&gt;=0.10.5-&gt;rex-lib[examples]) (3.1.7)
Requirement already satisfied: nest_asyncio in /usr/local/lib/python3.10/dist-packages (from orbax-checkpoint-&gt;brax&gt;=0.10.5-&gt;rex-lib[examples]) (1.6.0)
Requirement already satisfied: protobuf in /usr/local/lib/python3.10/dist-packages (from orbax-checkpoint-&gt;brax&gt;=0.10.5-&gt;rex-lib[examples]) (3.20.3)
Requirement already satisfied: humanize in /usr/local/lib/python3.10/dist-packages (from orbax-checkpoint-&gt;brax&gt;=0.10.5-&gt;rex-lib[examples]) (4.10.0)
Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py&gt;=2.2.0-&gt;rich&gt;=11.1-&gt;flax&gt;=0.8.5-&gt;rex-lib[examples]) (0.1.2)
Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from etils-&gt;brax&gt;=0.10.5-&gt;rex-lib[examples]) (2024.6.1)
Requirement already satisfied: importlib_resources in /usr/local/lib/python3.10/dist-packages (from etils-&gt;brax&gt;=0.10.5-&gt;rex-lib[examples]) (6.4.5)
Requirement already satisfied: zipp in /usr/local/lib/python3.10/dist-packages (from etils-&gt;brax&gt;=0.10.5-&gt;rex-lib[examples]) (3.20.2)
Downloading brax-0.11.0-py3-none-any.whl (998 kB)
   <span class="ansi-black-intense-fg">━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━</span> <span class="ansi-green-fg">998.6/998.6 kB</span> <span class="ansi-red-fg">11.8 MB/s</span> eta <span class="ansi-cyan-fg">0:00:00</span>
Downloading dill-0.3.9-py3-none-any.whl (119 kB)
   <span class="ansi-black-intense-fg">━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━</span> <span class="ansi-green-fg">119.4/119.4 kB</span> <span class="ansi-red-fg">5.3 MB/s</span> eta <span class="ansi-cyan-fg">0:00:00</span>
Downloading distrax-0.1.5-py3-none-any.whl (319 kB)
   <span class="ansi-black-intense-fg">━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━</span> <span class="ansi-green-fg">319.7/319.7 kB</span> <span class="ansi-red-fg">9.4 MB/s</span> eta <span class="ansi-cyan-fg">0:00:00</span>
Downloading equinox-0.11.7-py3-none-any.whl (178 kB)
   <span class="ansi-black-intense-fg">━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━</span> <span class="ansi-green-fg">178.4/178.4 kB</span> <span class="ansi-red-fg">7.5 MB/s</span> eta <span class="ansi-cyan-fg">0:00:00</span>
Downloading evosax-0.1.6-py3-none-any.whl (240 kB)
   <span class="ansi-black-intense-fg">━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━</span> <span class="ansi-green-fg">240.4/240.4 kB</span> <span class="ansi-red-fg">8.5 MB/s</span> eta <span class="ansi-cyan-fg">0:00:00</span>
Downloading gymnasium-1.0.0-py3-none-any.whl (958 kB)
   <span class="ansi-black-intense-fg">━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━</span> <span class="ansi-green-fg">958.1/958.1 kB</span> <span class="ansi-red-fg">13.3 MB/s</span> eta <span class="ansi-cyan-fg">0:00:00</span>
Downloading seaborn-0.13.2-py3-none-any.whl (294 kB)
   <span class="ansi-black-intense-fg">━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━</span> <span class="ansi-green-fg">294.9/294.9 kB</span> <span class="ansi-red-fg">8.4 MB/s</span> eta <span class="ansi-cyan-fg">0:00:00</span>
Downloading supergraph-0.0.8-py3-none-any.whl (65 kB)
   <span class="ansi-black-intense-fg">━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━</span> <span class="ansi-green-fg">65.5/65.5 kB</span> <span class="ansi-red-fg">2.5 MB/s</span> eta <span class="ansi-cyan-fg">0:00:00</span>
Downloading rex_lib-0.0.5-py3-none-any.whl (115 kB)
   <span class="ansi-black-intense-fg">━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━</span> <span class="ansi-green-fg">115.1/115.1 kB</span> <span class="ansi-red-fg">4.9 MB/s</span> eta <span class="ansi-cyan-fg">0:00:00</span>
Downloading Farama_Notifications-0.0.4-py3-none-any.whl (2.5 kB)
Downloading jaxtyping-0.2.34-py3-none-any.whl (42 kB)
   <span class="ansi-black-intense-fg">━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━</span> <span class="ansi-green-fg">42.4/42.4 kB</span> <span class="ansi-red-fg">1.7 MB/s</span> eta <span class="ansi-cyan-fg">0:00:00</span>
Downloading typeguard-2.13.3-py3-none-any.whl (17 kB)
Downloading dm_env-1.6-py3-none-any.whl (26 kB)
Downloading dotmap-1.3.30-py3-none-any.whl (11 kB)
Downloading Flask_Cors-5.0.0-py2.py3-none-any.whl (14 kB)
Downloading jaxopt-0.8.3-py3-none-any.whl (172 kB)
   <span class="ansi-black-intense-fg">━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━</span> <span class="ansi-green-fg">172.3/172.3 kB</span> <span class="ansi-red-fg">5.2 MB/s</span> eta <span class="ansi-cyan-fg">0:00:00</span>
Downloading mujoco-3.2.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (6.1 MB)
   <span class="ansi-black-intense-fg">━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━</span> <span class="ansi-green-fg">6.1/6.1 MB</span> <span class="ansi-red-fg">23.7 MB/s</span> eta <span class="ansi-cyan-fg">0:00:00</span>
Downloading mujoco_mjx-3.2.3-py3-none-any.whl (6.7 MB)
   <span class="ansi-black-intense-fg">━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━</span> <span class="ansi-green-fg">6.7/6.7 MB</span> <span class="ansi-red-fg">12.5 MB/s</span> eta <span class="ansi-cyan-fg">0:00:00</span>
Downloading pytinyrenderer-0.0.14-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.9 MB)
   <span class="ansi-black-intense-fg">━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━</span> <span class="ansi-green-fg">1.9/1.9 MB</span> <span class="ansi-red-fg">13.9 MB/s</span> eta <span class="ansi-cyan-fg">0:00:00</span>
Downloading tensorboardX-2.6.2.2-py2.py3-none-any.whl (101 kB)
   <span class="ansi-black-intense-fg">━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━</span> <span class="ansi-green-fg">101.7/101.7 kB</span> <span class="ansi-red-fg">3.1 MB/s</span> eta <span class="ansi-cyan-fg">0:00:00</span>
Downloading trimesh-4.4.9-py3-none-any.whl (700 kB)
   <span class="ansi-black-intense-fg">━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━</span> <span class="ansi-green-fg">700.1/700.1 kB</span> <span class="ansi-red-fg">20.4 MB/s</span> eta <span class="ansi-cyan-fg">0:00:00</span>
Downloading glfw-2.7.0-py2.py27.py3.py30.py31.py32.py33.py34.py35.py36.py37.py38-none-manylinux2014_x86_64.whl (211 kB)
   <span class="ansi-black-intense-fg">━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━</span> <span class="ansi-green-fg">211.8/211.8 kB</span> <span class="ansi-red-fg">7.5 MB/s</span> eta <span class="ansi-cyan-fg">0:00:00</span>
Building wheels for collected packages: ml-collections
  Building wheel for ml-collections (setup.py) ... done
  Created wheel for ml-collections: filename=ml_collections-0.1.1-py3-none-any.whl size=94507 sha256=8b83b1225aa4d52136d84206a5cb94da537f08a16dbd7b480fa90dd833c1cf78
  Stored in directory: /root/.cache/pip/wheels/7b/89/c9/a9b87790789e94aadcfc393c283e3ecd5ab916aed0a31be8fe
Successfully built ml-collections
Installing collected packages: pytinyrenderer, glfw, farama-notifications, dotmap, typeguard, trimesh, tensorboardX, supergraph, ml-collections, gymnasium, dm-env, dill, jaxtyping, seaborn, mujoco, flask-cors, mujoco-mjx, jaxopt, equinox, distrax, evosax, brax, rex-lib
  Attempting uninstall: typeguard
    Found existing installation: typeguard 4.3.0
    Uninstalling typeguard-4.3.0:
      Successfully uninstalled typeguard-4.3.0
  Attempting uninstall: seaborn
    Found existing installation: seaborn 0.13.1
    Uninstalling seaborn-0.13.1:
      Successfully uninstalled seaborn-0.13.1
<span class="ansi-red-fg">ERROR: pip&#39;s dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
inflect 7.4.0 requires typeguard&gt;=4.0.1, but you have typeguard 2.13.3 which is incompatible.</span><span class="ansi-red-fg">
</span>Successfully installed brax-0.11.0 dill-0.3.9 distrax-0.1.5 dm-env-1.6 dotmap-1.3.30 equinox-0.11.7 evosax-0.1.6 farama-notifications-0.0.4 flask-cors-5.0.0 glfw-2.7.0 gymnasium-1.0.0 jaxopt-0.8.3 jaxtyping-0.2.34 ml-collections-0.1.1 mujoco-3.2.3 mujoco-mjx-3.2.3 pytinyrenderer-0.0.14 rex-lib-0.0.5 seaborn-0.13.2 supergraph-0.0.8 tensorboardX-2.6.2.2 trimesh-4.4.9 typeguard-2.13.3
</code>
</pre>
</div>
</div>
</div>
</div>
</div>
<div class="cell border-box-sizing text_cell rendered">
<div class="inner_cell">
<div class="text_cell_render border-box-sizing rendered_html">
<h1 id="introduction-to-nodes-in-rex">Introduction to Nodes in Rex<a class="headerlink" href="#introduction-to-nodes-in-rex" title="Permanent link">¤</a></h1>
<p>In <strong>Rex</strong>, a <strong>node</strong> represents a fundamental computational unit within a graph-based system. Nodes encapsulate specific functionality and interact by passing data through connections, forming a network that can model complex systems. This tutorial introduces how to define nodes, specify their properties like rates and delays, and manage their interactions within a graph.</p>
<h2 id="defining-nodes">Defining Nodes<a class="headerlink" href="#defining-nodes" title="Permanent link">¤</a></h2>
<p>Nodes are defined by creating subclasses of the <code>BaseNode</code> class. This base class provides a standardized API and essential functionality that all nodes inherit. When defining a node, you can specify several parameters directly in the <code>__init__</code> method:</p>
<ul>
<li><strong><code>name</code></strong>: A unique identifier for the node.</li>
<li><strong><code>rate</code></strong>: The frequency at which the node's <code>step</code> method is called (in Hz).</li>
<li><strong><code>delay</code></strong> (optional): The expected computation delay of the node (in seconds).</li>
<li><strong><code>delay_dist</code></strong>: A distribution representing variability in the node's computation delay, useful for simulations.</li>
<li><strong><code>advance</code></strong>: If <code>True</code>, the node's <code>step</code> method triggers when all inputs are ready; if <code>False</code>, it throttles until the scheduled time.</li>
<li><strong><code>scheduling</code></strong>: Determines how the node's execution is scheduled. Options include <code>Scheduling.FREQUENCY</code> and <code>Scheduling.PHASE</code>.</li>
<li><strong><code>color</code></strong>: Used for visualization purposes.</li>
<li><strong><code>order</code></strong>: Determines the node's order in visualizations.</li>
</ul>
<p>Here's a basic example of a node definition:</p>
<div class="highlight"><pre><span></span><code><span class="k">class</span> <span class="nc">MyNode</span><span class="p">(</span><span class="n">BaseNode</span><span class="p">):</span>
    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
        <span class="bp">self</span><span class="p">,</span>
        <span class="n">name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
        <span class="n">rate</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span>
        <span class="n">delay</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>  <span class="c1"># Expected computation delay (used for phase-shifting)</span>
        <span class="n">delay_dist</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">DelayDistribution</span><span class="p">,</span> <span class="n">distrax</span><span class="o">.</span><span class="n">Distribution</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>  <span class="c1"># Sim. computation delay</span>
        <span class="n">advance</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
        <span class="n">scheduling</span><span class="p">:</span> <span class="n">Scheduling</span> <span class="o">=</span> <span class="n">Scheduling</span><span class="o">.</span><span class="n">FREQUENCY</span><span class="p">,</span>
        <span class="n">color</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
        <span class="n">order</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span>
    <span class="p">):</span>
        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">rate</span><span class="p">,</span> <span class="n">delay</span><span class="p">,</span> <span class="n">delay_dist</span><span class="p">,</span> <span class="n">advance</span><span class="p">,</span> <span class="n">scheduling</span><span class="p">,</span> <span class="n">color</span><span class="p">,</span> <span class="n">order</span><span class="p">)</span>
        <span class="c1"># Additional initialization if needed</span>

    <span class="k">def</span> <span class="nf">init_params</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">rng</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">Array</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">graph_state</span><span class="p">:</span> <span class="n">GraphState</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
        <span class="c1"># Initialize parameters</span>
        <span class="k">return</span> <span class="n">MyParams</span><span class="p">()</span>

    <span class="k">def</span> <span class="nf">init_state</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">rng</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">Array</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">graph_state</span><span class="p">:</span> <span class="n">GraphState</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
        <span class="c1"># Initialize state</span>
        <span class="k">return</span> <span class="n">MyState</span><span class="p">()</span>

    <span class="k">def</span> <span class="nf">init_output</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">rng</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">Array</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">graph_state</span><span class="p">:</span> <span class="n">GraphState</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
        <span class="c1"># Initialize default output</span>
        <span class="k">return</span> <span class="n">MyOutput</span><span class="p">()</span>

    <span class="k">def</span> <span class="nf">step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">step_state</span><span class="p">:</span> <span class="n">StepState</span><span class="p">):</span>
        <span class="c1"># Node&#39;s computation logic</span>
        <span class="n">new_state</span> <span class="o">=</span> <span class="o">...</span>
        <span class="n">output</span> <span class="o">=</span> <span class="o">...</span>
        <span class="k">return</span> <span class="n">step_state</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="n">state</span><span class="o">=</span><span class="n">new_state</span><span class="p">),</span> <span class="n">output</span>
</code></pre></div>
<h2 id="connecting-nodes">Connecting Nodes<a class="headerlink" href="#connecting-nodes" title="Permanent link">¤</a></h2>
<p>Nodes interact by passing outputs from one node to the inputs of another. This is achieved through the <code>connect</code> method, which establishes a connection between two nodes.</p>
<h3 id="connection-api">Connection API<a class="headerlink" href="#connection-api" title="Permanent link">¤</a></h3>
<p>When connecting nodes, you can specify several parameters that control the nature of the connection:</p>
<ul>
<li><strong><code>output_node</code></strong>: node whose output will be connected as an input.</li>
<li><strong><code>blocking</code></strong>: <code>True</code>, the receiving node waits for the input before proceeding. This can create dependencies between nodes.</li>
<li><strong><code>delay</code></strong>: An additional delay introduced in the connection, which can control the phase shift between nodes.</li>
<li><strong><code>delay_dist</code></strong>: Used in simulation to model communication delays between nodes.</li>
<li><strong><code>window</code></strong>: Determines how many past messages are stored and accessible in the input buffer.</li>
<li><strong><code>skip</code></strong>: If <code>True</code>, the connection is skipped when messages arrive simultaneously, helping resolve cyclic dependencies.</li>
<li><strong><code>jitter</code></strong>: Controls how to handle irregularities in message timing (e.g., <code>Jitter.LATEST</code> uses the most recent message).</li>
<li><strong><code>name</code></strong>: A shadow name for the input; defaults to the output node's name.</li>
</ul>
<h4 id="including-delay_dist-in-connection">Including <code>delay_dist</code> in Connection<a class="headerlink" href="#including-delay_dist-in-connection" title="Permanent link">¤</a></h4>
<p>The <code>delay_dist</code> parameter allows you to specify a distribution that models the variability in communication delay between nodes. This is particularly useful in simulations where network latency or message passing delays are significant.</p>
<h4 id="resolving-cyclic-dependencies-with-skip">Resolving Cyclic Dependencies with <code>skip</code><a class="headerlink" href="#resolving-cyclic-dependencies-with-skip" title="Permanent link">¤</a></h4>
<p>In graphs where nodes depend on each other's outputs (creating a cycle), the <code>skip</code> parameter can be used to resolve the dependency. By setting <code>skip=True</code> on a connection, you instruct the receiving node to proceed without waiting for the current message if it arrives simultaneously. This breaks the cycle and allows the system to function.</p>
<h4 id="example-connection">Example Connection<a class="headerlink" href="#example-connection" title="Permanent link">¤</a></h4>
<div class="highlight"><pre><span></span><code><span class="n">node_a</span><span class="o">.</span><span class="n">connect</span><span class="p">(</span>
    <span class="n">output_node</span><span class="o">=</span><span class="n">node_b</span><span class="p">,</span>
    <span class="n">blocking</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
    <span class="n">delay</span><span class="o">=</span><span class="mf">0.01</span><span class="p">,</span>  <span class="c1"># Expected communication delay (used for phase-shifting)</span>
    <span class="n">delay_dist</span><span class="o">=</span><span class="n">distrax</span><span class="o">.</span><span class="n">Normal</span><span class="p">(</span><span class="n">loc</span><span class="o">=</span><span class="mf">0.01</span><span class="p">,</span> <span class="n">scale</span><span class="o">=</span><span class="mf">0.005</span><span class="p">),</span> <span class="c1"># Sim. communication delay</span>
    <span class="n">window</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span>
    <span class="n">skip</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
    <span class="n">jitter</span><span class="o">=</span><span class="n">Jitter</span><span class="o">.</span><span class="n">LATEST</span><span class="p">,</span>
    <span class="n">name</span><span class="o">=</span><span class="s2">&quot;input_from_b&quot;</span>
<span class="p">)</span>
</code></pre></div>
<p>In this example, <code>node_a</code> connects to <code>node_b</code> with a blocking connection, an added delay of 0.01 seconds, and a delay distribution for simulation purposes. The <code>window</code> size is set to 5, meaning the last five messages are stored. The <code>skip</code> parameter is <code>False</code>, so the node will wait for the input.</p>
<h2 id="node-data-structure">Node Data Structure<a class="headerlink" href="#node-data-structure" title="Permanent link">¤</a></h2>
<p>Nodes manage four main types of data (defined as <a href="https://jax.readthedocs.io/en/latest/working-with-pytrees.html#patterns">pytrees</a>), typically defined using immutable dataclasses for efficiency and safety:</p>
<ol>
<li><strong>Parameters</strong>: Static configurations that usually remain constant during execution.</li>
<li><strong>State</strong>: Dynamic data that evolves over time with each <code>step</code>.</li>
<li><strong>Outputs</strong>: Data produced by a node's <code>step</code> method and sent to connected nodes.</li>
<li><strong>Inputs</strong>: Buffers that hold incoming data from other nodes, respecting the specified window size.</li>
</ol>
<h3 id="immutable-dataclasses">Immutable Dataclasses<a class="headerlink" href="#immutable-dataclasses" title="Permanent link">¤</a></h3>
<p>Using immutable dataclasses (e.g., via <code>@struct.dataclass</code> from Flax) ensures that the data structures are compatible with JAX's JIT compilation and functional programming paradigms. Additionally, dataclasses allow you to define specific methods related to the data structure, providing encapsulation and clarity.</p>
<div class="highlight"><pre><span></span><code><span class="nd">@struct</span><span class="o">.</span><span class="n">dataclass</span>
<span class="k">class</span> <span class="nc">MyParams</span><span class="p">:</span>
    <span class="n">some_parameter</span><span class="p">:</span> <span class="nb">float</span>

    <span class="k">def</span> <span class="nf">adjust_parameter</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">factor</span><span class="p">:</span> <span class="nb">float</span><span class="p">):</span>
        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="n">some_parameter</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">some_parameter</span> <span class="o">*</span> <span class="n">factor</span><span class="p">)</span>

<span class="nd">@struct</span><span class="o">.</span><span class="n">dataclass</span>
<span class="k">class</span> <span class="nc">MyState</span><span class="p">:</span>
    <span class="n">some_state_variable</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">Array</span>

    <span class="k">def</span> <span class="nf">update_state</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">delta</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">Array</span><span class="p">):</span>
        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="n">some_state_variable</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">some_state_variable</span> <span class="o">+</span> <span class="n">delta</span><span class="p">)</span>

<span class="nd">@struct</span><span class="o">.</span><span class="n">dataclass</span>
<span class="k">class</span> <span class="nc">MyOutput</span><span class="p">:</span>
    <span class="n">some_output_data</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">Array</span>
</code></pre></div>
<p>In this example, <code>MyParams</code> and <code>MyState</code> include methods to adjust parameters and update state, respectively. This encapsulation enhances code organization and readability.</p>
<h3 id="initialization">Initialization<a class="headerlink" href="#initialization" title="Permanent link">¤</a></h3>
<p>Node data is initialized using specific methods that you should override:</p>
<ul>
<li><strong><code>init_params</code></strong>: Initializes the node's parameters.</li>
<li><strong><code>init_state</code></strong>: Initializes the node's state.</li>
<li><strong><code>init_output</code></strong>: Provides a default output, useful for initializing input buffers in connected nodes.</li>
</ul>
<p>These methods are typically called during the graph's initialization phase using <code>graph.init()</code>.</p>
<h2 id="the-step-method-in-detail">The <code>step</code> Method in Detail<a class="headerlink" href="#the-step-method-in-detail" title="Permanent link">¤</a></h2>
<p>The <code>step</code> method defines how a node processes inputs and updates its state at each timestep. It receives a <code>StepState</code> object with all necessary information.</p>
<h3 id="stepstate-attributes"><code>StepState</code> Attributes<a class="headerlink" href="#stepstate-attributes" title="Permanent link">¤</a></h3>
<ul>
<li><strong><code>rng</code></strong>: Random number generator (updated if used).</li>
<li><strong><code>state</code></strong>: Node's current state.</li>
<li><strong><code>params</code></strong>: Static parameters influencing behavior.</li>
<li><strong><code>inputs</code></strong>: Dictionary of <code>InputState</code> instances (keyed by input names).</li>
<li><strong><code>eps</code></strong>: Episode number relates to the current computation graph used for simulation (unrelated to RL episode number).</li>
<li><strong><code>seq</code></strong>: Current step number (auto-increments with each step).</li>
<li><strong><code>ts</code></strong>: Timestamp at the start of the step.</li>
</ul>
<h3 id="accessing-inputs">Accessing Inputs<a class="headerlink" href="#accessing-inputs" title="Permanent link">¤</a></h3>
<p>Each <code>InputState</code> in <code>step_state.inputs</code> contains:</p>
<ul>
<li><strong><code>data</code></strong>: Messages from the connected node.</li>
<li><strong><code>seq</code></strong>: Sequence numbers of the received messages.</li>
<li><strong><code>ts_sent</code></strong>: Timestamps when messages were sent.</li>
<li><strong><code>ts_recv</code></strong>: Timestamps when messages were received.</li>
</ul>
<p>For example, accessing the most recent message:</p>
<div class="highlight"><pre><span></span><code><span class="n">latest_sensor_input</span> <span class="o">=</span> <span class="n">step_state</span><span class="o">.</span><span class="n">inputs</span><span class="p">[</span><span class="s1">&#39;sensor&#39;</span><span class="p">][</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">data</span>
</code></pre></div>
<h3 id="implementing-the-step-method">Implementing the <code>step</code> Method<a class="headerlink" href="#implementing-the-step-method" title="Permanent link">¤</a></h3>
<p>The typical steps to implement the <code>step</code> method can be condensed into the following block:</p>
<div class="highlight"><pre><span></span><code><span class="k">def</span> <span class="nf">step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">step_state</span><span class="p">:</span> <span class="n">StepState</span><span class="p">):</span>
    <span class="c1"># Unpack StepState</span>
    <span class="n">rng</span><span class="p">,</span> <span class="n">state</span><span class="p">,</span> <span class="n">params</span><span class="p">,</span> <span class="n">inputs</span> <span class="o">=</span> <span class="n">step_state</span><span class="o">.</span><span class="n">rng</span><span class="p">,</span> <span class="n">step_state</span><span class="o">.</span><span class="n">state</span><span class="p">,</span> <span class="n">step_state</span><span class="o">.</span><span class="n">params</span><span class="p">,</span> <span class="n">step_state</span><span class="o">.</span><span class="n">inputs</span>

    <span class="c1"># Access latest input</span>
    <span class="n">control_signal</span> <span class="o">=</span> <span class="n">inputs</span><span class="p">[</span><span class="s1">&#39;controller&#39;</span><span class="p">][</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">data</span>

    <span class="c1"># Update state</span>
    <span class="n">new_state_variable</span> <span class="o">=</span> <span class="n">state</span><span class="o">.</span><span class="n">some_state_variable</span> <span class="o">+</span> <span class="n">control_signal</span> <span class="o">*</span> <span class="n">params</span><span class="o">.</span><span class="n">gain</span>
    <span class="n">new_state</span> <span class="o">=</span> <span class="n">state</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="n">some_state_variable</span><span class="o">=</span><span class="n">new_state_variable</span><span class="p">)</span>

    <span class="c1"># Produce output</span>
    <span class="n">output</span> <span class="o">=</span> <span class="n">MyOutput</span><span class="p">(</span><span class="n">some_output_data</span><span class="o">=</span><span class="n">new_state_variable</span><span class="p">)</span>

    <span class="c1"># Update RNG if randomness is involved</span>
    <span class="n">rng</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">rng</span><span class="p">)</span>

    <span class="c1"># Return updated StepState and output</span>
    <span class="k">return</span> <span class="n">step_state</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="n">state</span><span class="o">=</span><span class="n">new_state</span><span class="p">,</span> <span class="n">rng</span><span class="o">=</span><span class="n">rng</span><span class="p">),</span> <span class="n">output</span>
</code></pre></div>
<h3 id="working-with-time-and-sequence">Working with Time and Sequence<a class="headerlink" href="#working-with-time-and-sequence" title="Permanent link">¤</a></h3>
<p>Use <code>eps</code>, <code>ts</code> and <code>seq</code> for time-dependent logic:</p>
<div class="highlight"><pre><span></span><code><span class="k">if</span> <span class="n">step_state</span><span class="o">.</span><span class="n">ts</span> <span class="o">&gt;</span> <span class="n">params</span><span class="o">.</span><span class="n">activation_time</span><span class="p">:</span>
    <span class="c1"># Perform time-based logic</span>
    <span class="k">pass</span>
</code></pre></div>
<h3 id="handling-input-windows">Handling Input Windows<a class="headerlink" href="#handling-input-windows" title="Permanent link">¤</a></h3>
<p>If the input window size is greater than 1, you can access past messages:</p>
<div class="highlight"><pre><span></span><code><span class="n">recent_sensor_data</span> <span class="o">=</span> <span class="n">inputs</span><span class="p">[</span><span class="s1">&#39;sensor_input&#39;</span><span class="p">][</span><span class="o">-</span><span class="mi">3</span><span class="p">:]</span><span class="o">.</span><span class="n">data</span>
</code></pre></div>
<h3 id="jit-compilation-and-side-effects-handling-with-external-callbacks">JIT Compilation and Side Effects Handling with External Callbacks<a class="headerlink" href="#jit-compilation-and-side-effects-handling-with-external-callbacks" title="Permanent link">¤</a></h3>
<p>Rex advocates for JIT-compiling the <code>step</code> method of each node to enhance performance. However, interfacing with real hardware often involves side effects that JAX's JIT compilation doesn't handle natively.</p>
<p>To include side-effecting code (e.g., sending commands to actuators, reading sensor data), you must use JAX's external callback mechanism. This involves wrapping side-effecting functions with <code>jax.experimental.io_callback</code> to ensure compatibility with JIT compilation.</p>
<p>Refer to the <a href="https://jax.readthedocs.io/en/latest/external_callbacks.html">JAX documentation on external callbacks</a> for detailed guidance.</p>
<div class="highlight"><pre><span></span><code><span class="k">def</span> <span class="nf">step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">step_state</span><span class="p">:</span> <span class="n">StepState</span><span class="p">):</span>
    <span class="c1"># Compute outputs</span>
    <span class="n">output</span> <span class="o">=</span> <span class="o">...</span>

    <span class="c1"># Side-effecting function</span>
    <span class="k">def</span> <span class="nf">_apply_action</span><span class="p">(</span><span class="n">action</span><span class="p">):</span>
        <span class="c1"># Code that interacts with hardware</span>
        <span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mf">1.0</span><span class="p">)</span>  <span class="c1"># Dummy return value</span>

    <span class="c1"># Wrap side-effecting code</span>
    <span class="n">_</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">experimental</span><span class="o">.</span><span class="n">io_callback</span><span class="p">(</span>
        <span class="n">_apply_action</span><span class="p">,</span>
        <span class="n">result_shape</span><span class="o">=</span><span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mf">1.0</span><span class="p">),</span>
        <span class="n">arg</span><span class="o">=</span><span class="n">output</span><span class="o">.</span><span class="n">some_output_data</span>
    <span class="p">)</span>

    <span class="c1"># Update state and return</span>
    <span class="k">return</span> <span class="n">step_state</span><span class="p">,</span> <span class="n">output</span>
</code></pre></div>
<h2 id="real-world-nodes-and-lifecycle-methods">Real-World Nodes and Lifecycle Methods<a class="headerlink" href="#real-world-nodes-and-lifecycle-methods" title="Permanent link">¤</a></h2>
<p>When nodes interface with real hardware or external systems, additional lifecycle management is necessary. The <code>BaseNode</code> API accommodates this through:</p>
<ul>
<li><strong><code>startup</code></strong>: Called before an episode starts, allowing the node to prepare (e.g., initialize hardware).</li>
<li><strong><code>stop</code></strong>: Called after an episode ends, enabling the node to clean up resources or safely shut down hardware.</li>
</ul>
<div class="highlight"><pre><span></span><code><span class="k">class</span> <span class="nc">RealWorldNode</span><span class="p">(</span><span class="n">BaseNode</span><span class="p">):</span>
    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
        <span class="bp">self</span><span class="p">,</span>
        <span class="n">name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
        <span class="n">rate</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span>
        <span class="n">delay</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
        <span class="n">delay_dist</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">DelayDistribution</span><span class="p">,</span> <span class="n">distrax</span><span class="o">.</span><span class="n">Distribution</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
        <span class="n">advance</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
        <span class="n">scheduling</span><span class="p">:</span> <span class="n">Scheduling</span> <span class="o">=</span> <span class="n">Scheduling</span><span class="o">.</span><span class="n">FREQUENCY</span><span class="p">,</span>
        <span class="n">color</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
        <span class="n">order</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span>
    <span class="p">):</span>
        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">rate</span><span class="p">,</span> <span class="n">delay</span><span class="p">,</span> <span class="n">delay_dist</span><span class="p">,</span> <span class="n">advance</span><span class="p">,</span> <span class="n">scheduling</span><span class="p">,</span> <span class="n">color</span><span class="p">,</span> <span class="n">order</span><span class="p">)</span>
        <span class="c1"># Additional initialization if needed</span>

    <span class="k">def</span> <span class="nf">startup</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">graph_state</span><span class="p">:</span> <span class="n">GraphState</span><span class="p">,</span> <span class="n">timeout</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
        <span class="c1"># Initialize hardware connections</span>
        <span class="k">return</span> <span class="kc">True</span>  <span class="c1"># Return True if successful</span>

    <span class="k">def</span> <span class="nf">stop</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">timeout</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
        <span class="c1"># Safely shut down hardware</span>
        <span class="k">return</span> <span class="kc">True</span>
</code></pre></div>
<h2 id="summary">Summary<a class="headerlink" href="#summary" title="Permanent link">¤</a></h2>
<p>By following these guidelines, you can define robust and efficient nodes within the Rex framework. Nodes can be customized extensively through their parameters and state, connected flexibly to form complex graphs, and optimized using JIT compilation. Proper handling of side effects ensures that nodes interfacing with real-world systems remain performant and reliable.</p>
<p>In the following examples, we'll implement specific nodes that illustrate these concepts in practice.</p>
</div>
</div>
</div>
<div class="cell border-box-sizing code_cell rendered">
<div class="input">

<div class="highlight"><pre><span></span><code><span class="c1"># @title Example: Agent</span>

<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Union</span>

<span class="kn">import</span> <span class="nn">jax</span>
<span class="kn">from</span> <span class="nn">flax</span> <span class="kn">import</span> <span class="n">struct</span>
<span class="kn">from</span> <span class="nn">flax.core</span> <span class="kn">import</span> <span class="n">FrozenDict</span>
<span class="kn">from</span> <span class="nn">jax</span> <span class="kn">import</span> <span class="n">numpy</span> <span class="k">as</span> <span class="n">jnp</span>

<span class="kn">from</span> <span class="nn">rex</span> <span class="kn">import</span> <span class="n">base</span>
<span class="kn">from</span> <span class="nn">rex.base</span> <span class="kn">import</span> <span class="n">GraphState</span><span class="p">,</span> <span class="n">StepState</span>
<span class="kn">from</span> <span class="nn">rex.node</span> <span class="kn">import</span> <span class="n">BaseNode</span>
<span class="kn">from</span> <span class="nn">rex.ppo</span> <span class="kn">import</span> <span class="n">Policy</span>


<span class="nd">@struct</span><span class="o">.</span><span class="n">dataclass</span>
<span class="k">class</span> <span class="nc">AgentOutput</span><span class="p">(</span><span class="n">base</span><span class="o">.</span><span class="n">Base</span><span class="p">):</span>
<span class="w">    </span><span class="sd">&quot;&quot;&quot;Agent&#39;s output&quot;&quot;&quot;</span>

    <span class="n">action</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">typing</span><span class="o">.</span><span class="n">ArrayLike</span>  <span class="c1"># Torque to apply to the pendulum</span>


<span class="nd">@struct</span><span class="o">.</span><span class="n">dataclass</span>
<span class="k">class</span> <span class="nc">AgentParams</span><span class="p">(</span><span class="n">base</span><span class="o">.</span><span class="n">Base</span><span class="p">):</span>
    <span class="c1"># Policy</span>
    <span class="n">policy</span><span class="p">:</span> <span class="n">Policy</span>
    <span class="c1"># Observations</span>
    <span class="n">num_act</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">jax</span><span class="o">.</span><span class="n">typing</span><span class="o">.</span><span class="n">ArrayLike</span><span class="p">]</span> <span class="o">=</span> <span class="n">struct</span><span class="o">.</span><span class="n">field</span><span class="p">(</span><span class="n">pytree_node</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>  <span class="c1"># Action history length</span>
    <span class="n">num_obs</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">jax</span><span class="o">.</span><span class="n">typing</span><span class="o">.</span><span class="n">ArrayLike</span><span class="p">]</span> <span class="o">=</span> <span class="n">struct</span><span class="o">.</span><span class="n">field</span><span class="p">(</span><span class="n">pytree_node</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>  <span class="c1"># Observation history length</span>
    <span class="c1"># Action</span>
    <span class="n">max_torque</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">jax</span><span class="o">.</span><span class="n">typing</span><span class="o">.</span><span class="n">ArrayLike</span><span class="p">]</span>
    <span class="c1"># Initial state</span>
    <span class="n">init_method</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="n">struct</span><span class="o">.</span><span class="n">field</span><span class="p">(</span><span class="n">pytree_node</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>  <span class="c1"># &quot;random&quot;, &quot;parametrized&quot;</span>
    <span class="n">parametrized</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">typing</span><span class="o">.</span><span class="n">ArrayLike</span>
    <span class="n">max_th</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">jax</span><span class="o">.</span><span class="n">typing</span><span class="o">.</span><span class="n">ArrayLike</span><span class="p">]</span>
    <span class="n">max_thdot</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">jax</span><span class="o">.</span><span class="n">typing</span><span class="o">.</span><span class="n">ArrayLike</span><span class="p">]</span>
    <span class="c1"># Train</span>
    <span class="n">gamma</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">jax</span><span class="o">.</span><span class="n">typing</span><span class="o">.</span><span class="n">ArrayLike</span><span class="p">]</span>
    <span class="n">tmax</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">jax</span><span class="o">.</span><span class="n">typing</span><span class="o">.</span><span class="n">ArrayLike</span><span class="p">]</span>

    <span class="nd">@staticmethod</span>
    <span class="k">def</span> <span class="nf">process_inputs</span><span class="p">(</span><span class="n">inputs</span><span class="p">:</span> <span class="n">FrozenDict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">base</span><span class="o">.</span><span class="n">InputState</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">jax</span><span class="o">.</span><span class="n">Array</span><span class="p">:</span>
        <span class="n">th</span><span class="p">,</span> <span class="n">thdot</span> <span class="o">=</span> <span class="n">inputs</span><span class="p">[</span><span class="s2">&quot;sensor&quot;</span><span class="p">][</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">th</span><span class="p">,</span> <span class="n">inputs</span><span class="p">[</span><span class="s2">&quot;sensor&quot;</span><span class="p">][</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">thdot</span>
        <span class="n">obs</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">jnp</span><span class="o">.</span><span class="n">cos</span><span class="p">(</span><span class="n">th</span><span class="p">),</span> <span class="n">jnp</span><span class="o">.</span><span class="n">sin</span><span class="p">(</span><span class="n">th</span><span class="p">),</span> <span class="n">thdot</span><span class="p">])</span>
        <span class="k">return</span> <span class="n">obs</span>

    <span class="nd">@staticmethod</span>
    <span class="k">def</span> <span class="nf">get_observation</span><span class="p">(</span><span class="n">step_state</span><span class="p">:</span> <span class="n">StepState</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">jax</span><span class="o">.</span><span class="n">Array</span><span class="p">:</span>
        <span class="c1"># Unpack StepState</span>
        <span class="n">inputs</span><span class="p">,</span> <span class="n">state</span> <span class="o">=</span> <span class="n">step_state</span><span class="o">.</span><span class="n">inputs</span><span class="p">,</span> <span class="n">step_state</span><span class="o">.</span><span class="n">state</span>

        <span class="c1"># Convert inputs to single observation</span>
        <span class="n">single_obs</span> <span class="o">=</span> <span class="n">AgentParams</span><span class="o">.</span><span class="n">process_inputs</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span>

        <span class="c1"># Concatenate with previous observations</span>
        <span class="n">obs</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">concatenate</span><span class="p">([</span><span class="n">single_obs</span><span class="p">,</span> <span class="n">state</span><span class="o">.</span><span class="n">history_obs</span><span class="o">.</span><span class="n">flatten</span><span class="p">(),</span> <span class="n">state</span><span class="o">.</span><span class="n">history_act</span><span class="o">.</span><span class="n">flatten</span><span class="p">()])</span>
        <span class="k">return</span> <span class="n">obs</span>

    <span class="nd">@staticmethod</span>
    <span class="k">def</span> <span class="nf">update_state</span><span class="p">(</span><span class="n">step_state</span><span class="p">:</span> <span class="n">StepState</span><span class="p">,</span> <span class="n">action</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">Array</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;AgentState&quot;</span><span class="p">:</span>
        <span class="c1"># Unpack StepState</span>
        <span class="n">state</span><span class="p">,</span> <span class="n">params</span><span class="p">,</span> <span class="n">inputs</span> <span class="o">=</span> <span class="n">step_state</span><span class="o">.</span><span class="n">state</span><span class="p">,</span> <span class="n">step_state</span><span class="o">.</span><span class="n">params</span><span class="p">,</span> <span class="n">step_state</span><span class="o">.</span><span class="n">inputs</span>

        <span class="c1"># Convert inputs to observation</span>
        <span class="n">single_obs</span> <span class="o">=</span> <span class="n">AgentParams</span><span class="o">.</span><span class="n">process_inputs</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span>

        <span class="c1"># Update obs history</span>
        <span class="k">if</span> <span class="n">params</span><span class="o">.</span><span class="n">num_obs</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
            <span class="n">history_obs</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">roll</span><span class="p">(</span><span class="n">state</span><span class="o">.</span><span class="n">history_obs</span><span class="p">,</span> <span class="n">shift</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
            <span class="n">history_obs</span> <span class="o">=</span> <span class="n">history_obs</span><span class="o">.</span><span class="n">at</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">set</span><span class="p">(</span><span class="n">single_obs</span><span class="p">)</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="n">history_obs</span> <span class="o">=</span> <span class="n">state</span><span class="o">.</span><span class="n">history_obs</span>

        <span class="c1"># Update act history</span>
        <span class="k">if</span> <span class="n">params</span><span class="o">.</span><span class="n">num_act</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
            <span class="n">history_act</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">roll</span><span class="p">(</span><span class="n">state</span><span class="o">.</span><span class="n">history_act</span><span class="p">,</span> <span class="n">shift</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
            <span class="n">history_act</span> <span class="o">=</span> <span class="n">history_act</span><span class="o">.</span><span class="n">at</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">set</span><span class="p">(</span><span class="n">action</span><span class="p">)</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="n">history_act</span> <span class="o">=</span> <span class="n">state</span><span class="o">.</span><span class="n">history_act</span>

        <span class="n">new_state</span> <span class="o">=</span> <span class="n">state</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="n">history_obs</span><span class="o">=</span><span class="n">history_obs</span><span class="p">,</span> <span class="n">history_act</span><span class="o">=</span><span class="n">history_act</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">new_state</span>

    <span class="nd">@staticmethod</span>
    <span class="k">def</span> <span class="nf">to_output</span><span class="p">(</span><span class="n">action</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">Array</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">AgentOutput</span><span class="p">:</span>
        <span class="k">return</span> <span class="n">AgentOutput</span><span class="p">(</span><span class="n">action</span><span class="o">=</span><span class="n">action</span><span class="p">)</span>


<span class="nd">@struct</span><span class="o">.</span><span class="n">dataclass</span>
<span class="k">class</span> <span class="nc">AgentState</span><span class="p">(</span><span class="n">base</span><span class="o">.</span><span class="n">Base</span><span class="p">):</span>
    <span class="n">history_act</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">typing</span><span class="o">.</span><span class="n">ArrayLike</span>  <span class="c1"># History of actions</span>
    <span class="n">history_obs</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">typing</span><span class="o">.</span><span class="n">ArrayLike</span>  <span class="c1"># History of observations</span>
    <span class="n">init_th</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">jax</span><span class="o">.</span><span class="n">typing</span><span class="o">.</span><span class="n">ArrayLike</span><span class="p">]</span>  <span class="c1"># Initial angle of the pendulum</span>
    <span class="n">init_thdot</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">jax</span><span class="o">.</span><span class="n">typing</span><span class="o">.</span><span class="n">ArrayLike</span><span class="p">]</span>  <span class="c1"># Initial angular velocity of the pendulum</span>


<span class="k">class</span> <span class="nc">Agent</span><span class="p">(</span><span class="n">BaseNode</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">init_params</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">rng</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">Array</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">graph_state</span><span class="p">:</span> <span class="n">GraphState</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">AgentParams</span><span class="p">:</span>
        <span class="k">return</span> <span class="n">AgentParams</span><span class="p">(</span>
            <span class="n">policy</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>  <span class="c1"># Policy must be set by the user</span>
            <span class="n">num_act</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span>  <span class="c1"># Number of actions to keep in history</span>
            <span class="n">num_obs</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span>  <span class="c1"># Number of observations to keep in history</span>
            <span class="n">max_torque</span><span class="o">=</span><span class="mf">2.0</span><span class="p">,</span>  <span class="c1"># Maximum torque that can be applied to the pendulum</span>
            <span class="n">init_method</span><span class="o">=</span><span class="s2">&quot;parametrized&quot;</span><span class="p">,</span>  <span class="c1"># &quot;random&quot; or &quot;parametrized&quot;</span>
            <span class="n">parametrized</span><span class="o">=</span><span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">jnp</span><span class="o">.</span><span class="n">pi</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">]),</span>  <span class="c1"># [th, thdot]</span>
            <span class="n">max_th</span><span class="o">=</span><span class="n">jnp</span><span class="o">.</span><span class="n">pi</span><span class="p">,</span>  <span class="c1"># Maximum initial angle of the pendulum</span>
            <span class="n">max_thdot</span><span class="o">=</span><span class="mf">9.0</span><span class="p">,</span>  <span class="c1"># Maximum initial angular velocity of the pendulum</span>
            <span class="n">gamma</span><span class="o">=</span><span class="mf">0.99</span><span class="p">,</span>  <span class="c1"># Discount factor  (used during training)</span>
            <span class="n">tmax</span><span class="o">=</span><span class="mf">3.0</span><span class="p">,</span>  <span class="c1"># Maximum time for an episode (used during training)</span>
        <span class="p">)</span>

    <span class="k">def</span> <span class="nf">init_state</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">rng</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">Array</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">graph_state</span><span class="p">:</span> <span class="n">GraphState</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">AgentState</span><span class="p">:</span>
        <span class="n">graph_state</span> <span class="o">=</span> <span class="n">graph_state</span> <span class="ow">or</span> <span class="n">base</span><span class="o">.</span><span class="n">GraphState</span><span class="p">()</span>
        <span class="n">params</span> <span class="o">=</span> <span class="n">graph_state</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">init_params</span><span class="p">(</span><span class="n">rng</span><span class="p">,</span> <span class="n">graph_state</span><span class="p">))</span>
        <span class="n">history_act</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">params</span><span class="o">.</span><span class="n">num_act</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">jnp</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>  <span class="c1"># [torque]</span>
        <span class="n">history_obs</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">params</span><span class="o">.</span><span class="n">num_obs</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">jnp</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>  <span class="c1"># [cos(th), sin(th), thdot]</span>

        <span class="c1"># Set the initial state of the pendulum</span>
        <span class="k">if</span> <span class="n">params</span><span class="o">.</span><span class="n">init_method</span> <span class="o">==</span> <span class="s2">&quot;parametrized&quot;</span><span class="p">:</span>
            <span class="n">init_th</span><span class="p">,</span> <span class="n">init_thdot</span> <span class="o">=</span> <span class="n">params</span><span class="o">.</span><span class="n">parametrized</span>
        <span class="k">elif</span> <span class="n">params</span><span class="o">.</span><span class="n">init_method</span> <span class="o">==</span> <span class="s2">&quot;random&quot;</span><span class="p">:</span>
            <span class="n">rng</span> <span class="o">=</span> <span class="n">rng</span> <span class="k">if</span> <span class="n">rng</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">PRNGKey</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
            <span class="n">rngs</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">rng</span><span class="p">,</span> <span class="n">num</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
            <span class="n">init_th</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">rngs</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">shape</span><span class="o">=</span><span class="p">(),</span> <span class="n">minval</span><span class="o">=-</span><span class="n">params</span><span class="o">.</span><span class="n">max_th</span><span class="p">,</span> <span class="n">maxval</span><span class="o">=</span><span class="n">params</span><span class="o">.</span><span class="n">max_th</span><span class="p">)</span>
            <span class="n">init_thdot</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">rngs</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">shape</span><span class="o">=</span><span class="p">(),</span> <span class="n">minval</span><span class="o">=-</span><span class="n">params</span><span class="o">.</span><span class="n">max_thdot</span><span class="p">,</span> <span class="n">maxval</span><span class="o">=</span><span class="n">params</span><span class="o">.</span><span class="n">max_thdot</span><span class="p">)</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Invalid init_method: </span><span class="si">{</span><span class="n">params</span><span class="o">.</span><span class="n">init_method</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">AgentState</span><span class="p">(</span><span class="n">history_act</span><span class="o">=</span><span class="n">history_act</span><span class="p">,</span> <span class="n">history_obs</span><span class="o">=</span><span class="n">history_obs</span><span class="p">,</span> <span class="n">init_th</span><span class="o">=</span><span class="n">init_th</span><span class="p">,</span> <span class="n">init_thdot</span><span class="o">=</span><span class="n">init_thdot</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">init_output</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">rng</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">Array</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">graph_state</span><span class="p">:</span> <span class="n">GraphState</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">AgentOutput</span><span class="p">:</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Default output of the node.&quot;&quot;&quot;</span>
        <span class="n">rng</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">PRNGKey</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="k">if</span> <span class="n">rng</span> <span class="ow">is</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">rng</span>
        <span class="n">graph_state</span> <span class="o">=</span> <span class="n">graph_state</span> <span class="ow">or</span> <span class="n">base</span><span class="o">.</span><span class="n">GraphState</span><span class="p">()</span>
        <span class="n">params</span> <span class="o">=</span> <span class="n">graph_state</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">init_params</span><span class="p">(</span><span class="n">rng</span><span class="p">,</span> <span class="n">graph_state</span><span class="p">))</span>
        <span class="n">action</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">rng</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,),</span> <span class="n">minval</span><span class="o">=-</span><span class="n">params</span><span class="o">.</span><span class="n">max_torque</span><span class="p">,</span> <span class="n">maxval</span><span class="o">=</span><span class="n">params</span><span class="o">.</span><span class="n">max_torque</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">AgentOutput</span><span class="p">(</span><span class="n">action</span><span class="o">=</span><span class="n">action</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">step_state</span><span class="p">:</span> <span class="n">StepState</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">StepState</span><span class="p">,</span> <span class="n">AgentOutput</span><span class="p">]:</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Step the node.&quot;&quot;&quot;</span>
        <span class="c1"># Unpack StepState</span>
        <span class="n">rng</span><span class="p">,</span> <span class="n">params</span> <span class="o">=</span> <span class="n">step_state</span><span class="o">.</span><span class="n">rng</span><span class="p">,</span> <span class="n">step_state</span><span class="o">.</span><span class="n">params</span>

        <span class="c1"># Prepare output</span>
        <span class="n">rng</span><span class="p">,</span> <span class="n">rng_net</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">rng</span><span class="p">)</span>
        <span class="k">if</span> <span class="n">params</span><span class="o">.</span><span class="n">policy</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>  <span class="c1"># Use policy to get action</span>
            <span class="n">obs</span> <span class="o">=</span> <span class="n">AgentParams</span><span class="o">.</span><span class="n">get_observation</span><span class="p">(</span><span class="n">step_state</span><span class="p">)</span>
            <span class="n">action</span> <span class="o">=</span> <span class="n">params</span><span class="o">.</span><span class="n">policy</span><span class="o">.</span><span class="n">get_action</span><span class="p">(</span><span class="n">obs</span><span class="p">,</span> <span class="n">rng</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span>  <span class="c1"># Supply rng for stochastic policies</span>
        <span class="k">else</span><span class="p">:</span>  <span class="c1"># Random action if no policy is set</span>
            <span class="n">action</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">rng_net</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,),</span> <span class="n">minval</span><span class="o">=-</span><span class="n">params</span><span class="o">.</span><span class="n">max_torque</span><span class="p">,</span> <span class="n">maxval</span><span class="o">=</span><span class="n">params</span><span class="o">.</span><span class="n">max_torque</span><span class="p">)</span>
        <span class="n">output</span> <span class="o">=</span> <span class="n">AgentParams</span><span class="o">.</span><span class="n">to_output</span><span class="p">(</span><span class="n">action</span><span class="p">)</span>  <span class="c1"># Convert action to output message</span>

        <span class="c1"># Update step_state (observation and action history)</span>
        <span class="n">new_state</span> <span class="o">=</span> <span class="n">params</span><span class="o">.</span><span class="n">update_state</span><span class="p">(</span><span class="n">step_state</span><span class="p">,</span> <span class="n">action</span><span class="p">)</span>  <span class="c1"># Update state</span>
        <span class="n">new_step_state</span> <span class="o">=</span> <span class="n">step_state</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="n">rng</span><span class="o">=</span><span class="n">rng</span><span class="p">,</span> <span class="n">state</span><span class="o">=</span><span class="n">new_state</span><span class="p">)</span>  <span class="c1"># Update step_state</span>
        <span class="k">return</span> <span class="n">new_step_state</span><span class="p">,</span> <span class="n">output</span>
</code></pre></div>

</div>
</div>
<div class="cell border-box-sizing code_cell rendered">
<div class="input">

<div class="highlight"><pre><span></span><code><span class="c1"># @title Example: Actuator</span>

<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Union</span>

<span class="kn">import</span> <span class="nn">jax</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">onp</span>
<span class="kn">from</span> <span class="nn">flax</span> <span class="kn">import</span> <span class="n">struct</span>

<span class="kn">from</span> <span class="nn">rex</span> <span class="kn">import</span> <span class="n">base</span>
<span class="kn">from</span> <span class="nn">rex.base</span> <span class="kn">import</span> <span class="n">GraphState</span><span class="p">,</span> <span class="n">StepState</span>
<span class="kn">from</span> <span class="nn">rex.jax_utils</span> <span class="kn">import</span> <span class="n">tree_dynamic_slice</span>
<span class="kn">from</span> <span class="nn">rex.node</span> <span class="kn">import</span> <span class="n">BaseNode</span>


<span class="nd">@struct</span><span class="o">.</span><span class="n">dataclass</span>
<span class="k">class</span> <span class="nc">ActuatorOutput</span><span class="p">(</span><span class="n">base</span><span class="o">.</span><span class="n">Base</span><span class="p">):</span>
<span class="w">    </span><span class="sd">&quot;&quot;&quot;Pendulum actuator output&quot;&quot;&quot;</span>

    <span class="n">action</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">typing</span><span class="o">.</span><span class="n">ArrayLike</span>  <span class="c1"># Torque to apply to the pendulum</span>


<span class="nd">@struct</span><span class="o">.</span><span class="n">dataclass</span>
<span class="k">class</span> <span class="nc">ActuatorParams</span><span class="p">(</span><span class="n">base</span><span class="o">.</span><span class="n">Base</span><span class="p">):</span>
<span class="w">    </span><span class="sd">&quot;&quot;&quot;Pendulum actuator param definition&quot;&quot;&quot;</span>

    <span class="n">actuator_delay</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">jax</span><span class="o">.</span><span class="n">typing</span><span class="o">.</span><span class="n">ArrayLike</span><span class="p">]</span>


<span class="k">class</span> <span class="nc">Actuator</span><span class="p">(</span><span class="n">BaseNode</span><span class="p">):</span>
<span class="w">    </span><span class="sd">&quot;&quot;&quot;This is a simple actuator node definition that could interface a real actuator.</span>

<span class="sd">    When interfacing real hardware, you would send the action to real hardware in the .step method.</span>
<span class="sd">    Optionally, you could also specify a startup routine that is called right before an episode starts.</span>
<span class="sd">    Finally, a stop routine is called after the episode is done.</span>
<span class="sd">    &quot;&quot;&quot;</span>

    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;No special initialization needed.&quot;&quot;&quot;</span>
        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">init_params</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">rng</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">Array</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">graph_state</span><span class="p">:</span> <span class="n">GraphState</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">ActuatorParams</span><span class="p">:</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Default params of the node.&quot;&quot;&quot;</span>
        <span class="n">actuator_delay</span> <span class="o">=</span> <span class="mf">0.05</span>
        <span class="k">return</span> <span class="n">ActuatorParams</span><span class="p">(</span><span class="n">actuator_delay</span><span class="o">=</span><span class="n">actuator_delay</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">init_output</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">rng</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">Array</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">graph_state</span><span class="p">:</span> <span class="n">GraphState</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">ActuatorOutput</span><span class="p">:</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Default output of the node.&quot;&quot;&quot;</span>
        <span class="k">return</span> <span class="n">ActuatorOutput</span><span class="p">(</span><span class="n">action</span><span class="o">=</span><span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mf">0.0</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">jnp</span><span class="o">.</span><span class="n">float32</span><span class="p">))</span>

    <span class="k">def</span> <span class="nf">startup</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">graph_state</span><span class="p">:</span> <span class="n">base</span><span class="o">.</span><span class="n">GraphState</span><span class="p">,</span> <span class="n">timeout</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Starts the node in the state specified by graph_state.</span>

<span class="sd">        This method is called right before an episode starts.</span>
<span class="sd">        It can be used to move (a real) robot to a starting position as specified by the graph_state.</span>

<span class="sd">        Not used when running in compiled mode.</span>
<span class="sd">        :param graph_state: The graph state.</span>
<span class="sd">        :param timeout: The timeout of the startup.</span>
<span class="sd">        :return: Whether the node has started successfully.</span>
<span class="sd">        &quot;&quot;&quot;</span>
        <span class="c1"># Move robot to starting position specified by graph_state (e.g. graph_state.state[&quot;agent&quot;].init_th)</span>
        <span class="k">return</span> <span class="kc">True</span>  <span class="c1"># Not doing anything here</span>

    <span class="k">def</span> <span class="nf">step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">step_state</span><span class="p">:</span> <span class="n">StepState</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">StepState</span><span class="p">,</span> <span class="n">ActuatorOutput</span><span class="p">]:</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;If we were to control a real robot, you would send the action to the robot here.&quot;&quot;&quot;</span>
        <span class="c1"># Prepare output</span>
        <span class="n">output</span> <span class="o">=</span> <span class="n">step_state</span><span class="o">.</span><span class="n">inputs</span><span class="p">[</span><span class="s2">&quot;agent&quot;</span><span class="p">][</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">data</span>
        <span class="n">output</span> <span class="o">=</span> <span class="n">ActuatorOutput</span><span class="p">(</span><span class="n">action</span><span class="o">=</span><span class="n">output</span><span class="o">.</span><span class="n">action</span><span class="p">)</span>

        <span class="k">def</span> <span class="nf">_apply_action</span><span class="p">(</span><span class="n">action</span><span class="p">):</span>
<span class="w">            </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd">            Not really doing anything here, just a dummy implementation.</span>
<span class="sd">            Include some side-effecting code here (e.g. sending the action to a real robot).</span>

<span class="sd">            The .step method may be jit-compiled, it is important to wrap any side-effecting code in a host_callback.</span>
<span class="sd">            See the jax documentation for more information on how to do this:</span>
<span class="sd">            https://jax.readthedocs.io/en/latest/external-callbacks.html</span>
<span class="sd">            &quot;&quot;&quot;</span>
            <span class="c1"># print(f&quot;Applying action: {action}&quot;) # Apply action to the robot</span>
            <span class="k">return</span> <span class="n">onp</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mf">1.0</span><span class="p">)</span>  <span class="c1"># Must match dtype and shape of return_shape</span>

        <span class="c1"># Apply action to the robot</span>
        <span class="n">return_shape</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mf">1.0</span><span class="p">)</span>  <span class="c1"># Must match dtype and shape of return_shape</span>
        <span class="n">_</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">experimental</span><span class="o">.</span><span class="n">io_callback</span><span class="p">(</span><span class="n">_apply_action</span><span class="p">,</span> <span class="n">return_shape</span><span class="p">,</span> <span class="n">output</span><span class="p">)</span>

        <span class="c1"># Update state</span>
        <span class="n">new_step_state</span> <span class="o">=</span> <span class="n">step_state</span>
        <span class="k">return</span> <span class="n">new_step_state</span><span class="p">,</span> <span class="n">output</span>

    <span class="k">def</span> <span class="nf">stop</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">timeout</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Stopping routine that is called after the episode is done.</span>

<span class="sd">        **IMPORTANT** It may happen that stop is called *before* the final .step call of an episode returns,</span>
<span class="sd">        which may cause unsafe behavior when the final step undoes the work of the .stop method.</span>
<span class="sd">        This should be handled by the user. For example, by stopping &quot;longer&quot; before returning here.</span>

<span class="sd">        Only ran when running asynchronously.</span>
<span class="sd">        :param timeout: The timeout of the stop</span>
<span class="sd">        :return: Whether the node has stopped successfully.</span>
<span class="sd">        &quot;&quot;&quot;</span>
        <span class="c1"># Stop the robot (e.g. set the torque to 0)</span>
        <span class="k">return</span> <span class="kc">True</span>


<span class="k">class</span> <span class="nc">SimActuator</span><span class="p">(</span><span class="n">BaseNode</span><span class="p">):</span>
<span class="w">    </span><span class="sd">&quot;&quot;&quot;This is a simple simulated actuator node definition that can either</span>
<span class="sd">    1. Feedthrough the agent&#39;s action (for normal operation, e.g., training).</span>
<span class="sd">       Optionally, you could include some noise or other modifications to the action.</span>
<span class="sd">    2. Reapply the recorded actuator outputs for system identification if available.</span>
<span class="sd">    &quot;&quot;&quot;</span>

    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="n">outputs</span><span class="p">:</span> <span class="n">ActuatorOutput</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Initialize Actuator for system identification.</span>

<span class="sd">        Here, we will reapply the recorded actuator outputs for system identification if available.</span>

<span class="sd">        :param outputs: Recorded actuator Outputs to be used for system identification.</span>
<span class="sd">        &quot;&quot;&quot;</span>
        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">_outputs</span> <span class="o">=</span> <span class="n">outputs</span>

    <span class="k">def</span> <span class="nf">init_params</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">rng</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">Array</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">graph_state</span><span class="p">:</span> <span class="n">GraphState</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">ActuatorParams</span><span class="p">:</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Default params of the node.&quot;&quot;&quot;</span>
        <span class="n">actuator_delay</span> <span class="o">=</span> <span class="mf">0.05</span>
        <span class="k">return</span> <span class="n">ActuatorParams</span><span class="p">(</span><span class="n">actuator_delay</span><span class="o">=</span><span class="n">actuator_delay</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">init_output</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">rng</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">Array</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">graph_state</span><span class="p">:</span> <span class="n">GraphState</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">ActuatorOutput</span><span class="p">:</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Default output of the node.&quot;&quot;&quot;</span>
        <span class="k">return</span> <span class="n">ActuatorOutput</span><span class="p">(</span><span class="n">action</span><span class="o">=</span><span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mf">0.0</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">jnp</span><span class="o">.</span><span class="n">float32</span><span class="p">))</span>

    <span class="k">def</span> <span class="nf">step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">step_state</span><span class="p">:</span> <span class="n">StepState</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">StepState</span><span class="p">,</span> <span class="n">ActuatorOutput</span><span class="p">]:</span>
        <span class="c1"># Get action from dataset if available, else use the one provided by the agent</span>
        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_outputs</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>  <span class="c1"># Use the recorded action (for system identification)</span>
            <span class="n">output</span> <span class="o">=</span> <span class="n">tree_dynamic_slice</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_outputs</span><span class="p">,</span> <span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">step_state</span><span class="o">.</span><span class="n">eps</span><span class="p">,</span> <span class="n">step_state</span><span class="o">.</span><span class="n">seq</span><span class="p">]))</span>
            <span class="n">output</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">tree_util</span><span class="o">.</span><span class="n">tree_map</span><span class="p">(</span><span class="k">lambda</span> <span class="n">_o</span><span class="p">:</span> <span class="n">_o</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">output</span><span class="p">)</span>
        <span class="k">else</span><span class="p">:</span>  <span class="c1"># Feedthrough the agent&#39;s action (for normal operation, e.g., training)</span>
            <span class="n">output</span> <span class="o">=</span> <span class="n">step_state</span><span class="o">.</span><span class="n">inputs</span><span class="p">[</span><span class="s2">&quot;agent&quot;</span><span class="p">][</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">data</span>
            <span class="n">output</span> <span class="o">=</span> <span class="n">ActuatorOutput</span><span class="p">(</span><span class="n">action</span><span class="o">=</span><span class="n">output</span><span class="o">.</span><span class="n">action</span><span class="p">)</span>
        <span class="n">new_step_state</span> <span class="o">=</span> <span class="n">step_state</span>
        <span class="k">return</span> <span class="n">new_step_state</span><span class="p">,</span> <span class="n">output</span>
</code></pre></div>

</div>
</div>
<div class="cell border-box-sizing code_cell rendered">
<div class="input">

<div class="highlight"><pre><span></span><code><span class="c1"># @title Example: Sensor</span>

<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Union</span>

<span class="kn">import</span> <span class="nn">jax</span>
<span class="kn">from</span> <span class="nn">flax</span> <span class="kn">import</span> <span class="n">struct</span>

<span class="kn">from</span> <span class="nn">rex</span> <span class="kn">import</span> <span class="n">base</span>
<span class="kn">from</span> <span class="nn">rex.base</span> <span class="kn">import</span> <span class="n">GraphState</span><span class="p">,</span> <span class="n">StepState</span>
<span class="kn">from</span> <span class="nn">rex.node</span> <span class="kn">import</span> <span class="n">BaseNode</span>


<span class="nd">@struct</span><span class="o">.</span><span class="n">dataclass</span>
<span class="k">class</span> <span class="nc">SensorOutput</span><span class="p">(</span><span class="n">base</span><span class="o">.</span><span class="n">Base</span><span class="p">):</span>
<span class="w">    </span><span class="sd">&quot;&quot;&quot;Output message definition of the sensor node.&quot;&quot;&quot;</span>

    <span class="n">th</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">jax</span><span class="o">.</span><span class="n">typing</span><span class="o">.</span><span class="n">ArrayLike</span><span class="p">]</span>
    <span class="n">thdot</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">jax</span><span class="o">.</span><span class="n">typing</span><span class="o">.</span><span class="n">ArrayLike</span><span class="p">]</span>


<span class="nd">@struct</span><span class="o">.</span><span class="n">dataclass</span>
<span class="k">class</span> <span class="nc">SensorParams</span><span class="p">(</span><span class="n">base</span><span class="o">.</span><span class="n">Base</span><span class="p">):</span>
<span class="w">    </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd">    Other than the sensor delay, we don&#39;t have any other parameters.</span>
<span class="sd">    You could add more parameters here if needed, such as noise levels etc.</span>
<span class="sd">    &quot;&quot;&quot;</span>

    <span class="n">sensor_delay</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">jax</span><span class="o">.</span><span class="n">typing</span><span class="o">.</span><span class="n">ArrayLike</span><span class="p">]</span>


<span class="nd">@struct</span><span class="o">.</span><span class="n">dataclass</span>
<span class="k">class</span> <span class="nc">SensorState</span><span class="p">:</span>
<span class="w">    </span><span class="sd">&quot;&quot;&quot;We use this state to record the reconstruction loss.&quot;&quot;&quot;</span>

    <span class="n">loss_th</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">jax</span><span class="o">.</span><span class="n">typing</span><span class="o">.</span><span class="n">ArrayLike</span><span class="p">]</span>
    <span class="n">loss_thdot</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">jax</span><span class="o">.</span><span class="n">typing</span><span class="o">.</span><span class="n">ArrayLike</span><span class="p">]</span>


<span class="k">class</span> <span class="nc">Sensor</span><span class="p">(</span><span class="n">BaseNode</span><span class="p">):</span>
<span class="w">    </span><span class="sd">&quot;&quot;&quot;This is a simple sensor node definition that interfaces a real sensor.</span>

<span class="sd">    When interfacing real hardware, you would grab the sensor measurement in the .step method.</span>
<span class="sd">    Optionally, you could also specify a startup routine that is called right before an episode starts.</span>
<span class="sd">    Finally, a stop routine is called after the episode is done.</span>
<span class="sd">    &quot;&quot;&quot;</span>

    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;No special initialization needed.&quot;&quot;&quot;</span>
        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">init_params</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">rng</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">Array</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">graph_state</span><span class="p">:</span> <span class="n">GraphState</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">SensorParams</span><span class="p">:</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Default params of the node.&quot;&quot;&quot;</span>
        <span class="n">sensor_delay</span> <span class="o">=</span> <span class="mf">0.05</span>
        <span class="k">return</span> <span class="n">SensorParams</span><span class="p">(</span><span class="n">sensor_delay</span><span class="o">=</span><span class="n">sensor_delay</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">init_output</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">rng</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">Array</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">graph_state</span><span class="p">:</span> <span class="n">GraphState</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">SensorOutput</span><span class="p">:</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Default output of the node.&quot;&quot;&quot;</span>
        <span class="c1"># Randomly define some initial sensor values</span>
        <span class="n">th</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">pi</span>
        <span class="n">thdot</span> <span class="o">=</span> <span class="mf">0.0</span>
        <span class="k">return</span> <span class="n">SensorOutput</span><span class="p">(</span><span class="n">th</span><span class="o">=</span><span class="n">th</span><span class="p">,</span> <span class="n">thdot</span><span class="o">=</span><span class="n">thdot</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">startup</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">graph_state</span><span class="p">:</span> <span class="n">base</span><span class="o">.</span><span class="n">GraphState</span><span class="p">,</span> <span class="n">timeout</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Starts the node in the state specified by graph_state.</span>

<span class="sd">        This method is called right before an episode starts.</span>
<span class="sd">        It can be used to move (a real) robot to a starting position as specified by the graph_state.</span>

<span class="sd">        Not used when running in compiled mode.</span>
<span class="sd">        :param graph_state: The graph state.</span>
<span class="sd">        :param timeout: The timeout of the startup.</span>
<span class="sd">        :return: Whether the node has started successfully.</span>
<span class="sd">        &quot;&quot;&quot;</span>
        <span class="k">return</span> <span class="kc">True</span>  <span class="c1"># Not doing anything here</span>

    <span class="k">def</span> <span class="nf">step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">step_state</span><span class="p">:</span> <span class="n">StepState</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">StepState</span><span class="p">,</span> <span class="n">SensorOutput</span><span class="p">]:</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;If we were to interface a real hardware, you would grab the sensor measurement here.&quot;&quot;&quot;</span>

<span class="w">        </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd">        As the .step method may be jit-compiled, it is important to wrap any side-effecting code in a host_callback.</span>
<span class="sd">        See the jax documentation for more information on how to do this:</span>
<span class="sd">        https://jax.readthedocs.io/en/latest/external-callbacks.html</span>
<span class="sd">        &quot;&quot;&quot;</span>
        <span class="n">world</span> <span class="o">=</span> <span class="n">step_state</span><span class="o">.</span><span class="n">inputs</span><span class="p">[</span><span class="s2">&quot;world&quot;</span><span class="p">][</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">data</span>

        <span class="k">def</span> <span class="nf">_grab_measurement</span><span class="p">():</span>
<span class="w">            </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd">            Not really doing anything here, just a dummy implementation.</span>
<span class="sd">            Include some side-effecting code here (e.g. grabbing measurement from sensor).</span>

<span class="sd">            The .step method may be jit-compiled, it is important to wrap any side-effecting code in a host_callback.</span>
<span class="sd">            See the jax documentation for more information on how to do this:</span>
<span class="sd">            https://jax.readthedocs.io/en/latest/external-callbacks.html</span>
<span class="sd">            &quot;&quot;&quot;</span>
            <span class="c1"># print(&quot;Grabbing sensor measurement&quot;)</span>
            <span class="n">sensor_msg</span> <span class="o">=</span> <span class="n">onp</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mf">1.0</span><span class="p">)</span>  <span class="c1"># Dummy sensor measurement (not actually used)</span>
            <span class="k">return</span> <span class="n">sensor_msg</span>  <span class="c1"># Must match dtype and shape of return_shape</span>

        <span class="c1"># Grab sensor measurement</span>
        <span class="n">return_shape</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mf">1.0</span><span class="p">)</span>  <span class="c1"># Must match dtype and shape of return_shape</span>
        <span class="n">_</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">experimental</span><span class="o">.</span><span class="n">io_callback</span><span class="p">(</span><span class="n">_grab_measurement</span><span class="p">,</span> <span class="n">return_shape</span><span class="p">)</span>

        <span class="c1"># Prepare output</span>
        <span class="n">output</span> <span class="o">=</span> <span class="n">SensorOutput</span><span class="p">(</span><span class="n">th</span><span class="o">=</span><span class="n">world</span><span class="o">.</span><span class="n">th</span><span class="p">,</span> <span class="n">thdot</span><span class="o">=</span><span class="n">world</span><span class="o">.</span><span class="n">thdot</span><span class="p">)</span>

        <span class="c1"># Update state (NOOP)</span>
        <span class="n">new_step_state</span> <span class="o">=</span> <span class="n">step_state</span>

        <span class="k">return</span> <span class="n">new_step_state</span><span class="p">,</span> <span class="n">output</span>

    <span class="k">def</span> <span class="nf">stop</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">timeout</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Stopping routine that is called after the episode is done.</span>

<span class="sd">        **IMPORTANT** It may happen that stop is called *before* the final .step call of an episode returns,</span>
<span class="sd">        which may cause unsafe behavior when the final step undoes the work of the .stop method.</span>
<span class="sd">        This should be handled by the user. For example, by stopping &quot;longer&quot; before returning here.</span>

<span class="sd">        Only ran when running asynchronously.</span>
<span class="sd">        :param timeout: The timeout of the stop</span>
<span class="sd">        :return: Whether the node has stopped successfully.</span>
<span class="sd">        &quot;&quot;&quot;</span>
        <span class="k">return</span> <span class="kc">True</span>  <span class="c1"># Not doing anything here</span>


<span class="k">class</span> <span class="nc">SimSensor</span><span class="p">(</span><span class="n">BaseNode</span><span class="p">):</span>
<span class="w">    </span><span class="sd">&quot;&quot;&quot;This is a simple simulated sensor node definition that can either</span>
<span class="sd">    1. Convert the world state into a realistic sensor measurement (for normal operation, e.g., training).</span>
<span class="sd">       Optionally, you could include some noise or other modifications to the sensor measurement.</span>
<span class="sd">    2. Calculate a reconstruction loss based on the sensor measurement and the recorded sensor outputs.</span>

<span class="sd">    By calculating and aggregating the reconstruction loss here, we take time-scale differences and delays into account.</span>
<span class="sd">    &quot;&quot;&quot;</span>

    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="n">outputs</span><span class="p">:</span> <span class="n">SensorOutput</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Initialize a simulated sensor for system identification.</span>

<span class="sd">        If outputs are provided, we will calculate the reconstruction loss based on the recorded sensor outputs.</span>

<span class="sd">        :param outputs: Recorded sensor Outputs to be used for system identification.</span>
<span class="sd">        &quot;&quot;&quot;</span>
        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">_outputs</span> <span class="o">=</span> <span class="n">outputs</span>

    <span class="k">def</span> <span class="nf">init_params</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">rng</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">Array</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">graph_state</span><span class="p">:</span> <span class="n">GraphState</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">SensorParams</span><span class="p">:</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Default params of the node.&quot;&quot;&quot;</span>
        <span class="n">sensor_delay</span> <span class="o">=</span> <span class="mf">0.05</span>
        <span class="k">return</span> <span class="n">SensorParams</span><span class="p">(</span><span class="n">sensor_delay</span><span class="o">=</span><span class="n">sensor_delay</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">init_state</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">rng</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">Array</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">graph_state</span><span class="p">:</span> <span class="n">GraphState</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">SensorState</span><span class="p">:</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Default state of the node.&quot;&quot;&quot;</span>
        <span class="k">return</span> <span class="n">SensorState</span><span class="p">(</span><span class="n">loss_th</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span> <span class="n">loss_thdot</span><span class="o">=</span><span class="mf">0.0</span><span class="p">)</span>  <span class="c1"># Initialize reconstruction loss to zero at the start of the episode</span>

    <span class="k">def</span> <span class="nf">init_output</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">rng</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">Array</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">graph_state</span><span class="p">:</span> <span class="n">GraphState</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">SensorOutput</span><span class="p">:</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Default output of the node.&quot;&quot;&quot;</span>
        <span class="c1"># Randomly define some initial sensor values</span>
        <span class="n">th</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">pi</span>
        <span class="n">thdot</span> <span class="o">=</span> <span class="mf">0.0</span>
        <span class="k">return</span> <span class="n">SensorOutput</span><span class="p">(</span><span class="n">th</span><span class="o">=</span><span class="n">th</span><span class="p">,</span> <span class="n">thdot</span><span class="o">=</span><span class="n">thdot</span><span class="p">)</span>  <span class="c1"># Fix the initial sensor values</span>

    <span class="k">def</span> <span class="nf">init_delays</span><span class="p">(</span>
        <span class="bp">self</span><span class="p">,</span> <span class="n">rng</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">Array</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">graph_state</span><span class="p">:</span> <span class="n">base</span><span class="o">.</span><span class="n">GraphState</span> <span class="o">=</span> <span class="kc">None</span>
    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">jax</span><span class="o">.</span><span class="n">typing</span><span class="o">.</span><span class="n">ArrayLike</span><span class="p">]]:</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Initialize trainable communication delays.</span>

<span class="sd">        **Note** These only include trainable delays that were specified while connecting the nodes.</span>

<span class="sd">        :param rng: Random number generator.</span>
<span class="sd">        :param graph_state: The graph state that may be used to get the default output.</span>
<span class="sd">        :return: Trainable delays (e.g., {input_name: delay}). Can be an incomplete dictionary.</span>
<span class="sd">                 Entries for non-trainable delays or non-existent connections are ignored.</span>
<span class="sd">        &quot;&quot;&quot;</span>
        <span class="n">graph_state</span> <span class="o">=</span> <span class="n">graph_state</span> <span class="ow">or</span> <span class="n">GraphState</span><span class="p">()</span>
        <span class="n">params</span> <span class="o">=</span> <span class="n">graph_state</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">init_params</span><span class="p">(</span><span class="n">rng</span><span class="p">,</span> <span class="n">graph_state</span><span class="p">))</span>
        <span class="n">delays</span> <span class="o">=</span> <span class="p">{</span><span class="s2">&quot;world&quot;</span><span class="p">:</span> <span class="n">params</span><span class="o">.</span><span class="n">sensor_delay</span><span class="p">}</span>
        <span class="k">return</span> <span class="n">delays</span>

    <span class="k">def</span> <span class="nf">step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">step_state</span><span class="p">:</span> <span class="n">StepState</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">StepState</span><span class="p">,</span> <span class="n">SensorOutput</span><span class="p">]:</span>
        <span class="c1"># Determine output</span>
        <span class="n">data</span> <span class="o">=</span> <span class="n">step_state</span><span class="o">.</span><span class="n">inputs</span><span class="p">[</span><span class="s2">&quot;world&quot;</span><span class="p">][</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">data</span>
        <span class="n">output</span> <span class="o">=</span> <span class="n">SensorOutput</span><span class="p">(</span><span class="n">th</span><span class="o">=</span><span class="n">data</span><span class="o">.</span><span class="n">th</span><span class="p">,</span> <span class="n">thdot</span><span class="o">=</span><span class="n">data</span><span class="o">.</span><span class="n">thdot</span><span class="p">)</span>

        <span class="c1"># Calculate loss</span>
        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_outputs</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>  <span class="c1"># Calculate reconstruction loss and aggregate in state</span>
            <span class="n">output_rec</span> <span class="o">=</span> <span class="n">tree_dynamic_slice</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_outputs</span><span class="p">,</span> <span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">step_state</span><span class="o">.</span><span class="n">eps</span><span class="p">,</span> <span class="n">step_state</span><span class="o">.</span><span class="n">seq</span><span class="p">]))</span>
            <span class="n">output_rec</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">tree_util</span><span class="o">.</span><span class="n">tree_map</span><span class="p">(</span><span class="k">lambda</span> <span class="n">_o</span><span class="p">:</span> <span class="n">_o</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">output_rec</span><span class="p">)</span>
            <span class="n">th_rec</span><span class="p">,</span> <span class="n">thdot_rec</span> <span class="o">=</span> <span class="n">output_rec</span><span class="o">.</span><span class="n">th</span><span class="p">,</span> <span class="n">output_rec</span><span class="o">.</span><span class="n">thdot</span>
            <span class="n">state</span> <span class="o">=</span> <span class="n">step_state</span><span class="o">.</span><span class="n">state</span>
            <span class="n">loss_th</span> <span class="o">=</span> <span class="n">state</span><span class="o">.</span><span class="n">loss_th</span> <span class="o">+</span> <span class="p">(</span><span class="n">jnp</span><span class="o">.</span><span class="n">sin</span><span class="p">(</span><span class="n">output</span><span class="o">.</span><span class="n">th</span><span class="p">)</span> <span class="o">-</span> <span class="n">jnp</span><span class="o">.</span><span class="n">sin</span><span class="p">(</span><span class="n">th_rec</span><span class="p">))</span> <span class="o">**</span> <span class="mi">2</span> <span class="o">+</span> <span class="p">(</span><span class="n">jnp</span><span class="o">.</span><span class="n">cos</span><span class="p">(</span><span class="n">output</span><span class="o">.</span><span class="n">th</span><span class="p">)</span> <span class="o">-</span> <span class="n">jnp</span><span class="o">.</span><span class="n">cos</span><span class="p">(</span><span class="n">th_rec</span><span class="p">))</span> <span class="o">**</span> <span class="mi">2</span>
            <span class="n">loss_thdot</span> <span class="o">=</span> <span class="n">state</span><span class="o">.</span><span class="n">loss_thdot</span> <span class="o">+</span> <span class="p">(</span><span class="n">output</span><span class="o">.</span><span class="n">thdot</span> <span class="o">-</span> <span class="n">thdot_rec</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span>
            <span class="n">new_state</span> <span class="o">=</span> <span class="n">state</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="n">loss_th</span><span class="o">=</span><span class="n">loss_th</span><span class="p">,</span> <span class="n">loss_thdot</span><span class="o">=</span><span class="n">loss_thdot</span><span class="p">)</span>
        <span class="k">else</span><span class="p">:</span>  <span class="c1"># NOOP</span>
            <span class="n">new_state</span> <span class="o">=</span> <span class="n">step_state</span><span class="o">.</span><span class="n">state</span>

        <span class="c1"># Update step_state</span>
        <span class="n">new_step_state</span> <span class="o">=</span> <span class="n">step_state</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="n">state</span><span class="o">=</span><span class="n">new_state</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">new_step_state</span><span class="p">,</span> <span class="n">output</span>
</code></pre></div>

</div>
</div>
<div class="cell border-box-sizing code_cell rendered">
<div class="input">

<div class="highlight"><pre><span></span><code><span class="c1"># @title Example: ODE simulation node</span>

<span class="kn">from</span> <span class="nn">math</span> <span class="kn">import</span> <span class="n">ceil</span>
<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Union</span>

<span class="kn">import</span> <span class="nn">jax</span>
<span class="kn">from</span> <span class="nn">flax</span> <span class="kn">import</span> <span class="n">struct</span>

<span class="kn">from</span> <span class="nn">rex</span> <span class="kn">import</span> <span class="n">base</span>
<span class="kn">from</span> <span class="nn">rex.base</span> <span class="kn">import</span> <span class="n">GraphState</span><span class="p">,</span> <span class="n">StepState</span>
<span class="kn">from</span> <span class="nn">rex.node</span> <span class="kn">import</span> <span class="n">BaseWorld</span>


<span class="nd">@struct</span><span class="o">.</span><span class="n">dataclass</span>
<span class="k">class</span> <span class="nc">OdeParams</span><span class="p">(</span><span class="n">base</span><span class="o">.</span><span class="n">Base</span><span class="p">):</span>
<span class="w">    </span><span class="sd">&quot;&quot;&quot;Pendulum ode param definition&quot;&quot;&quot;</span>

    <span class="n">max_speed</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">jax</span><span class="o">.</span><span class="n">typing</span><span class="o">.</span><span class="n">ArrayLike</span><span class="p">]</span>
    <span class="n">J</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">jax</span><span class="o">.</span><span class="n">typing</span><span class="o">.</span><span class="n">ArrayLike</span><span class="p">]</span>
    <span class="n">mass</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">jax</span><span class="o">.</span><span class="n">typing</span><span class="o">.</span><span class="n">ArrayLike</span><span class="p">]</span>
    <span class="n">length</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">jax</span><span class="o">.</span><span class="n">typing</span><span class="o">.</span><span class="n">ArrayLike</span><span class="p">]</span>
    <span class="n">b</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">jax</span><span class="o">.</span><span class="n">typing</span><span class="o">.</span><span class="n">ArrayLike</span><span class="p">]</span>
    <span class="n">K</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">jax</span><span class="o">.</span><span class="n">typing</span><span class="o">.</span><span class="n">ArrayLike</span><span class="p">]</span>
    <span class="n">R</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">jax</span><span class="o">.</span><span class="n">typing</span><span class="o">.</span><span class="n">ArrayLike</span><span class="p">]</span>
    <span class="n">c</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">jax</span><span class="o">.</span><span class="n">typing</span><span class="o">.</span><span class="n">ArrayLike</span><span class="p">]</span>
    <span class="n">dt_substeps_min</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">jax</span><span class="o">.</span><span class="n">typing</span><span class="o">.</span><span class="n">ArrayLike</span><span class="p">]</span> <span class="o">=</span> <span class="n">struct</span><span class="o">.</span><span class="n">field</span><span class="p">(</span><span class="n">pytree_node</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
    <span class="n">dt</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">jax</span><span class="o">.</span><span class="n">typing</span><span class="o">.</span><span class="n">ArrayLike</span><span class="p">]</span> <span class="o">=</span> <span class="n">struct</span><span class="o">.</span><span class="n">field</span><span class="p">(</span><span class="n">pytree_node</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>

    <span class="nd">@property</span>
    <span class="k">def</span> <span class="nf">substeps</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">int</span><span class="p">:</span>
        <span class="n">substeps</span> <span class="o">=</span> <span class="n">ceil</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dt</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">dt_substeps_min</span><span class="p">)</span>
        <span class="k">return</span> <span class="nb">int</span><span class="p">(</span><span class="n">substeps</span><span class="p">)</span>

    <span class="nd">@property</span>
    <span class="k">def</span> <span class="nf">dt_substeps</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">float</span><span class="p">:</span>
        <span class="n">substeps</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">substeps</span>
        <span class="n">dt_substeps</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dt</span> <span class="o">/</span> <span class="n">substeps</span>
        <span class="k">return</span> <span class="n">dt_substeps</span>

    <span class="k">def</span> <span class="nf">step</span><span class="p">(</span>
        <span class="bp">self</span><span class="p">,</span> <span class="n">substeps</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">dt_substeps</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">typing</span><span class="o">.</span><span class="n">ArrayLike</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="s2">&quot;OdeState&quot;</span><span class="p">,</span> <span class="n">us</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">typing</span><span class="o">.</span><span class="n">ArrayLike</span>
    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="s2">&quot;OdeState&quot;</span><span class="p">,</span> <span class="s2">&quot;OdeState&quot;</span><span class="p">]:</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Step the pendulum ode.&quot;&quot;&quot;</span>

        <span class="k">def</span> <span class="nf">_scan_fn</span><span class="p">(</span><span class="n">_x</span><span class="p">,</span> <span class="n">_u</span><span class="p">):</span>
            <span class="n">next_x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_runge_kutta4</span><span class="p">(</span><span class="n">dt_substeps</span><span class="p">,</span> <span class="n">_x</span><span class="p">,</span> <span class="n">_u</span><span class="p">)</span>
            <span class="c1"># Clip velocity</span>
            <span class="n">clip_thdot</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">clip</span><span class="p">(</span><span class="n">next_x</span><span class="o">.</span><span class="n">thdot</span><span class="p">,</span> <span class="o">-</span><span class="bp">self</span><span class="o">.</span><span class="n">max_speed</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_speed</span><span class="p">)</span>
            <span class="n">next_x</span> <span class="o">=</span> <span class="n">next_x</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="n">thdot</span><span class="o">=</span><span class="n">clip_thdot</span><span class="p">)</span>
            <span class="k">return</span> <span class="n">next_x</span><span class="p">,</span> <span class="n">next_x</span>

        <span class="n">x_final</span><span class="p">,</span> <span class="n">x_substeps</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">lax</span><span class="o">.</span><span class="n">scan</span><span class="p">(</span><span class="n">_scan_fn</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">us</span><span class="p">,</span> <span class="n">length</span><span class="o">=</span><span class="n">substeps</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">x_final</span><span class="p">,</span> <span class="n">x_substeps</span>

    <span class="k">def</span> <span class="nf">_runge_kutta4</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dt</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">typing</span><span class="o">.</span><span class="n">ArrayLike</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="s2">&quot;OdeState&quot;</span><span class="p">,</span> <span class="n">u</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">typing</span><span class="o">.</span><span class="n">ArrayLike</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;OdeState&quot;</span><span class="p">:</span>
        <span class="n">k1</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_ode</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">u</span><span class="p">)</span>
        <span class="n">k2</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_ode</span><span class="p">(</span><span class="n">x</span> <span class="o">+</span> <span class="n">k1</span> <span class="o">*</span> <span class="n">dt</span> <span class="o">*</span> <span class="mf">0.5</span><span class="p">,</span> <span class="n">u</span><span class="p">)</span>
        <span class="n">k3</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_ode</span><span class="p">(</span><span class="n">x</span> <span class="o">+</span> <span class="n">k2</span> <span class="o">*</span> <span class="n">dt</span> <span class="o">*</span> <span class="mf">0.5</span><span class="p">,</span> <span class="n">u</span><span class="p">)</span>
        <span class="n">k4</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_ode</span><span class="p">(</span><span class="n">x</span> <span class="o">+</span> <span class="n">k3</span> <span class="o">*</span> <span class="n">dt</span><span class="p">,</span> <span class="n">u</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">x</span> <span class="o">+</span> <span class="p">(</span><span class="n">k1</span> <span class="o">+</span> <span class="n">k2</span> <span class="o">*</span> <span class="mi">2</span> <span class="o">+</span> <span class="n">k3</span> <span class="o">*</span> <span class="mi">2</span> <span class="o">+</span> <span class="n">k4</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="n">dt</span> <span class="o">/</span> <span class="mi">6</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">_ode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="s2">&quot;OdeState&quot;</span><span class="p">,</span> <span class="n">u</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">typing</span><span class="o">.</span><span class="n">ArrayLike</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;OdeState&quot;</span><span class="p">:</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;dx function for the pendulum ode&quot;&quot;&quot;</span>
        <span class="c1"># Downward := [pi, 0], Upward := [0, 0]</span>
        <span class="n">g</span><span class="p">,</span> <span class="n">J</span><span class="p">,</span> <span class="n">m</span><span class="p">,</span> <span class="n">l</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">K</span><span class="p">,</span> <span class="n">R</span><span class="p">,</span> <span class="n">c</span> <span class="o">=</span> <span class="mf">9.81</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">J</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mass</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">length</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">b</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">K</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">R</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">c</span>  <span class="c1"># noqa: E741</span>
        <span class="n">th</span><span class="p">,</span> <span class="n">thdot</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">th</span><span class="p">,</span> <span class="n">x</span><span class="o">.</span><span class="n">thdot</span>
        <span class="n">activation</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">sign</span><span class="p">(</span><span class="n">thdot</span><span class="p">)</span>
        <span class="n">ddx</span> <span class="o">=</span> <span class="p">(</span><span class="n">u</span> <span class="o">*</span> <span class="n">K</span> <span class="o">/</span> <span class="n">R</span> <span class="o">+</span> <span class="n">m</span> <span class="o">*</span> <span class="n">g</span> <span class="o">*</span> <span class="n">l</span> <span class="o">*</span> <span class="n">jnp</span><span class="o">.</span><span class="n">sin</span><span class="p">(</span><span class="n">th</span><span class="p">)</span> <span class="o">-</span> <span class="n">b</span> <span class="o">*</span> <span class="n">thdot</span> <span class="o">-</span> <span class="n">thdot</span> <span class="o">*</span> <span class="n">K</span> <span class="o">*</span> <span class="n">K</span> <span class="o">/</span> <span class="n">R</span> <span class="o">-</span> <span class="n">c</span> <span class="o">*</span> <span class="n">activation</span><span class="p">)</span> <span class="o">/</span> <span class="n">J</span>
        <span class="k">return</span> <span class="n">OdeState</span><span class="p">(</span><span class="n">th</span><span class="o">=</span><span class="n">thdot</span><span class="p">,</span> <span class="n">thdot</span><span class="o">=</span><span class="n">ddx</span><span class="p">,</span> <span class="n">loss_task</span><span class="o">=</span><span class="mf">0.0</span><span class="p">)</span>  <span class="c1"># No derivative for loss_task</span>


<span class="nd">@struct</span><span class="o">.</span><span class="n">dataclass</span>
<span class="k">class</span> <span class="nc">OdeState</span><span class="p">(</span><span class="n">base</span><span class="o">.</span><span class="n">Base</span><span class="p">):</span>
<span class="w">    </span><span class="sd">&quot;&quot;&quot;Pendulum state definition&quot;&quot;&quot;</span>

    <span class="n">loss_task</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">jax</span><span class="o">.</span><span class="n">typing</span><span class="o">.</span><span class="n">ArrayLike</span><span class="p">]</span>
    <span class="n">th</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">jax</span><span class="o">.</span><span class="n">typing</span><span class="o">.</span><span class="n">ArrayLike</span><span class="p">]</span>
    <span class="n">thdot</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">jax</span><span class="o">.</span><span class="n">typing</span><span class="o">.</span><span class="n">ArrayLike</span><span class="p">]</span>


<span class="nd">@struct</span><span class="o">.</span><span class="n">dataclass</span>
<span class="k">class</span> <span class="nc">OdeOutput</span><span class="p">(</span><span class="n">base</span><span class="o">.</span><span class="n">Base</span><span class="p">):</span>
<span class="w">    </span><span class="sd">&quot;&quot;&quot;World output definition&quot;&quot;&quot;</span>

    <span class="n">th</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">jax</span><span class="o">.</span><span class="n">typing</span><span class="o">.</span><span class="n">ArrayLike</span><span class="p">]</span>
    <span class="n">thdot</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">jax</span><span class="o">.</span><span class="n">typing</span><span class="o">.</span><span class="n">ArrayLike</span><span class="p">]</span>


<span class="k">class</span> <span class="nc">OdeWorld</span><span class="p">(</span><span class="n">BaseWorld</span><span class="p">):</span>  <span class="c1"># We inherit from BaseWorld for convenience, but you can inherit from BaseNode if you want</span>
    <span class="k">def</span> <span class="nf">init_params</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">rng</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">Array</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">graph_state</span><span class="p">:</span> <span class="n">GraphState</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">OdeParams</span><span class="p">:</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Default params of the node.&quot;&quot;&quot;</span>
        <span class="k">return</span> <span class="n">OdeParams</span><span class="p">(</span>
            <span class="n">max_speed</span><span class="o">=</span><span class="mf">40.0</span><span class="p">,</span>  <span class="c1"># Clip angular velocity to this value</span>
            <span class="n">J</span><span class="o">=</span><span class="mf">0.000159931461600856</span><span class="p">,</span>  <span class="c1"># 0.000159931461600856,</span>
            <span class="n">mass</span><span class="o">=</span><span class="mf">0.0508581731919534</span><span class="p">,</span>  <span class="c1"># 0.0508581731919534,</span>
            <span class="n">length</span><span class="o">=</span><span class="mf">0.0415233722862552</span><span class="p">,</span>  <span class="c1"># 0.0415233722862552,</span>
            <span class="n">b</span><span class="o">=</span><span class="mf">1.43298488e-05</span><span class="p">,</span>  <span class="c1"># 1.43298488358436e-05,</span>
            <span class="n">K</span><span class="o">=</span><span class="mf">0.03333912</span><span class="p">,</span>  <span class="c1"># 0.0333391179016334,</span>
            <span class="n">R</span><span class="o">=</span><span class="mf">7.73125142</span><span class="p">,</span>  <span class="c1"># 7.73125142447252,</span>
            <span class="n">c</span><span class="o">=</span><span class="mf">0.000975041213361349</span><span class="p">,</span>  <span class="c1"># 0.000975041213361349,</span>
            <span class="c1"># Backend parameters</span>
            <span class="n">dt_substeps_min</span><span class="o">=</span><span class="mi">1</span> <span class="o">/</span> <span class="mi">100</span><span class="p">,</span>  <span class="c1"># Minimum substep size for ode integration</span>
            <span class="n">dt</span><span class="o">=</span><span class="mi">1</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">rate</span><span class="p">,</span>  <span class="c1"># Time step per .step() call</span>
        <span class="p">)</span>

    <span class="k">def</span> <span class="nf">init_state</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">rng</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">Array</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">graph_state</span><span class="p">:</span> <span class="n">GraphState</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">OdeState</span><span class="p">:</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Default state of the node.&quot;&quot;&quot;</span>
        <span class="n">graph_state</span> <span class="o">=</span> <span class="n">graph_state</span> <span class="ow">or</span> <span class="n">GraphState</span><span class="p">()</span>

        <span class="c1"># Try to grab state from graph_state</span>
        <span class="n">state</span> <span class="o">=</span> <span class="n">graph_state</span><span class="o">.</span><span class="n">state</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;agent&quot;</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
        <span class="n">init_th</span> <span class="o">=</span> <span class="n">state</span><span class="o">.</span><span class="n">init_th</span> <span class="k">if</span> <span class="n">state</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">jnp</span><span class="o">.</span><span class="n">pi</span>
        <span class="n">init_thdot</span> <span class="o">=</span> <span class="n">state</span><span class="o">.</span><span class="n">init_thdot</span> <span class="k">if</span> <span class="n">state</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="k">else</span> <span class="mf">0.0</span>
        <span class="k">return</span> <span class="n">OdeState</span><span class="p">(</span><span class="n">th</span><span class="o">=</span><span class="n">init_th</span><span class="p">,</span> <span class="n">thdot</span><span class="o">=</span><span class="n">init_thdot</span><span class="p">,</span> <span class="n">loss_task</span><span class="o">=</span><span class="mf">0.0</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">init_output</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">rng</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">Array</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">graph_state</span><span class="p">:</span> <span class="n">GraphState</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">OdeOutput</span><span class="p">:</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Default output of the node.&quot;&quot;&quot;</span>
        <span class="n">graph_state</span> <span class="o">=</span> <span class="n">graph_state</span> <span class="ow">or</span> <span class="n">GraphState</span><span class="p">()</span>
        <span class="c1"># Grab output from state</span>
        <span class="n">world_state</span> <span class="o">=</span> <span class="n">graph_state</span><span class="o">.</span><span class="n">state</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">init_state</span><span class="p">(</span><span class="n">rng</span><span class="p">,</span> <span class="n">graph_state</span><span class="p">))</span>
        <span class="k">return</span> <span class="n">OdeOutput</span><span class="p">(</span><span class="n">th</span><span class="o">=</span><span class="n">world_state</span><span class="o">.</span><span class="n">th</span><span class="p">,</span> <span class="n">thdot</span><span class="o">=</span><span class="n">world_state</span><span class="o">.</span><span class="n">thdot</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">init_delays</span><span class="p">(</span>
        <span class="bp">self</span><span class="p">,</span> <span class="n">rng</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">Array</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">graph_state</span><span class="p">:</span> <span class="n">base</span><span class="o">.</span><span class="n">GraphState</span> <span class="o">=</span> <span class="kc">None</span>
    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">jax</span><span class="o">.</span><span class="n">typing</span><span class="o">.</span><span class="n">ArrayLike</span><span class="p">]]:</span>
        <span class="n">graph_state</span> <span class="o">=</span> <span class="n">graph_state</span> <span class="ow">or</span> <span class="n">GraphState</span><span class="p">()</span>
        <span class="n">params</span> <span class="o">=</span> <span class="n">graph_state</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;actuator&quot;</span><span class="p">)</span>
        <span class="n">delays</span> <span class="o">=</span> <span class="p">{}</span>
        <span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="s2">&quot;actuator_delay&quot;</span><span class="p">):</span>
            <span class="n">delays</span><span class="p">[</span><span class="s2">&quot;actuator&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">params</span><span class="o">.</span><span class="n">actuator_delay</span>
        <span class="k">return</span> <span class="n">delays</span>

    <span class="k">def</span> <span class="nf">step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">step_state</span><span class="p">:</span> <span class="n">StepState</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">StepState</span><span class="p">,</span> <span class="n">OdeOutput</span><span class="p">]:</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Step the node.&quot;&quot;&quot;</span>
        <span class="c1"># Unpack StepState</span>
        <span class="n">_</span><span class="p">,</span> <span class="n">state</span><span class="p">,</span> <span class="n">params</span><span class="p">,</span> <span class="n">inputs</span> <span class="o">=</span> <span class="n">step_state</span><span class="o">.</span><span class="n">rng</span><span class="p">,</span> <span class="n">step_state</span><span class="o">.</span><span class="n">state</span><span class="p">,</span> <span class="n">step_state</span><span class="o">.</span><span class="n">params</span><span class="p">,</span> <span class="n">step_state</span><span class="o">.</span><span class="n">inputs</span>

        <span class="c1"># Apply dynamics</span>
        <span class="n">u</span> <span class="o">=</span> <span class="n">inputs</span><span class="p">[</span><span class="s2">&quot;actuator&quot;</span><span class="p">]</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">action</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span>  <span class="c1"># [-1] to get the latest action, [0] reduces the dimension to scalar</span>
        <span class="n">us</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">u</span><span class="p">]</span> <span class="o">*</span> <span class="n">params</span><span class="o">.</span><span class="n">substeps</span><span class="p">)</span>
        <span class="n">new_state</span> <span class="o">=</span> <span class="n">params</span><span class="o">.</span><span class="n">step</span><span class="p">(</span><span class="n">params</span><span class="o">.</span><span class="n">substeps</span><span class="p">,</span> <span class="n">params</span><span class="o">.</span><span class="n">dt_substeps</span><span class="p">,</span> <span class="n">state</span><span class="p">,</span> <span class="n">us</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
        <span class="n">next_th</span><span class="p">,</span> <span class="n">next_thdot</span> <span class="o">=</span> <span class="n">new_state</span><span class="o">.</span><span class="n">th</span><span class="p">,</span> <span class="n">new_state</span><span class="o">.</span><span class="n">thdot</span>
        <span class="n">output</span> <span class="o">=</span> <span class="n">OdeOutput</span><span class="p">(</span><span class="n">th</span><span class="o">=</span><span class="n">next_th</span><span class="p">,</span> <span class="n">thdot</span><span class="o">=</span><span class="n">next_thdot</span><span class="p">)</span>  <span class="c1"># Prepare output</span>

        <span class="c1"># Calculate cost (penalize angle error, angular velocity and input voltage)</span>
        <span class="n">norm_next_th</span> <span class="o">=</span> <span class="n">next_th</span> <span class="o">-</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">jnp</span><span class="o">.</span><span class="n">pi</span> <span class="o">*</span> <span class="n">jnp</span><span class="o">.</span><span class="n">floor</span><span class="p">((</span><span class="n">next_th</span> <span class="o">+</span> <span class="n">jnp</span><span class="o">.</span><span class="n">pi</span><span class="p">)</span> <span class="o">/</span> <span class="p">(</span><span class="mi">2</span> <span class="o">*</span> <span class="n">jnp</span><span class="o">.</span><span class="n">pi</span><span class="p">))</span>
        <span class="n">loss_task</span> <span class="o">=</span> <span class="n">state</span><span class="o">.</span><span class="n">loss_task</span> <span class="o">+</span> <span class="n">norm_next_th</span><span class="o">**</span><span class="mi">2</span> <span class="o">+</span> <span class="mf">0.1</span> <span class="o">*</span> <span class="p">(</span><span class="n">next_thdot</span> <span class="o">/</span> <span class="p">(</span><span class="mi">1</span> <span class="o">+</span> <span class="mi">10</span> <span class="o">*</span> <span class="nb">abs</span><span class="p">(</span><span class="n">norm_next_th</span><span class="p">)))</span> <span class="o">**</span> <span class="mi">2</span> <span class="o">+</span> <span class="mf">0.01</span> <span class="o">*</span> <span class="n">u</span><span class="o">**</span><span class="mi">2</span>

        <span class="c1"># Update state</span>
        <span class="n">new_state</span> <span class="o">=</span> <span class="n">new_state</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="n">loss_task</span><span class="o">=</span><span class="n">loss_task</span><span class="p">)</span>
        <span class="n">new_step_state</span> <span class="o">=</span> <span class="n">step_state</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="n">state</span><span class="o">=</span><span class="n">new_state</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">new_step_state</span><span class="p">,</span> <span class="n">output</span>
</code></pre></div>

</div>
</div>
<div class="cell border-box-sizing code_cell rendered">
<div class="input">

<div class="highlight"><pre><span></span><code><span class="c1"># @title Example: Brax simulation node</span>
<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Union</span>

<span class="kn">import</span> <span class="nn">jax</span>
<span class="kn">from</span> <span class="nn">flax</span> <span class="kn">import</span> <span class="n">struct</span>

<span class="kn">from</span> <span class="nn">rex</span> <span class="kn">import</span> <span class="n">base</span>
<span class="kn">from</span> <span class="nn">rex.base</span> <span class="kn">import</span> <span class="n">GraphState</span><span class="p">,</span> <span class="n">StepState</span>
<span class="kn">from</span> <span class="nn">rex.node</span> <span class="kn">import</span> <span class="n">BaseWorld</span>


<span class="k">try</span><span class="p">:</span>
    <span class="kn">from</span> <span class="nn">brax.generalized</span> <span class="kn">import</span> <span class="n">pipeline</span> <span class="k">as</span> <span class="n">gen_pipeline</span>
    <span class="kn">from</span> <span class="nn">brax.io</span> <span class="kn">import</span> <span class="n">mjcf</span>
    <span class="kn">from</span> <span class="nn">brax.positional</span> <span class="kn">import</span> <span class="n">pipeline</span> <span class="k">as</span> <span class="n">pos_pipeline</span>
    <span class="kn">from</span> <span class="nn">brax.spring</span> <span class="kn">import</span> <span class="n">pipeline</span> <span class="k">as</span> <span class="n">spring_pipeline</span>

    <span class="n">Systems</span> <span class="o">=</span> <span class="n">Union</span><span class="p">[</span><span class="n">gen_pipeline</span><span class="o">.</span><span class="n">System</span><span class="p">,</span> <span class="n">spring_pipeline</span><span class="o">.</span><span class="n">System</span><span class="p">,</span> <span class="n">pos_pipeline</span><span class="o">.</span><span class="n">System</span><span class="p">]</span>
    <span class="n">Pipelines</span> <span class="o">=</span> <span class="n">Union</span><span class="p">[</span><span class="n">gen_pipeline</span><span class="o">.</span><span class="n">State</span><span class="p">,</span> <span class="n">spring_pipeline</span><span class="o">.</span><span class="n">State</span><span class="p">,</span> <span class="n">pos_pipeline</span><span class="o">.</span><span class="n">State</span><span class="p">]</span>
<span class="k">except</span> <span class="ne">ModuleNotFoundError</span> <span class="k">as</span> <span class="n">e</span><span class="p">:</span>
    <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Brax not installed. Install it with `pip install brax`&quot;</span><span class="p">)</span>
    <span class="k">raise</span> <span class="n">e</span>


<span class="nd">@struct</span><span class="o">.</span><span class="n">dataclass</span>
<span class="k">class</span> <span class="nc">BraxParams</span><span class="p">(</span><span class="n">base</span><span class="o">.</span><span class="n">Base</span><span class="p">):</span>
    <span class="n">max_speed</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">jax</span><span class="o">.</span><span class="n">typing</span><span class="o">.</span><span class="n">ArrayLike</span><span class="p">]</span>
    <span class="n">damping</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">jax</span><span class="o">.</span><span class="n">typing</span><span class="o">.</span><span class="n">ArrayLike</span><span class="p">]</span>
    <span class="n">armature</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">jax</span><span class="o">.</span><span class="n">typing</span><span class="o">.</span><span class="n">ArrayLike</span><span class="p">]</span>
    <span class="n">gear</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">jax</span><span class="o">.</span><span class="n">typing</span><span class="o">.</span><span class="n">ArrayLike</span><span class="p">]</span>
    <span class="n">mass_weight</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">jax</span><span class="o">.</span><span class="n">typing</span><span class="o">.</span><span class="n">ArrayLike</span><span class="p">]</span>
    <span class="n">radius_weight</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">jax</span><span class="o">.</span><span class="n">typing</span><span class="o">.</span><span class="n">ArrayLike</span><span class="p">]</span>
    <span class="n">offset</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">jax</span><span class="o">.</span><span class="n">typing</span><span class="o">.</span><span class="n">ArrayLike</span><span class="p">]</span>
    <span class="n">friction_loss</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">jax</span><span class="o">.</span><span class="n">typing</span><span class="o">.</span><span class="n">ArrayLike</span><span class="p">]</span>
    <span class="n">backend</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="n">struct</span><span class="o">.</span><span class="n">field</span><span class="p">(</span><span class="n">pytree_node</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
    <span class="n">dt</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">jax</span><span class="o">.</span><span class="n">typing</span><span class="o">.</span><span class="n">ArrayLike</span><span class="p">]</span> <span class="o">=</span> <span class="n">struct</span><span class="o">.</span><span class="n">field</span><span class="p">(</span><span class="n">pytree_node</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>

    <span class="nd">@property</span>
    <span class="k">def</span> <span class="nf">substeps</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">int</span><span class="p">:</span>
        <span class="n">dt_substeps_per_backend</span> <span class="o">=</span> <span class="p">{</span><span class="s2">&quot;generalized&quot;</span><span class="p">:</span> <span class="mi">1</span> <span class="o">/</span> <span class="mi">100</span><span class="p">,</span> <span class="s2">&quot;spring&quot;</span><span class="p">:</span> <span class="mi">1</span> <span class="o">/</span> <span class="mi">100</span><span class="p">,</span> <span class="s2">&quot;positional&quot;</span><span class="p">:</span> <span class="mi">1</span> <span class="o">/</span> <span class="mi">100</span><span class="p">}[</span><span class="bp">self</span><span class="o">.</span><span class="n">backend</span><span class="p">]</span>
        <span class="n">substeps</span> <span class="o">=</span> <span class="n">ceil</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dt</span> <span class="o">/</span> <span class="n">dt_substeps_per_backend</span><span class="p">)</span>
        <span class="k">return</span> <span class="nb">int</span><span class="p">(</span><span class="n">substeps</span><span class="p">)</span>

    <span class="nd">@property</span>
    <span class="k">def</span> <span class="nf">dt_substeps</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">float</span><span class="p">:</span>
        <span class="n">substeps</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">substeps</span>
        <span class="n">dt_substeps</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dt</span> <span class="o">/</span> <span class="n">substeps</span>
        <span class="k">return</span> <span class="n">dt_substeps</span>

    <span class="nd">@property</span>
    <span class="k">def</span> <span class="nf">pipeline</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Pipelines</span><span class="p">:</span>
        <span class="k">return</span> <span class="p">{</span><span class="s2">&quot;generalized&quot;</span><span class="p">:</span> <span class="n">gen_pipeline</span><span class="p">,</span> <span class="s2">&quot;spring&quot;</span><span class="p">:</span> <span class="n">spring_pipeline</span><span class="p">,</span> <span class="s2">&quot;positional&quot;</span><span class="p">:</span> <span class="n">pos_pipeline</span><span class="p">}[</span><span class="bp">self</span><span class="o">.</span><span class="n">backend</span><span class="p">]</span>

    <span class="nd">@property</span>
    <span class="k">def</span> <span class="nf">sys</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Systems</span><span class="p">:</span>
        <span class="n">base_sys</span> <span class="o">=</span> <span class="n">mjcf</span><span class="o">.</span><span class="n">loads</span><span class="p">(</span><span class="n">DISK_PENDULUM_XML</span><span class="p">)</span>
        <span class="c1"># Appropriately replace parameters for the disk pendulum</span>
        <span class="n">itransform</span> <span class="o">=</span> <span class="n">base_sys</span><span class="o">.</span><span class="n">link</span><span class="o">.</span><span class="n">inertia</span><span class="o">.</span><span class="n">transform</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="n">pos</span><span class="o">=</span><span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([[</span><span class="mf">0.0</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">offset</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">]]))</span>
        <span class="n">i</span> <span class="o">=</span> <span class="n">base_sys</span><span class="o">.</span><span class="n">link</span><span class="o">.</span><span class="n">inertia</span><span class="o">.</span><span class="n">i</span><span class="o">.</span><span class="n">at</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">set</span><span class="p">(</span>
            <span class="mf">0.5</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">mass_weight</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">radius_weight</span><span class="o">**</span><span class="mi">2</span>
        <span class="p">)</span>  <span class="c1"># inertia of cylinder in local frame.</span>
        <span class="n">inertia</span> <span class="o">=</span> <span class="n">base_sys</span><span class="o">.</span><span class="n">link</span><span class="o">.</span><span class="n">inertia</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="n">transform</span><span class="o">=</span><span class="n">itransform</span><span class="p">,</span> <span class="n">mass</span><span class="o">=</span><span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="bp">self</span><span class="o">.</span><span class="n">mass_weight</span><span class="p">]),</span> <span class="n">i</span><span class="o">=</span><span class="n">i</span><span class="p">)</span>
        <span class="n">link</span> <span class="o">=</span> <span class="n">base_sys</span><span class="o">.</span><span class="n">link</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="n">inertia</span><span class="o">=</span><span class="n">inertia</span><span class="p">)</span>
        <span class="n">actuator</span> <span class="o">=</span> <span class="n">base_sys</span><span class="o">.</span><span class="n">actuator</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="n">gear</span><span class="o">=</span><span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="bp">self</span><span class="o">.</span><span class="n">gear</span><span class="p">]))</span>
        <span class="n">dof</span> <span class="o">=</span> <span class="n">base_sys</span><span class="o">.</span><span class="n">dof</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="n">armature</span><span class="o">=</span><span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="bp">self</span><span class="o">.</span><span class="n">armature</span><span class="p">]),</span> <span class="n">damping</span><span class="o">=</span><span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="bp">self</span><span class="o">.</span><span class="n">damping</span><span class="p">]))</span>
        <span class="n">opt</span> <span class="o">=</span> <span class="n">base_sys</span><span class="o">.</span><span class="n">opt</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="n">timestep</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dt_substeps</span><span class="p">)</span>
        <span class="n">new_sys</span> <span class="o">=</span> <span class="n">base_sys</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="n">link</span><span class="o">=</span><span class="n">link</span><span class="p">,</span> <span class="n">actuator</span><span class="o">=</span><span class="n">actuator</span><span class="p">,</span> <span class="n">dof</span><span class="o">=</span><span class="n">dof</span><span class="p">,</span> <span class="n">opt</span><span class="o">=</span><span class="n">opt</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">new_sys</span>

    <span class="k">def</span> <span class="nf">step</span><span class="p">(</span>
        <span class="bp">self</span><span class="p">,</span> <span class="n">substeps</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">dt_substeps</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">typing</span><span class="o">.</span><span class="n">ArrayLike</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Pipelines</span><span class="p">,</span> <span class="n">us</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">typing</span><span class="o">.</span><span class="n">ArrayLike</span>
    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">Pipelines</span><span class="p">,</span> <span class="n">Pipelines</span><span class="p">]:</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Step the pendulum ode.&quot;&quot;&quot;</span>
        <span class="c1"># Appropriately replace timestep for the disk pendulum</span>
        <span class="n">sys</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">sys</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="n">opt</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">sys</span><span class="o">.</span><span class="n">opt</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="n">timestep</span><span class="o">=</span><span class="n">dt_substeps</span><span class="p">))</span>

        <span class="k">def</span> <span class="nf">_scan_fn</span><span class="p">(</span><span class="n">_x</span><span class="p">,</span> <span class="n">_u</span><span class="p">):</span>
            <span class="c1"># Add friction loss</span>
            <span class="n">thdot</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">qd</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
            <span class="n">activation</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">sign</span><span class="p">(</span><span class="n">thdot</span><span class="p">)</span>
            <span class="n">friction</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">friction_loss</span> <span class="o">*</span> <span class="n">activation</span> <span class="o">/</span> <span class="n">sys</span><span class="o">.</span><span class="n">actuator</span><span class="o">.</span><span class="n">gear</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
            <span class="n">_u_friction</span> <span class="o">=</span> <span class="n">_u</span> <span class="o">-</span> <span class="n">friction</span>
            <span class="c1"># Step</span>
            <span class="n">next_x</span> <span class="o">=</span> <span class="n">gen_pipeline</span><span class="o">.</span><span class="n">step</span><span class="p">(</span><span class="n">sys</span><span class="p">,</span> <span class="n">_x</span><span class="p">,</span> <span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">_u_friction</span><span class="p">)[</span><span class="kc">None</span><span class="p">])</span>
            <span class="c1"># Clip velocity</span>
            <span class="n">next_x</span> <span class="o">=</span> <span class="n">next_x</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="n">qd</span><span class="o">=</span><span class="n">jnp</span><span class="o">.</span><span class="n">clip</span><span class="p">(</span><span class="n">next_x</span><span class="o">.</span><span class="n">qd</span><span class="p">,</span> <span class="o">-</span><span class="bp">self</span><span class="o">.</span><span class="n">max_speed</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_speed</span><span class="p">))</span>
            <span class="k">return</span> <span class="n">next_x</span><span class="p">,</span> <span class="n">next_x</span>

        <span class="n">x_final</span><span class="p">,</span> <span class="n">x_substeps</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">lax</span><span class="o">.</span><span class="n">scan</span><span class="p">(</span><span class="n">_scan_fn</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">us</span><span class="p">,</span> <span class="n">length</span><span class="o">=</span><span class="n">substeps</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">x_final</span><span class="p">,</span> <span class="n">x_substeps</span>


<span class="nd">@struct</span><span class="o">.</span><span class="n">dataclass</span>
<span class="k">class</span> <span class="nc">BraxState</span><span class="p">(</span><span class="n">base</span><span class="o">.</span><span class="n">Base</span><span class="p">):</span>
<span class="w">    </span><span class="sd">&quot;&quot;&quot;Pendulum state definition&quot;&quot;&quot;</span>

    <span class="n">loss_task</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">jax</span><span class="o">.</span><span class="n">typing</span><span class="o">.</span><span class="n">ArrayLike</span><span class="p">]</span>
    <span class="n">pipeline_state</span><span class="p">:</span> <span class="n">Pipelines</span>

    <span class="nd">@property</span>
    <span class="k">def</span> <span class="nf">th</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">pipeline_state</span><span class="o">.</span><span class="n">q</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="mi">0</span><span class="p">]</span>

    <span class="nd">@property</span>
    <span class="k">def</span> <span class="nf">thdot</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">pipeline_state</span><span class="o">.</span><span class="n">qd</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="mi">0</span><span class="p">]</span>


<span class="nd">@struct</span><span class="o">.</span><span class="n">dataclass</span>
<span class="k">class</span> <span class="nc">BraxOutput</span><span class="p">(</span><span class="n">base</span><span class="o">.</span><span class="n">Base</span><span class="p">):</span>
<span class="w">    </span><span class="sd">&quot;&quot;&quot;World output definition&quot;&quot;&quot;</span>

    <span class="n">th</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">jax</span><span class="o">.</span><span class="n">typing</span><span class="o">.</span><span class="n">ArrayLike</span><span class="p">]</span>
    <span class="n">thdot</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">jax</span><span class="o">.</span><span class="n">typing</span><span class="o">.</span><span class="n">ArrayLike</span><span class="p">]</span>


<span class="k">class</span> <span class="nc">BraxWorld</span><span class="p">(</span><span class="n">BaseWorld</span><span class="p">):</span>  <span class="c1"># We inherit from BaseWorld for convenience, but you can inherit from BaseNode if you want</span>
    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="n">backend</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;generalized&quot;</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">backend</span> <span class="o">=</span> <span class="n">backend</span>

    <span class="k">def</span> <span class="nf">init_params</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">rng</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">Array</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">graph_state</span><span class="p">:</span> <span class="n">GraphState</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">BraxParams</span><span class="p">:</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Default params of the node.&quot;&quot;&quot;</span>
        <span class="k">return</span> <span class="n">BraxParams</span><span class="p">(</span>
            <span class="c1"># Realistic parameters for the disk pendulum</span>
            <span class="n">max_speed</span><span class="o">=</span><span class="mf">40.0</span><span class="p">,</span>
            <span class="n">damping</span><span class="o">=</span><span class="mf">0.00015877</span><span class="p">,</span>
            <span class="n">armature</span><span class="o">=</span><span class="mf">6.4940527e-06</span><span class="p">,</span>
            <span class="n">gear</span><span class="o">=</span><span class="mf">0.00428677</span><span class="p">,</span>
            <span class="n">mass_weight</span><span class="o">=</span><span class="mf">0.05076142</span><span class="p">,</span>
            <span class="n">radius_weight</span><span class="o">=</span><span class="mf">0.05121992</span><span class="p">,</span>
            <span class="n">offset</span><span class="o">=</span><span class="mf">0.04161447</span><span class="p">,</span>
            <span class="n">friction_loss</span><span class="o">=</span><span class="mf">0.00097525</span><span class="p">,</span>
            <span class="c1"># Backend parameters</span>
            <span class="n">dt</span><span class="o">=</span><span class="mi">1</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">rate</span><span class="p">,</span>
            <span class="n">backend</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">backend</span><span class="p">,</span>
        <span class="p">)</span>

    <span class="k">def</span> <span class="nf">init_state</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">rng</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">Array</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">graph_state</span><span class="p">:</span> <span class="n">GraphState</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">BraxState</span><span class="p">:</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Default state of the node.&quot;&quot;&quot;</span>
        <span class="n">graph_state</span> <span class="o">=</span> <span class="n">graph_state</span> <span class="ow">or</span> <span class="n">GraphState</span><span class="p">()</span>

        <span class="c1"># Try to grab state from graph_state</span>
        <span class="n">state</span> <span class="o">=</span> <span class="n">graph_state</span><span class="o">.</span><span class="n">state</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;agent&quot;</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
        <span class="n">init_th</span> <span class="o">=</span> <span class="n">state</span><span class="o">.</span><span class="n">init_th</span> <span class="k">if</span> <span class="n">state</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">jnp</span><span class="o">.</span><span class="n">pi</span>
        <span class="n">init_thdot</span> <span class="o">=</span> <span class="n">state</span><span class="o">.</span><span class="n">init_thdot</span> <span class="k">if</span> <span class="n">state</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="k">else</span> <span class="mf">0.0</span>

        <span class="c1"># Set the initial state of the disk pendulum</span>
        <span class="n">params</span> <span class="o">=</span> <span class="n">graph_state</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">init_params</span><span class="p">(</span><span class="n">rng</span><span class="p">,</span> <span class="n">graph_state</span><span class="p">))</span>
        <span class="n">sys</span> <span class="o">=</span> <span class="n">params</span><span class="o">.</span><span class="n">sys</span>
        <span class="n">q</span> <span class="o">=</span> <span class="n">sys</span><span class="o">.</span><span class="n">init_q</span><span class="o">.</span><span class="n">at</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">set</span><span class="p">(</span><span class="n">init_th</span><span class="p">)</span>
        <span class="n">qd</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">init_thdot</span><span class="p">])</span>
        <span class="n">pipeline_state</span> <span class="o">=</span> <span class="n">params</span><span class="o">.</span><span class="n">pipeline</span><span class="o">.</span><span class="n">init</span><span class="p">(</span><span class="n">sys</span><span class="p">,</span> <span class="n">q</span><span class="p">,</span> <span class="n">qd</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">BraxState</span><span class="p">(</span><span class="n">pipeline_state</span><span class="o">=</span><span class="n">pipeline_state</span><span class="p">,</span> <span class="n">loss_task</span><span class="o">=</span><span class="mf">0.0</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">init_output</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">rng</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">Array</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">graph_state</span><span class="p">:</span> <span class="n">GraphState</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">BraxOutput</span><span class="p">:</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Default output of the node.&quot;&quot;&quot;</span>
        <span class="n">graph_state</span> <span class="o">=</span> <span class="n">graph_state</span> <span class="ow">or</span> <span class="n">GraphState</span><span class="p">()</span>
        <span class="c1"># Grab output from state</span>
        <span class="n">state</span> <span class="o">=</span> <span class="n">graph_state</span><span class="o">.</span><span class="n">state</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">init_state</span><span class="p">(</span><span class="n">rng</span><span class="p">,</span> <span class="n">graph_state</span><span class="p">))</span>
        <span class="k">return</span> <span class="n">BraxOutput</span><span class="p">(</span><span class="n">th</span><span class="o">=</span><span class="n">state</span><span class="o">.</span><span class="n">pipeline_state</span><span class="o">.</span><span class="n">q</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">thdot</span><span class="o">=</span><span class="n">state</span><span class="o">.</span><span class="n">pipeline_state</span><span class="o">.</span><span class="n">qd</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>

    <span class="k">def</span> <span class="nf">step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">step_state</span><span class="p">:</span> <span class="n">StepState</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">StepState</span><span class="p">,</span> <span class="n">BraxOutput</span><span class="p">]:</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Step the node.&quot;&quot;&quot;</span>

        <span class="c1"># Unpack StepState</span>
        <span class="n">_</span><span class="p">,</span> <span class="n">state</span><span class="p">,</span> <span class="n">params</span><span class="p">,</span> <span class="n">inputs</span> <span class="o">=</span> <span class="n">step_state</span><span class="o">.</span><span class="n">rng</span><span class="p">,</span> <span class="n">step_state</span><span class="o">.</span><span class="n">state</span><span class="p">,</span> <span class="n">step_state</span><span class="o">.</span><span class="n">params</span><span class="p">,</span> <span class="n">step_state</span><span class="o">.</span><span class="n">inputs</span>

        <span class="c1"># Apply dynamics</span>
        <span class="n">u</span> <span class="o">=</span> <span class="n">inputs</span><span class="p">[</span><span class="s2">&quot;actuator&quot;</span><span class="p">]</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">action</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span>  <span class="c1"># [-1] to get the latest action, [0] reduces the dimension to scalar</span>
        <span class="n">us</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">u</span><span class="p">]</span> <span class="o">*</span> <span class="n">params</span><span class="o">.</span><span class="n">substeps</span><span class="p">)</span>
        <span class="n">x</span> <span class="o">=</span> <span class="n">state</span><span class="o">.</span><span class="n">pipeline_state</span>
        <span class="n">next_x</span> <span class="o">=</span> <span class="n">params</span><span class="o">.</span><span class="n">step</span><span class="p">(</span><span class="n">params</span><span class="o">.</span><span class="n">substeps</span><span class="p">,</span> <span class="n">params</span><span class="o">.</span><span class="n">dt_substeps</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">us</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
        <span class="n">new_state</span> <span class="o">=</span> <span class="n">state</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="n">pipeline_state</span><span class="o">=</span><span class="n">next_x</span><span class="p">)</span>
        <span class="n">next_th</span><span class="p">,</span> <span class="n">next_thdot</span> <span class="o">=</span> <span class="n">new_state</span><span class="o">.</span><span class="n">th</span><span class="p">,</span> <span class="n">new_state</span><span class="o">.</span><span class="n">thdot</span>
        <span class="n">output</span> <span class="o">=</span> <span class="n">BraxOutput</span><span class="p">(</span><span class="n">th</span><span class="o">=</span><span class="n">next_th</span><span class="p">,</span> <span class="n">thdot</span><span class="o">=</span><span class="n">next_thdot</span><span class="p">)</span>  <span class="c1"># Prepare output</span>

        <span class="c1"># Calculate cost (penalize angle error, angular velocity and input voltage)</span>
        <span class="n">norm_next_th</span> <span class="o">=</span> <span class="n">next_th</span> <span class="o">-</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">jnp</span><span class="o">.</span><span class="n">pi</span> <span class="o">*</span> <span class="n">jnp</span><span class="o">.</span><span class="n">floor</span><span class="p">((</span><span class="n">next_th</span> <span class="o">+</span> <span class="n">jnp</span><span class="o">.</span><span class="n">pi</span><span class="p">)</span> <span class="o">/</span> <span class="p">(</span><span class="mi">2</span> <span class="o">*</span> <span class="n">jnp</span><span class="o">.</span><span class="n">pi</span><span class="p">))</span>
        <span class="n">loss_task</span> <span class="o">=</span> <span class="n">state</span><span class="o">.</span><span class="n">loss_task</span> <span class="o">+</span> <span class="n">norm_next_th</span><span class="o">**</span><span class="mi">2</span> <span class="o">+</span> <span class="mf">0.1</span> <span class="o">*</span> <span class="p">(</span><span class="n">next_thdot</span> <span class="o">/</span> <span class="p">(</span><span class="mi">1</span> <span class="o">+</span> <span class="mi">10</span> <span class="o">*</span> <span class="nb">abs</span><span class="p">(</span><span class="n">norm_next_th</span><span class="p">)))</span> <span class="o">**</span> <span class="mi">2</span> <span class="o">+</span> <span class="mf">0.01</span> <span class="o">*</span> <span class="n">u</span><span class="o">**</span><span class="mi">2</span>

        <span class="c1"># Update state</span>
        <span class="n">new_state</span> <span class="o">=</span> <span class="n">new_state</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="n">loss_task</span><span class="o">=</span><span class="n">loss_task</span><span class="p">)</span>
        <span class="n">new_step_state</span> <span class="o">=</span> <span class="n">step_state</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="n">state</span><span class="o">=</span><span class="n">new_state</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">new_step_state</span><span class="p">,</span> <span class="n">output</span>


<span class="n">DISK_PENDULUM_XML</span> <span class="o">=</span> <span class="s2">&quot;&quot;&quot;</span>
<span class="s2">&lt;mujoco model=&quot;disk_pendulum&quot;&gt;</span>
<span class="s2">    &lt;compiler inertiafromgeom=&quot;auto&quot; angle=&quot;radian&quot; coordinate=&quot;local&quot; eulerseq=&quot;xyz&quot; autolimits=&quot;true&quot;/&gt;</span>
<span class="s2">    &lt;option gravity=&quot;0 0 -9.81&quot; timestep=&quot;0.01&quot; iterations=&quot;10&quot;/&gt;</span>
<span class="s2">    &lt;custom&gt;</span>
<span class="s2">        &lt;numeric data=&quot;10&quot; name=&quot;constraint_ang_damping&quot;/&gt; &lt;!-- positional &amp; spring --&gt;</span>
<span class="s2">        &lt;numeric data=&quot;1&quot; name=&quot;spring_inertia_scale&quot;/&gt;  &lt;!-- positional &amp; spring --&gt;</span>
<span class="s2">        &lt;numeric data=&quot;0&quot; name=&quot;ang_damping&quot;/&gt;  &lt;!-- positional &amp; spring --&gt;</span>
<span class="s2">        &lt;numeric data=&quot;0&quot; name=&quot;spring_mass_scale&quot;/&gt;  &lt;!-- positional &amp; spring --&gt;</span>
<span class="s2">        &lt;numeric data=&quot;0.5&quot; name=&quot;joint_scale_pos&quot;/&gt; &lt;!-- positional --&gt;</span>
<span class="s2">        &lt;numeric data=&quot;0.1&quot; name=&quot;joint_scale_ang&quot;/&gt; &lt;!-- positional --&gt;</span>
<span class="s2">        &lt;numeric data=&quot;3000&quot; name=&quot;constraint_stiffness&quot;/&gt;  &lt;!-- spring --&gt;</span>
<span class="s2">        &lt;numeric data=&quot;10000&quot; name=&quot;constraint_limit_stiffness&quot;/&gt;  &lt;!-- spring --&gt;</span>
<span class="s2">        &lt;numeric data=&quot;50&quot; name=&quot;constraint_vel_damping&quot;/&gt;  &lt;!-- spring --&gt;</span>
<span class="s2">        &lt;numeric data=&quot;10&quot; name=&quot;solver_maxls&quot;/&gt;  &lt;!-- generalized --&gt;</span>
<span class="s2">    &lt;/custom&gt;</span>

<span class="s2">    &lt;asset&gt;</span>
<span class="s2">        &lt;texture builtin=&quot;flat&quot; height=&quot;1278&quot; mark=&quot;cross&quot; markrgb=&quot;1 1 1&quot; name=&quot;texgeom&quot; random=&quot;0.01&quot; rgb1=&quot;0.8 0.6 0.4&quot; rgb2=&quot;0.8 0.6 0.4&quot; type=&quot;cube&quot; width=&quot;127&quot;/&gt;</span>
<span class="s2">        &lt;material name=&quot;geom&quot; texture=&quot;texgeom&quot; texuniform=&quot;true&quot;/&gt;</span>
<span class="s2">    &lt;/asset&gt;</span>

<span class="s2">    &lt;default&gt;</span>
<span class="s2">        &lt;geom contype=&quot;0&quot; friction=&quot;1 0.1 0.1&quot; material=&quot;geom&quot;/&gt;</span>
<span class="s2">    &lt;/default&gt;</span>

<span class="s2">    &lt;worldbody&gt;</span>
<span class="s2">        &lt;light cutoff=&quot;100&quot; diffuse=&quot;1 1 1&quot; dir=&quot;-0 0 -1.3&quot; directional=&quot;true&quot; exponent=&quot;1&quot; pos=&quot;0 0 1.3&quot; specular=&quot;.1 .1 .1&quot;/&gt;</span>
<span class="s2">        &lt;geom name=&quot;table&quot; type=&quot;plane&quot; pos=&quot;0 0.0 -0.1&quot; size=&quot;1 1 0.1&quot; contype=&quot;8&quot; conaffinity=&quot;11&quot; condim=&quot;3&quot;/&gt;</span>
<span class="s2">        &lt;body name=&quot;disk&quot; pos=&quot;0.0 0.0 0.0&quot; euler=&quot;1.5708 0.0 0.0&quot;&gt;</span>
<span class="s2">            &lt;joint name=&quot;hinge_joint&quot; type=&quot;hinge&quot; axis=&quot;0 0 1&quot; range=&quot;-180 180&quot; armature=&quot;0.00022993&quot; damping=&quot;0.0001&quot; limited=&quot;false&quot;/&gt;</span>
<span class="s2">            &lt;geom name=&quot;disk_geom&quot; type=&quot;cylinder&quot; size=&quot;0.06 0.001&quot; contype=&quot;0&quot; conaffinity=&quot;0&quot; condim=&quot;3&quot; mass=&quot;0.0&quot;/&gt;</span>
<span class="s2">            &lt;geom name=&quot;mass_geom&quot; type=&quot;cylinder&quot; size=&quot;0.02 0.005&quot; contype=&quot;0&quot; conaffinity=&quot;0&quot;  condim=&quot;3&quot; rgba=&quot;0.04 0.04 0.04 1&quot;</span>
<span class="s2">                  pos=&quot;0.0 0.04 0.&quot; mass=&quot;0.05085817&quot;/&gt;</span>
<span class="s2">        &lt;/body&gt;</span>
<span class="s2">    &lt;/worldbody&gt;</span>

<span class="s2">    &lt;actuator&gt;</span>
<span class="s2">        &lt;motor joint=&quot;hinge_joint&quot; ctrllimited=&quot;false&quot; ctrlrange=&quot;-3.0 3.0&quot;  gear=&quot;0.01&quot;/&gt;</span>
<span class="s2">    &lt;/actuator&gt;</span>
<span class="s2">&lt;/mujoco&gt;</span>
<span class="s2">&quot;&quot;&quot;</span>
</code></pre></div>

</div>
</div>
<div class="cell border-box-sizing code_cell rendered">
<div class="input">

<div class="highlight"><pre><span></span><code>
</code></pre></div>

</div>
</div>












                
              </article>
            </div>
          
          
<script>var target=document.getElementById(location.hash.slice(1));target&&target.name&&(target.checked=target.name.startsWith("__tabbed_"))</script>
        </div>
        
      </main>
      
        <footer class="md-footer">
  
  <div class="md-footer-meta md-typeset">
    <div class="md-footer-meta__inner md-grid">
      <div class="md-copyright">
  
  
    Made with
    <a href="https://squidfunk.github.io/mkdocs-material/" target="_blank" rel="noopener">
      Material for MkDocs
    </a>
  
</div>
      
    </div>
  </div>
</footer>
      
    </div>
    <div class="md-dialog" data-md-component="dialog">
      <div class="md-dialog__inner md-typeset"></div>
    </div>
    
    
    <script id="__config" type="application/json">{"base": "..", "features": ["navigation.sections", "toc.integrate", "header.autohide"], "search": "../assets/javascripts/workers/search.6ce7567c.min.js", "translations": {"clipboard.copied": "Copied to clipboard", "clipboard.copy": "Copy to clipboard", "search.result.more.one": "1 more on this page", "search.result.more.other": "# more on this page", "search.result.none": "No matching documents", "search.result.one": "1 matching document", "search.result.other": "# matching documents", "search.result.placeholder": "Type to start searching", "search.result.term.missing": "Missing", "select.version": "Select version"}}</script>
    
    
      <script src="../assets/javascripts/bundle.83f73b43.min.js"></script>
      
        <script src="../_static/mathjax.js"></script>
      
        <script src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
      
    
  </body>
</html>