{"config":{"lang":["en"],"separator":"[\\s\\-]+","pipeline":["stopWordFilter"]},"docs":[{"location":"index.html","title":"Getting Started","text":"<p>Rex is a JAX-powered framework for sim-to-real robotics.</p> <p>Key features:</p> <ul> <li>Graph-based design: Model asynchronous systems with nodes for sensing, actuation, and computation.</li> <li>Latency-aware modeling: Simulate delay effects for hardware, computation, and communication channels.</li> <li>Real-time and parallelized runtimes: Run real-world experiments or accelerated parallelized simulations.</li> <li>Seamless integration with JAX: Utilize JAX's autodiff, JIT compilation, and GPU/TPU acceleration.</li> <li>System identification tools: Estimate dynamics and delays directly from real-world data.</li> <li>Modular and extensible: Compatible with various simulation engines (e.g., Brax, MuJoCo).</li> <li>Unified sim2real pipeline: Train delay-aware policies in simulation and deploy them on real-world systems.</li> </ul>"},{"location":"index.html#sim-to-real-workflow","title":"Sim-to-Real Workflow","text":"<ol> <li>Interface Real Systems: Define nodes for sensors, actuators, and computation to represent real-world systems.</li> <li>Build Simulation: Swap real-world nodes with simulated ones (e.g., physics engines, motor dynamics).</li> <li>System Identification: Estimate system dynamics and delays from real-world data.</li> <li>Policy Training: Train delay-aware policies in simulation, accounting for realistic dynamics and delays.</li> <li>Evaluation: Evaluate trained policies on the real-world system, and iterate on the design.</li> </ol>"},{"location":"index.html#installation","title":"Installation","text":"<pre><code>pip install rex-lib\n</code></pre> <p>Requires Python 3.9+ and JAX 0.4.30+.</p>"},{"location":"index.html#quick-example","title":"Quick example","text":"<p>Here's a simple example of a pendulum system.  The real-world system is defined with nodes interfacing hardware for sensing, actuation: <pre><code>from rex.asynchronous import AsyncGraph\nfrom rex.examples.pendulum import Actuator, Agent, Sensor\n\nsensor = Sensor(rate=50)        # 50 Hz sampling rate\nagent = Agent(rate=30)          # 30 Hz policy execution rate\nactuator = Actuator(rate=50)    # 50 Hz control rate\nnodes = dict(sensor=sensor, agent=agent, actuator=actuator)\n\nagent.connect(sensor)       # Agent receives sensor data\nactuator.connect(agent)     # Actuator receives agent commands\ngraph = AsyncGraph(nodes, agent) # Graph for real-world execution\n\ngraph_state = graph.init()  # Initial states of all nodes\ngraph.warmup(graph_state)   # Jit-compiles the graph (only once).\nfor _ in range(100):        # Run the graph for 100 steps\n    graph_state = graph.run(graph_state) # Run for one step\ngraph.stop()                # Stop asynchronous nodes\ndata = graph.get_record()   # Get recorded data from the graph\n</code></pre> In simulation, we replace the hardware-interfacing nodes with simulated ones, add delay models, and add a physics simulation node: <pre><code>from distrax import Normal\nfrom rex.constants import Clock, RealTimeFactor\nfrom rex.asynchronous import AsyncGraph\nfrom rex.examples.pendulum import SimActuator, Agent, SimSensor, BraxWorld\n\nsensor = SimSensor(rate=50, delay_dist=Normal(0.01, 0.001))     # Process delay\nagent = Agent(rate=30, delay_dist=Normal(0.02, 0.005))          # Computational delay\nactuator = SimActuator(rate=50, delay_dist=Normal(0.01, 0.001)) # Process delay\nworld = BraxWorld(rate=100)  # 100 Hz physics simulation\nnodes = dict(sensor=sensor, agent=agent, actuator=actuator, world=world)\n\nsensor.connect(world, delay_dist=Normal(0.001, 0.001)) # Sensor delay\nagent.connect(sensor, delay_dist=Normal(0.001, 0.001)) # Communication delay\nactuator.connect(agent, delay_dist=Normal(0.001, 0.001)) # Communication delay\nworld.connect(actuator, delay_dist=Normal(0.001, 0.001), # Actuator delay\n              skip=True) # Breaks algebraic loop in the graph\ngraph = AsyncGraph(nodes, agent,\n                   clock=Clock.SIMULATED, # Simulates based on delay_dist\n                   real_time_factor=RealTimeFactor.FAST_AS_POSSIBLE)\n\ngraph_state = graph.init()  # Initial states of all nodes\ngraph.warmup(graph_state)   # Jit-compiles the graph\nfor _ in range(100):        # Run the graph for 100 steps\n    graph_state = graph.run(graph_state) # Run for one step\ngraph.stop()                # Stop asynchronous nodes\ndata = graph.get_record()   # Get recorded data from the graph\n</code></pre> Nodes are defined using JAX's PyTrees: <pre><code>from rex.node import BaseNode\n\nclass Agent(BaseNode):\n    def init_params(self, rng=None, graph_state=None):\n        return SomePyTree(a=..., b=...)\n\n    def init_state(self, rng=None, graph_state=None):\n        return SomePyTree(x1=..., x2=...)\n\n    def init_output(self, rng=None, graph_state=None):\n        return SomePyTree(y1=..., y2=...)\n\n    # Jit-compiled via graph.warmup for faster execution\n    def step(self, step_state): # Called at Node's rate\n        ss = step_state  # Shorten name\n        # Read params, and current state\n        params, state = ss.params, ss.state\n        # Current episode, sequence, timestamp\n        eps, seq, ts = ss.eps, ss.seq, ss.ts\n        # Grab the data, and I/O timestamps\n        cam = ss.inputs[\"sensor\"] # Received messages \n        cam.data, cam.ts_send, cam.ts_recv\n        ... # Some computation for new_state, output\n        new_state = SomePyTree(x1=..., x2=...)\n        output = SomePyTree(y1=..., y2=...)\n        # Update step_state for next step call\n        new_ss = ss.replace(state=new_state)\n        return new_ss, output # Sends output\n</code></pre></p>"},{"location":"index.html#next-steps","title":"Next steps","text":"<p>If this quick start has got you interested, then have a look at the sim2real.ipynb notebook for an example of a sim-to-real workflow using Rex.</p>"},{"location":"index.html#citation","title":"Citation","text":"<p>If you found this library to be useful in academic work, then please cite: (OpenReview)</p> <pre><code>@article{anonymous2024rex,\n  title={{REX: GPU-Accelerated Sim2Real Framework with Delay and Dynamics Estimation}},\n  author={Anonymous},\n  journal={OpenReview},\n  year={2024}\n}\n</code></pre> <p>(Also consider starring the project on GitHub.)</p>"},{"location":"citation.html","title":"Citation","text":"<p>If you found this library to be useful in academic work, then please cite: (OpenReview)</p> <pre><code>@article{anonymous2024rex,\n  title={{REX: GPU-Accelerated Sim2Real Framework with Delay and Dynamics Estimation}},\n  author={Anonymous},\n  journal={OpenReview},\n  year={2024}\n}\n</code></pre> <p>(Also consider starring the project on GitHub.)</p>"},{"location":"api/artificial.html","title":"Artificial","text":""},{"location":"api/artificial.html#rex.artificial.generate_graphs","title":"<code>rex.artificial.generate_graphs(nodes: Dict[str, BaseNode], ts_max: float, rng: jax.Array = None, num_episodes: int = 1) -&gt; Graph</code>","text":"<p>Generate graphs based on the nodes, computation delays, and communication delays.</p> <p>All nodes are assumed to have a rate and name attribute. Moreover, all nodes are assumed to run and communicate asynchronously. In other words, their timestamps are independent.</p> <p>Parameters:</p> <ul> <li> <code>nodes</code>               (<code>Dict[str, BaseNode]</code>)           \u2013            <p>Dictionary of nodes.</p> </li> <li> <code>ts_max</code>               (<code>float</code>)           \u2013            <p>Final time.</p> </li> <li> <code>rng</code>               (<code>Array</code>, default:                   <code>None</code> )           \u2013            <p>Random number generator.</p> </li> <li> <code>num_episodes</code>               (<code>int</code>, default:                   <code>1</code> )           \u2013            <p>Number of graphs to generate.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>Graph</code>           \u2013            <p>Graphs for each episode.</p> </li> </ul>"},{"location":"api/artificial.html#rex.artificial.augment_graphs","title":"<code>rex.artificial.augment_graphs(graphs: Graph, nodes: Dict[str, BaseNode], rng: jax.Array = None) -&gt; Graph</code>","text":"<p>Augment graphs based on the nodes, computation delays, and communication delays.</p> <p>With augmenting, the graphs are expanded with additional vertices and edges based on the provided nodes. Nodes not in graphs.vertices are added to the graphs according to the specified delay_dist. Edges between vertices are added for connections not present in graphs.edges.</p> <p>Parameters:</p> <ul> <li> <code>graphs</code>               (<code>Graph</code>)           \u2013            <p>Graphs to augment.</p> </li> <li> <code>nodes</code>               (<code>Dict[str, BaseNode]</code>)           \u2013            <p>Dictionary of nodes.</p> </li> <li> <code>rng</code>               (<code>Array</code>, default:                   <code>None</code> )           \u2013            <p>Random number generator.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>Graph</code>           \u2013            <p>Augmented graphs.</p> </li> </ul>"},{"location":"api/artificial.html#rex.base.Graph","title":"<code>rex.base.Graph</code>","text":"<p>A computation graph data structure that holds the vertices and edges of a computation graph.</p> <p>This data structure is used to represent the computation graph of a system. It holds the vertices and edges of the graph. The vertices represent consecutive step calls of nodes, and the edges represent the data flow between connected nodes.</p> <p>Stateful edges must not be included in the edges, but are implicitly assumed. In other words, consecutive sequence numbers of the same node are assumed to be connected.</p> <p>The graph should be directed and acyclic. Cycles are not allowed.</p> <p>Attributes:</p> <ul> <li> <code>vertices</code>               (<code>Dict[str, Vertex]</code>)           \u2013            <p>A dictionary of vertices. The keys are the unique names of the node type, and the values are the vertices.</p> </li> <li> <code>edges</code>               (<code>Dict[Tuple[str, str], Edge]</code>)           \u2013            <p>A dictionary of edges. The keys are of the form (n1, n2), where n1 and n2 are the unique names of the    output and input nodes, respectively. The values are the edges.</p> </li> </ul>"},{"location":"api/artificial.html#rex.base.Graph.__len__","title":"<code>__len__() -&gt; int</code>","text":"<p>Get the number of episodes in the graph.</p> <p>Returns:</p> <ul> <li> <code>int</code>           \u2013            <p>The number of episodes (i.e. the number of graphs if the graph is batched).</p> </li> </ul>"},{"location":"api/artificial.html#rex.base.Graph.__getitem__","title":"<code>__getitem__(val: int) -&gt; Graph</code>","text":"<p>In case the graph is batched, and holds the graphs of multiple episodes, this function returns the graph of a specific episode.</p> <p>Parameters:</p> <ul> <li> <code>val</code>               (<code>int</code>)           \u2013            <p>The episode to get the graph of.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>Graph</code>           \u2013            <p>The graph of the specific episode.</p> </li> </ul>"},{"location":"api/artificial.html#rex.base.Graph.stack","title":"<code>stack(graphs_raw: List[Graph]) -&gt; Graph</code>  <code>staticmethod</code>","text":"<p>Stack multiple graphs into a single graph.</p> Padding <p>If the graphs have different lengths, the vertices and edges are padded with -1.</p> <p>Parameters:</p> <ul> <li> <code>graphs_raw</code>               (<code>List[Graph]</code>)           \u2013            <p>A list of graphs to stack.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>Graph</code>           \u2013            <p>A single graph with the vertices and edges stacked</p> </li> </ul>"},{"location":"api/artificial.html#rex.base.Graph.filter","title":"<code>filter(nodes: Dict[str, BaseNode], filter_edges: bool = True) -&gt; Graph</code>","text":"<p>Filter the graph to only include the nodes and edges that are in the nodes dictionary.</p> <p>Parameters:</p> <ul> <li> <code>nodes</code>               (<code>Dict[str, BaseNode]</code>)           \u2013            <p>A dictionary of nodes. The keys are the unique names of the nodes, and the values are the nodes.</p> </li> <li> <code>filter_edges</code>               (<code>bool</code>, default:                   <code>True</code> )           \u2013            <p>If True, only include the nodes that are connected to the nodes in the dictionary.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>Graph</code>           \u2013            <p>A new graph with only the nodes and edges that are in the nodes dictionary.</p> </li> </ul>"},{"location":"api/artificial.html#rex.base.Vertex","title":"<code>rex.base.Vertex</code>","text":"<p>A vertex data structure that holds the sequence numbers and timestamps of a node.</p> <p>This data structure may be batched and hold data for multiple episodes. The last dimension represent the sequence numbers during the episode.</p> <p>In case the timestamps are not available, set ts_start and ts_end to a dummy value (e.g. 0.0).</p> <p>Ideally, for every vertex seq[i] there should be an edge with seq_out[i] for every connected node in the graph.</p> <p>Attributes:</p> <ul> <li> <code>seq</code>               (<code>Union[int, Array]</code>)           \u2013            <p>The sequence number of the node. Should start at 0 and increase by 1 every step (no gaps).</p> </li> <li> <code>ts_start</code>               (<code>Union[float, Array]</code>)           \u2013            <p>The start time of the computation of the node (i.e. when the node starts processing step 'seq').</p> </li> <li> <code>ts_end</code>               (<code>Union[float, Array]</code>)           \u2013            <p>The end time of the computation of the node (i.e. when the node finishes processing step 'seq').</p> </li> </ul>"},{"location":"api/artificial.html#rex.base.Edge","title":"<code>rex.base.Edge</code>","text":"<p>And edge data structure that holds the sequence numbers and timestamps of a connection.</p> <p>This data structure may be batched and hold data for multiple episodes. The last dimension represent the data during the episode.</p> <p>Given a message from  node_out to node_in, the sequence number of the send message is seq_out. The message is received at node_in at time ts_recv. Seq_in is the sequence number of the call that node_in processes the message.</p> <p>When there are outputs that were never received, set the seq_in to -1.</p> <p>In case the received timestamps are not available, set ts_recv to a dummy value (e.g. 0.0).</p> <p>Attributes:</p> <ul> <li> <code>seq_out</code>               (<code>Union[int, Array]</code>)           \u2013            <p>The sequence number of the message. Must be monotonically increasing.</p> </li> <li> <code>seq_in</code>               (<code>Union[int, Array]</code>)           \u2013            <p>The sequence number of the call that the message is processed. Must be monotonically increasing.</p> </li> <li> <code>ts_recv</code>               (<code>Union[float, Array]</code>)           \u2013            <p>The time the message is received at the input node. Must be monotonically increasing.</p> </li> </ul>"},{"location":"api/asynchronous.html","title":"Asynchronous","text":""},{"location":"api/asynchronous.html#rex.asynchronous.AsyncGraph","title":"<code>rex.asynchronous.AsyncGraph</code>","text":""},{"location":"api/asynchronous.html#rex.asynchronous.AsyncGraph.max_steps","title":"<code>max_steps</code>  <code>property</code>","text":"<p>The maximum number of steps.</p>"},{"location":"api/asynchronous.html#rex.asynchronous.AsyncGraph.max_eps","title":"<code>max_eps</code>  <code>property</code>","text":"<p>The maximum number of episodes.</p>"},{"location":"api/asynchronous.html#rex.asynchronous.AsyncGraph.__init__","title":"<code>__init__(nodes: Dict[str, BaseNode], supervisor: BaseNode, clock: Clock = Clock.WALL_CLOCK, real_time_factor: Union[float, int] = RealTimeFactor.REAL_TIME)</code>","text":"<p>Creates an interface around all nodes in the graph.</p> <p>As a mental model, it helps to think of the graph as dividing the nodes into two groups:</p> <ol> <li>Supervisor Node: The designated node that controls the graph's execution flow.</li> <li>All Other Nodes: These nodes form the environment the supervisor interacts with.</li> </ol> <p>This partitioning of nodes essentially creates an agent-environment interface, where the supervisor node acts as the agent, and the remaining nodes represent the environment. The graph provides gym-like <code>.reset</code> and <code>.step</code> methods that mirror reinforcement learning interfaces:</p> <ul> <li><code>.init</code>: Initializes the graph state, which includes the state of all nodes.</li> <li><code>.reset</code>: Initializes the system and returns the initial observation as would be seen by the supervisor node.</li> <li><code>.step</code>: Advances the graph by one step (i.e. steps all nodes except the supervisor) and returns the next observation.</li> </ul> <p>As a result, the timestep of graph.step is determined by the rate of the supervisor node (i.e., <code>1/supervisor.rate</code>).</p> <p>Parameters:</p> <ul> <li> <code>nodes</code>               (<code>Dict[str, BaseNode]</code>)           \u2013            <p>Dictionary of nodes that make up the graph.</p> </li> <li> <code>supervisor</code>               (<code>BaseNode</code>)           \u2013            <p>The designated node that controls the graph's execution flow.</p> </li> <li> <code>clock</code>               (<code>Clock</code>, default:                   <code>WALL_CLOCK</code> )           \u2013            <p>Determines how time is managed in the graph. Choices include <code>Clock.SIMULATED</code> for virtual simulations    and <code>Clock.WALL_CLOCK</code> for real-time applications.</p> </li> <li> <code>real_time_factor</code>               (<code>Union[float, int]</code>, default:                   <code>REAL_TIME</code> )           \u2013            <p>Sets the speed of the simulation. It can simulate as fast as possible               (<code>RealTimeFactor.FAST_AS_POSSIBLE</code>), in real-time (<code>RealTimeFactor.REAL_TIME</code>), or at any               custom speed relative to real-time.</p> </li> </ul>"},{"location":"api/asynchronous.html#rex.asynchronous.AsyncGraph.init","title":"<code>init(rng: jax.typing.ArrayLike = None, params: Dict[str, base.Params] = None, order: Tuple[str, ...] = None) -&gt; base.GraphState</code>","text":"<p>Initializes the graph state with optional parameters for RNG and step states.</p> <p>Nodes are initialized in a specified order, with the option to override params. Useful for setting up the graph state before running the graph with .run, .rollout, or .reset.</p> <p>Parameters:</p> <ul> <li> <code>rng</code>               (<code>ArrayLike</code>, default:                   <code>None</code> )           \u2013            <p>Random number generator seed or state.</p> </li> <li> <code>params</code>               (<code>Dict[str, Params]</code>, default:                   <code>None</code> )           \u2013            <p>Predefined params for (a subset of) the nodes.</p> </li> <li> <code>order</code>               (<code>Tuple[str, ...]</code>, default:                   <code>None</code> )           \u2013            <p>The order in which nodes are initialized.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>GraphState</code>           \u2013            <p>The initialized graph state.</p> </li> </ul>"},{"location":"api/asynchronous.html#rex.asynchronous.AsyncGraph.set_record_settings","title":"<code>set_record_settings(params: Union[Dict[str, bool], bool] = None, rng: Union[Dict[str, bool], bool] = None, inputs: Union[Dict[str, bool], bool] = None, state: Union[Dict[str, bool], bool] = None, output: Union[Dict[str, bool], bool] = None, max_records: Union[Dict[str, int], int] = None) -&gt; None</code>","text":"<p>Sets the record settings for the nodes in the graph.</p> <p>Parameters:</p> <ul> <li> <code>params</code>               (<code>Union[Dict[str, bool], bool]</code>, default:                   <code>None</code> )           \u2013            <p>Whether to record the params of the nodes.</p> </li> <li> <code>rng</code>               (<code>Union[Dict[str, bool], bool]</code>, default:                   <code>None</code> )           \u2013            <p>Whether to record the RNG states of the nodes.</p> </li> <li> <code>inputs</code>               (<code>Union[Dict[str, bool], bool]</code>, default:                   <code>None</code> )           \u2013            <p>Whether to record the input states of the nodes.</p> </li> <li> <code>state</code>               (<code>Union[Dict[str, bool], bool]</code>, default:                   <code>None</code> )           \u2013            <p>Whether to record the state of the nodes.</p> </li> <li> <code>output</code>               (<code>Union[Dict[str, bool], bool]</code>, default:                   <code>None</code> )           \u2013            <p>Whether to record the output of the nodes.</p> </li> <li> <code>max_records</code>               (<code>Union[Dict[str, int], int]</code>, default:                   <code>None</code> )           \u2013            <p>The maximum number of records to store for each node.</p> </li> </ul>"},{"location":"api/asynchronous.html#rex.asynchronous.AsyncGraph.warmup","title":"<code>warmup(graph_state: base.GraphState, device_step: Union[Dict[str, jax.Device], jax.Device] = None, device_dist: Union[Dict[str, jax.Device], jax.Device] = None, jit_step: Union[Dict[str, bool], bool] = True, profile: Union[Dict[str, bool], bool] = False, verbose: bool = False)</code>","text":"<p>Ahead-of-time compilation of step and I/O functions to avoid latency at runtime.</p> <p>Parameters:</p> <ul> <li> <code>graph_state</code>               (<code>GraphState</code>)           \u2013            <p>The graph state that is expected to be used during runtime.</p> </li> <li> <code>device_step</code>               (<code>Union[Dict[str, Device], Device]</code>, default:                   <code>None</code> )           \u2013            <p>The device to compile the step functions on. It's also the device used to prepare the input states.          If None, the default device is used.</p> </li> <li> <code>device_dist</code>               (<code>Union[Dict[str, Device], Device]</code>, default:                   <code>None</code> )           \u2013            <p>The device to compile the sampling of the delay distribution functions on. If None, the default device is used.          Only relevant when using a simulated clock.</p> </li> <li> <code>jit_step</code>               (<code>Union[Dict[str, bool], bool]</code>, default:                   <code>True</code> )           \u2013            <p>Whether to compile the step functions with JIT. If True, the step functions are compiled with JIT.       Step functions with jit are faster, but may not have side-effects by default.       Either wrap the side-effecting code in a jax callback wrapper, or set jit=False for those nodes.       See here for more info.</p> </li> <li> <code>profile</code>               (<code>Union[Dict[str, bool], bool]</code>, default:                   <code>False</code> )           \u2013            <p>Whether to compile the step functions with time profiling. If True, the step functions are compiled with time profiling.      IMPORTANT: This will test-run the step functions, which may lead to unexpected side-effects.</p> </li> <li> <code>verbose</code>               (<code>bool</code>, default:                   <code>False</code> )           \u2013            <p>Whether to print time profiling information.</p> </li> </ul>"},{"location":"api/asynchronous.html#rex.asynchronous.AsyncGraph.stop","title":"<code>stop(timeout: float = None) -&gt; None</code>","text":"<p>Stops the graph and all its nodes.</p> <p>Parameters:</p> <ul> <li> <code>timeout</code>               (<code>float</code>, default:                   <code>None</code> )           \u2013            <p>The maximum time to wait for the graph to stop. If None, it waits indefinitely.</p> </li> </ul>"},{"location":"api/asynchronous.html#rex.asynchronous.AsyncGraph.run","title":"<code>run(graph_state: base.GraphState, timeout: float = None) -&gt; base.GraphState</code>","text":"<p>Executes one step of the graph including the supervisor node and returns the updated graph state.</p> <p>Different from the .step method, it automatically progresses the graph state post-supervisor execution. This method is different from the gym API, as it uses the .step method of the supervisor node, while the reset and step methods allow the user to override the .step method.</p> <p>Parameters:</p> <ul> <li> <code>graph_state</code>               (<code>GraphState</code>)           \u2013            <p>The current graph state, or initial graph state from .init().</p> </li> <li> <code>timeout</code>               (<code>float</code>, default:                   <code>None</code> )           \u2013            <p>The maximum time to wait for the graph to complete a step.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>GraphState</code>           \u2013            <p>Updated graph state. It returns directly after the supervisor node's step() is run.</p> </li> </ul>"},{"location":"api/asynchronous.html#rex.asynchronous.AsyncGraph.reset","title":"<code>reset(graph_state: base.GraphState, timeout: float = None) -&gt; Tuple[base.GraphState, base.StepState]</code>","text":"<p>Prepares the graph for execution by resetting it to a state before the supervisor node's execution.</p> <p>Returns the graph and step state just before what would be the supervisor's step, mimicking the initial observation return of a gym environment's reset method. The step state can be considered the initial observation of a gym environment.</p> <p>Parameters:</p> <ul> <li> <code>graph_state</code>               (<code>GraphState</code>)           \u2013            <p>The graph state from .init().</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>Tuple[GraphState, StepState]</code>           \u2013            <p>Tuple of the new graph state and the supervisor node's step state before execution of the first step.</p> </li> </ul>"},{"location":"api/asynchronous.html#rex.asynchronous.AsyncGraph.step","title":"<code>step(graph_state: base.GraphState, step_state: base.StepState = None, output: base.Output = None) -&gt; Tuple[base.GraphState, base.StepState]</code>","text":"<p>Executes one step of the graph, optionally overriding the supervisor node's execution.</p> <p>If step_state and output are provided, they override the supervisor's step, allowing for custom step implementations. Otherwise, the supervisor's step() is executed as usual.</p> <p>When providing the updated step_state and output, the provided output can be viewed as the action that the agent would take in a gym environment, which is sent to nodes connected to the supervisor node.</p> <p>Start every episode with a call to reset() using the initial graph state from init(), then call step() repeatedly.</p> <p>Parameters:</p> <ul> <li> <code>graph_state</code>               (<code>GraphState</code>)           \u2013            <p>The current graph state.</p> </li> <li> <code>step_state</code>               (<code>StepState</code>, default:                   <code>None</code> )           \u2013            <p>Custom step state for the supervisor node.</p> </li> <li> <code>output</code>               (<code>Output</code>, default:                   <code>None</code> )           \u2013            <p>Custom output for the supervisor node.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>Tuple[GraphState, StepState]</code>           \u2013            <p>Tuple of the new graph state and the supervisor node's step state before execution of the next step.</p> </li> </ul>"},{"location":"api/asynchronous.html#rex.asynchronous.AsyncGraph.get_record","title":"<code>get_record() -&gt; base.EpisodeRecord</code>","text":"<p>Gets the episode record for all nodes in the graph.</p> <p>Returns:</p> <ul> <li> <code>EpisodeRecord</code>           \u2013            <p>Returns the episode record for all nodes in the graph.</p> </li> </ul>"},{"location":"api/base.html","title":"Base","text":""},{"location":"api/base.html#rex.base.InputState","title":"<code>rex.base.InputState</code>","text":"<p>A ring buffer that holds the inputs for a node's input channel.</p> <p>The size of the buffer is determined by the window size of the corresponding connection (i.e. node.connect(..., window=...)).</p> <p>Attributes:</p> <ul> <li> <code>seq</code>               (<code>ArrayLike</code>)           \u2013            <p>the sequence number of the received message</p> </li> <li> <code>ts_sent</code>               (<code>ArrayLike</code>)           \u2013            <p>the time the message was sent</p> </li> <li> <code>ts_recv</code>               (<code>ArrayLike</code>)           \u2013            <p>the time the message was received</p> </li> <li> <code>data</code>               (<code>Output</code>)           \u2013            <p>the message of the connection (arbitrary pytree structure)</p> </li> <li> <code>delay_dist</code>               (<code>DelayDistribution</code>)           \u2013            <p>the delay distribution of the connection</p> </li> </ul>"},{"location":"api/base.html#rex.base.InputState.__getitem__","title":"<code>__getitem__(val: int) -&gt; InputState</code>","text":"<p>Get the value of the ring buffer at a specific index.</p> <p>This is useful for indexing all the values of the ring buffer at a specific index.</p> <p>Parameters:</p> <ul> <li> <code>val</code>               (<code>int</code>)           \u2013            <p>the index to get the value from</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>InputState</code>           \u2013            <p>The input state at the specific index</p> </li> </ul>"},{"location":"api/base.html#rex.base.InputState.push","title":"<code>push(seq: int, ts_sent: float, ts_recv: float, data: Any) -&gt; InputState</code>","text":"<p>Push a new message into the ring buffer.</p> <p>Parameters:</p> <ul> <li> <code>seq</code>               (<code>int</code>)           \u2013            <p>the sequence number of the received message</p> </li> <li> <code>ts_sent</code>               (<code>float</code>)           \u2013            <p>the time the message was sent</p> </li> <li> <code>ts_recv</code>               (<code>float</code>)           \u2013            <p>the time the message was received</p> </li> <li> <code>data</code>               (<code>Any</code>)           \u2013            <p>the message of the connection (arbitrary pytree structure)</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>InputState</code>           \u2013            <p>The new input state with the message pushed into the ring buffer.</p> </li> </ul>"},{"location":"api/base.html#rex.base.InputState.from_outputs","title":"<code>from_outputs(seq: ArrayLike, ts_sent: ArrayLike, ts_recv: ArrayLike, outputs: Any, delay_dist: DelayDistribution, is_data: bool = False) -&gt; InputState</code>  <code>classmethod</code>","text":"<p>Create an InputState from a list of messages, timestamps, and sequence numbers.</p> <p>The oldest message should be first in the list.</p> <p>Parameters:</p> <ul> <li> <code>seq</code>               (<code>ArrayLike</code>)           \u2013            <p>the sequence number of the received message</p> </li> <li> <code>ts_sent</code>               (<code>ArrayLike</code>)           \u2013            <p>the time the message was sent</p> </li> <li> <code>ts_recv</code>               (<code>ArrayLike</code>)           \u2013            <p>the time the message was received</p> </li> <li> <code>outputs</code>               (<code>Any</code>)           \u2013            <p>the messages of the connection (arbitrary pytree structure)</p> </li> <li> <code>is_data</code>               (<code>bool</code>, default:                   <code>False</code> )           \u2013            <p>if True, the outputs are already a stacked pytree structure</p> </li> <li> <code>delay_dist</code>               (<code>DelayDistribution</code>)           \u2013            <p>the delay distribution of the connection</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>InputState</code>           \u2013            <p>The input state with the messages and timestamps in the ring buffer.</p> </li> </ul>"},{"location":"api/base.html#rex.base.StepState","title":"<code>rex.base.StepState</code>","text":"<p>Step state definition.</p> <p>It holds all the information that is required to step a node.</p> <p>Attributes:</p> <ul> <li> <code>rng</code>               (<code>Array</code>)           \u2013            <p>The random number generator. Used for sampling random processes. If used, it should be updated.</p> </li> <li> <code>state</code>               (<code>State</code>)           \u2013            <p>The state of the node. Usually dynamic during an episode.</p> </li> <li> <code>params</code>               (<code>Params</code>)           \u2013            <p>The parameters of the node. Usually static during an episode.</p> </li> <li> <code>inputs</code>               (<code>FrozenDict[str, InputState]</code>)           \u2013            <p>The inputs of the node. See InputState.</p> </li> <li> <code>eps</code>               (<code>Union[int, ArrayLike]</code>)           \u2013            <p>The current episode number. Relates to the computation graph, not the episode counter of an environment.</p> </li> <li> <code>seq</code>               (<code>Union[int, ArrayLike]</code>)           \u2013            <p>The current step number. Automatically increases by 1 every step.</p> </li> <li> <code>ts</code>               (<code>Union[float, ArrayLike]</code>)           \u2013            <p>The current time step at the start of the step. Determined by the computation graph.</p> </li> </ul>"},{"location":"api/base.html#rex.base.GraphState","title":"<code>rex.base.GraphState</code>","text":"<p>Graph state definition.</p> <p>It holds all the information that is required to step a graph.</p> <p>Attributes:</p> <ul> <li> <code>step</code>               (<code>Union[int, ArrayLike]</code>)           \u2013            <p>The current step number. Automatically increases by 1 every step.</p> </li> <li> <code>eps</code>               (<code>Union[int, ArrayLike]</code>)           \u2013            <p>The current episode number. To update the episode, use GraphState.replace_eps.</p> </li> <li> <code>rng</code>               (<code>FrozenDict[str, Array]</code>)           \u2013            <p>The random number generators for each node in the graph.</p> </li> <li> <code>seq</code>               (<code>FrozenDict[str, Union[int, ArrayLike]]</code>)           \u2013            <p>The current step number for each node in the graph.</p> </li> <li> <code>ts</code>               (<code>FrozenDict[str, Union[float, ArrayLike]]</code>)           \u2013            <p>The start time of the step for each node in the graph.</p> </li> <li> <code>params</code>               (<code>FrozenDict[str, Params]</code>)           \u2013            <p>The parameters for each node in the graph.</p> </li> <li> <code>state</code>               (<code>FrozenDict[str, State]</code>)           \u2013            <p>The state for each node in the graph.</p> </li> <li> <code>inputs</code>               (<code>FrozenDict[str, FrozenDict[str, InputState]]</code>)           \u2013            <p>The inputs for each node in the graph.</p> </li> <li> <code>timings_eps</code>               (<code>Timings</code>)           \u2013            <p>The timings data structure that describes the execution and partitioning of the graph.</p> </li> <li> <code>buffer</code>               (<code>FrozenDict[str, Output]</code>)           \u2013            <p>The output buffer of the graph. It holds the outputs of nodes during the execution. Input buffers are     automatically filled with the outputs of previously executed step calls of other nodes.</p> </li> <li> <code>aux</code>               (<code>FrozenDict[str, Any]</code>)           \u2013            <p>Auxiliary data that can be used to store additional information (e.g. records, wrappers etc.).</p> </li> </ul>"},{"location":"api/base.html#rex.base.Base","title":"<code>rex.base.Base</code>","text":"<p>Base functionality extending all dataclasses. These methods allow for dataclasses to be operated like arrays/matrices.</p> <p>Note: Credits to the authors of the brax library for this implementation.</p> Tip <p>Use this base class for all state, output, and param pytrees.</p>"},{"location":"api/base.html#rex.base.Base.__str__","title":"<code>__str__()</code>","text":"<p>Return a string representation of the dataclass.</p>"},{"location":"api/base.html#rex.base.Base.__add__","title":"<code>__add__(o: Any) -&gt; Any</code>","text":"<p>Element-wise addition of two pytrees.</p> <p>Parameters:</p> <ul> <li> <code>o</code>               (<code>Any</code>)           \u2013            <p>The other pytree to add.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>Any</code>           \u2013            <p>The resulting pytree after applying the element-wise operation.</p> </li> </ul>"},{"location":"api/base.html#rex.base.Base.__sub__","title":"<code>__sub__(o: Any) -&gt; Any</code>","text":"<p>Element-wise subtraction of two pytrees.</p> <p>Parameters:</p> <ul> <li> <code>o</code>               (<code>Any</code>)           \u2013            <p>The other pytree to add.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>Any</code>           \u2013            <p>The resulting pytree after applying the element-wise operation.</p> </li> </ul>"},{"location":"api/base.html#rex.base.Base.__mul__","title":"<code>__mul__(o: Any) -&gt; Any</code>","text":"<p>Element-wise multiplication of two pytrees.</p> <p>Parameters:</p> <ul> <li> <code>o</code>               (<code>Any</code>)           \u2013            <p>The other pytree to add.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>Any</code>           \u2013            <p>The resulting pytree after applying the element-wise operation.</p> </li> </ul>"},{"location":"api/base.html#rex.base.Base.__neg__","title":"<code>__neg__() -&gt; Any</code>","text":"<p>Element-wise negation of two pytrees.</p> <p>Returns:</p> <ul> <li> <code>Any</code>           \u2013            <p>The resulting pytree after applying the element-wise operation.</p> </li> </ul>"},{"location":"api/base.html#rex.base.Base.__truediv__","title":"<code>__truediv__(o: Any) -&gt; Any</code>","text":"<p>Element-wise division of two pytrees.</p> <p>Parameters:</p> <ul> <li> <code>o</code>               (<code>Any</code>)           \u2013            <p>The other pytree to add.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>Any</code>           \u2013            <p>The resulting pytree after applying the element-wise operation.</p> </li> </ul>"},{"location":"api/base.html#rex.base.Base.__getitem__","title":"<code>__getitem__(val: int) -&gt; Any</code>","text":"<p>Get a specific value from the dataclass.</p> <p>Parameters:</p> <ul> <li> <code>val</code>               (<code>int</code>)           \u2013            <p>The value to get from the dataclass.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>Any</code>           \u2013            <p>The value from the dataclass.</p> </li> </ul>"},{"location":"api/base.html#rex.base.Base.replace","title":"<code>replace(*args: Any, **kwargs: Any) -&gt; Any</code>","text":"<p>Replace fields in the dataclass.</p> <p>Parameters:</p> <ul> <li> <code>*args</code>               (<code>Any</code>, default:                   <code>()</code> )           \u2013            <p>The fields to replace.</p> </li> <li> <code>**kwargs</code>               (<code>Any</code>, default:                   <code>{}</code> )           \u2013            <p>The fields to replace.</p> </li> </ul>"},{"location":"api/base.html#rex.base.Base.reshape","title":"<code>reshape(shape: Sequence[int]) -&gt; Any</code>","text":"<p>Reshape the dataclass.</p> <p>Parameters:</p> <ul> <li> <code>shape</code>               (<code>Sequence[int]</code>)           \u2013            <p>The shape to reshape the dataclass.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>Any</code>           \u2013            <p>The reshaped dataclass.</p> </li> </ul>"},{"location":"api/base.html#rex.base.Base.select","title":"<code>select(o: Any, cond: jax.Array) -&gt; Any</code>","text":"<p>Select elements from two pytrees based on a condition</p> <p>Parameters:</p> <ul> <li> <code>o</code>               (<code>Any</code>)           \u2013            <p>The other pytree to select elements from.</p> </li> <li> <code>cond</code>               (<code>Array</code>)           \u2013            <p>The condition to select elements based on.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>Any</code>           \u2013            <p>The resulting pytree after applying the condition</p> </li> </ul>"},{"location":"api/base.html#rex.base.Base.slice","title":"<code>slice(beg: int, end: int) -&gt; Any</code>","text":"<p>Slice the dataclass.</p> <p>Parameters:</p> <ul> <li> <code>beg</code>               (<code>int</code>)           \u2013            <p>The beginning of the slice.</p> </li> <li> <code>end</code>               (<code>int</code>)           \u2013            <p>The end of the slice.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>Any</code>           \u2013            <p>The sliced dataclass.</p> </li> </ul>"},{"location":"api/base.html#rex.base.Base.take","title":"<code>take(i: int, axis: int = 0) -&gt; Any</code>","text":"<p>Take elements from the dataclass.</p> <p>Parameters:</p> <ul> <li> <code>i</code>               (<code>int</code>)           \u2013            <p>The elements to take.</p> </li> <li> <code>axis</code>               (<code>int</code>, default:                   <code>0</code> )           \u2013            <p>The axis to take the elements from.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>Any</code>           \u2013            <p>The taken elements from the dataclass.</p> </li> </ul>"},{"location":"api/base.html#rex.base.Base.concatenate","title":"<code>concatenate(*others: Any, axis: int = 0) -&gt; Any</code>","text":"<p>Concatenate the dataclass with other dataclasses.</p> <p>Parameters:</p> <ul> <li> <code>*others</code>               (<code>Any</code>, default:                   <code>()</code> )           \u2013            <p>The other dataclasses to concatenate.</p> </li> <li> <code>axis</code>               (<code>int</code>, default:                   <code>0</code> )           \u2013            <p>The axis to concatenate the dataclasses on.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>Any</code>           \u2013            <p>The concatenated dataclass</p> </li> </ul>"},{"location":"api/base.html#rex.base.Base.index_set","title":"<code>index_set(idx: Union[jax.Array, Sequence[jax.Array]], o: Any) -&gt; Any</code>","text":"<p>Set elements in the dataclass based on an index.</p> <p>Parameters:</p> <ul> <li> <code>idx</code>               (<code>Union[Array, Sequence[Array]]</code>)           \u2013            <p>The index to set the elements.</p> </li> <li> <code>o</code>               (<code>Any</code>)           \u2013            <p>The elements to set.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>Any</code>           \u2013            <p>The dataclass with the elements</p> </li> </ul>"},{"location":"api/base.html#rex.base.Base.index_sum","title":"<code>index_sum(idx: Union[jax.Array, Sequence[jax.Array]], o: Any) -&gt; Any</code>","text":"<p>Sum elements in the dataclass based on an index.</p> <p>Parameters:</p> <ul> <li> <code>idx</code>               (<code>Union[Array, Sequence[Array]]</code>)           \u2013            <p>The index to sum the elements.</p> </li> <li> <code>o</code>               (<code>Any</code>)           \u2013            <p>The elements to sum.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>Any</code>           \u2013            <p>The dataclass with the summed elements.</p> </li> </ul>"},{"location":"api/cem.html","title":"Cross-Entropy Method","text":""},{"location":"api/cem.html#rex.cem.cem","title":"<code>rex.cem.cem(loss: Loss, solver: CEMSolver, init_state: CEMState, transform: Transform, max_steps: int = 100, rng: jax.Array = None, verbose: bool = True) -&gt; Tuple[CEMState, jax.typing.ArrayLike]</code>","text":"<p>Run the Cross-Entropy Method (can be jit-compiled).</p> <p>Parameters:</p> <ul> <li> <code>loss</code>               (<code>Loss</code>)           \u2013            <p>Loss function.</p> </li> <li> <code>solver</code>               (<code>CEMSolver</code>)           \u2013            <p>CEM Solver.</p> </li> <li> <code>init_state</code>               (<code>CEMState</code>)           \u2013            <p>Initial state of the CEM Solver.</p> </li> <li> <code>transform</code>               (<code>Transform</code>)           \u2013            <p>Transform function (e.g. denormalization, extension, etc.).</p> </li> <li> <code>max_steps</code>               (<code>int</code>, default:                   <code>100</code> )           \u2013            <p>Maximum number of steps to run the CEM Solver.</p> </li> <li> <code>rng</code>               (<code>Array</code>, default:                   <code>None</code> )           \u2013            <p>Random number generator.</p> </li> <li> <code>verbose</code>               (<code>bool</code>, default:                   <code>True</code> )           \u2013            <p>Whether to print the progress.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>Tuple[CEMState, ArrayLike]</code>           \u2013            <p>The final state of the CEM Solver and the losses at each step.</p> </li> </ul>"},{"location":"api/cem.html#rex.cem.CEMSolver","title":"<code>rex.cem.CEMSolver</code>","text":"<p>See https://arxiv.org/pdf/1907.03613.pdf for details on CEM</p> <p>Attributes:</p> <ul> <li> <code>u_min</code>               (<code>Dict[str, Params]</code>)           \u2013            <p>(Normalized) Minimum values for the parameters (pytree).</p> </li> <li> <code>u_max</code>               (<code>Dict[str, Params]</code>)           \u2013            <p>(Normalized) Maximum values for the parameters (pytree).</p> </li> <li> <code>evolution_smoothing</code>               (<code>Union[float, ArrayLike]</code>)           \u2013            <p>Smoothing factor for updating the mean and standard deviation.</p> </li> <li> <code>num_samples</code>               (<code>int</code>)           \u2013            <p>Number of samples per iteration.</p> </li> <li> <code>elite_portion</code>               (<code>float</code>)           \u2013            <p>The portion of the samples to consider</p> </li> </ul>"},{"location":"api/cem.html#rex.cem.CEMSolver.init","title":"<code>init(u_min: Dict[str, Params], u_max: Dict[str, Params], num_samples: int = 100, evolution_smoothing: Union[float, jax.typing.ArrayLike] = 0.1, elite_portion: float = 0.1) -&gt; CEMSolver</code>  <code>classmethod</code>","text":"<p>Initialize the Cross-Entropy Method (CEM) Solver.</p> <p>Parameters:</p> <ul> <li> <code>u_min</code>               (<code>Dict[str, Params]</code>)           \u2013            <p>(Normalized) Minimum values for the parameters (pytree).</p> </li> <li> <code>u_max</code>               (<code>Dict[str, Params]</code>)           \u2013            <p>(Normalized) Maximum values for the parameters (pytree).</p> </li> <li> <code>num_samples</code>               (<code>int</code>, default:                   <code>100</code> )           \u2013            <p>Number of samples per iteration.</p> </li> <li> <code>evolution_smoothing</code>               (<code>Union[float, ArrayLike]</code>, default:                   <code>0.1</code> )           \u2013            <p>See https://arxiv.org/pdf/1907.03613.pdf for details.</p> </li> <li> <code>elite_portion</code>               (<code>float</code>, default:                   <code>0.1</code> )           \u2013            <p>See https://arxiv.org/pdf/1907.03613.pdf for details.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>CEMSolver</code> (              <code>CEMSolver</code> )          \u2013            <p>An instance of the CEMSolver class.</p> </li> </ul>"},{"location":"api/cem.html#rex.cem.CEMSolver.init_state","title":"<code>init_state(mean: Dict[str, Params], stdev: Dict[str, Params] = None) -&gt; CEMState</code>","text":"<p>Initialize the state of the CEM Solver.</p> <p>Parameters:</p> <ul> <li> <code>mean</code>               (<code>Dict[str, Params]</code>)           \u2013            <p>(Normalized) Mean values for the parameters (pytree).</p> </li> <li> <code>stdev</code>               (<code>Dict[str, Params]</code>, default:                   <code>None</code> )           \u2013            <p>(Normalized) Standard deviation values for the parameters (pytree).</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>CEMState</code> (              <code>CEMState</code> )          \u2013            <p>The initialized state of the CEM Solver.</p> </li> </ul>"},{"location":"api/cem.html#rex.cem.CEMState","title":"<code>rex.cem.CEMState</code>","text":"<p>State of the CEM Solver.</p> <p>Attributes:</p> <ul> <li> <code>mean</code>               (<code>Dict[str, Params]</code>)           \u2013            <p>(Normalized) Mean values for the parameters (pytree).</p> </li> <li> <code>stdev</code>               (<code>Dict[str, Params]</code>)           \u2013            <p>(Normalized) Standard deviation values for the parameters (pytree).</p> </li> <li> <code>bestsofar</code>               (<code>Dict[str, Params]</code>)           \u2013            <p>(Normalized) Best-so-far values for the parameters (pytree).</p> </li> <li> <code>bestsofar_loss</code>               (<code>Union[float, ArrayLike]</code>)           \u2013            <p>Loss of the best-so-far values.</p> </li> </ul>"},{"location":"api/compiled.html","title":"Compiled","text":""},{"location":"api/compiled.html#rex.graph.Graph","title":"<code>rex.graph.Graph</code>","text":""},{"location":"api/compiled.html#rex.graph.Graph.max_steps","title":"<code>max_steps</code>  <code>property</code>","text":"<p>The maximum number of steps.</p> <p>That's usually the number of vertices of the supervisor in the raw computation graphs.</p>"},{"location":"api/compiled.html#rex.graph.Graph.timings","title":"<code>timings: base.Timings</code>  <code>property</code>","text":"<p>Timings of the supergraph.</p> <p>Contains all predication masks to convert the supergraph to the correct partition given the current episode and step.</p>"},{"location":"api/compiled.html#rex.graph.Graph.graphs","title":"<code>graphs: base.Graph</code>  <code>property</code>","text":"<p>Graphs after applying windows to the raw computation graphs.</p>"},{"location":"api/compiled.html#rex.graph.Graph.graphs_raw","title":"<code>graphs_raw: base.Graph</code>  <code>property</code>","text":"<p>Raw computation graphs.</p>"},{"location":"api/compiled.html#rex.graph.Graph.Gs","title":"<code>Gs: List[nx.DiGraph]</code>  <code>property</code>","text":"<p>List of networkx graphs after applying windows to the raw computation graphs.</p>"},{"location":"api/compiled.html#rex.graph.Graph.S","title":"<code>S: nx.DiGraph</code>  <code>property</code>","text":"<p>The supergraph</p>"},{"location":"api/compiled.html#rex.graph.Graph.__init__","title":"<code>__init__(nodes: Dict[str, BaseNode], supervisor: BaseNode, graphs_raw: base.Graph, skip: List[str] = None, supergraph: Supergraph = Supergraph.MCS, prune: bool = True, S_init: nx.DiGraph = None, backtrack: int = 20, debug: bool = False, progress_bar: bool = True, buffer_sizes: Dict[str, int] = None, extra_padding: int = 0)</code>","text":"<p>Compile graph with nodes, supervisor, and target computation graphs.</p> <p>This class finds a partitioning and supergraph to efficiently represent all raw computation graphs. It exposes a .step and .reset method that resembles the gym API. In addition, we provide a .run and .rollout method. We refer to the specific methods for more information.</p> <p>The supervisor node defines the boundary between partitions, and essentially dictates the timestep of every step call.</p> <p>\"Raw\" computation graphs are the graphs that are computation graphs that only take into account the data flow of a system, without considering the fact that some messages may be used in multiple step calls, when no new data is available. Conversely, some messages may be discarded if they fall out of the buffer size. In other words, we first modify the raw computation graphs to take into account the buffer sizes (i.e. window sizes) for every connection.</p> <p>Parameters:</p> <ul> <li> <code>nodes</code>               (<code>Dict[str, BaseNode]</code>)           \u2013            <p>Dictionary of nodes.</p> </li> <li> <code>supervisor</code>               (<code>BaseNode</code>)           \u2013            <p>Supervisor node.</p> </li> <li> <code>graphs_raw</code>               (<code>Graph</code>)           \u2013            <p>Raw computation graphs. Must be acyclic.</p> </li> <li> <code>skip</code>               (<code>List[str]</code>, default:                   <code>None</code> )           \u2013            <p>List of nodes to skip during graph execution.</p> </li> <li> <code>supergraph</code>               (<code>Supergraph</code>, default:                   <code>MCS</code> )           \u2013            <p>Supergraph mode. Options are MCS, TOPOLOGICAL, and GENERATIONAL.</p> </li> <li> <code>prune</code>               (<code>bool</code>, default:                   <code>True</code> )           \u2013            <p>Prune nodes that are no ancestors of the supervisor node.    Setting to False ensures that all nodes up until the time of the last supervisor node are included.</p> </li> <li> <code>S_init</code>               (<code>DiGraph</code>, default:                   <code>None</code> )           \u2013            <p>Initial supergraph.</p> </li> <li> <code>backtrack</code>               (<code>int</code>, default:                   <code>20</code> )           \u2013            <p>Backtrack parameter for MCS supergraph mode.</p> </li> <li> <code>debug</code>               (<code>bool</code>, default:                   <code>False</code> )           \u2013            <p>Debug mode. Validates the partitioning and supergraph and times various compilation steps.</p> </li> <li> <code>progress_bar</code>               (<code>bool</code>, default:                   <code>True</code> )           \u2013            <p>Show progress bar during supergraph generation.</p> </li> <li> <code>buffer_sizes</code>               (<code>Dict[str, int]</code>, default:                   <code>None</code> )           \u2013            <p>Dictionary of buffer sizes for each connection.</p> </li> <li> <code>extra_padding</code>               (<code>int</code>, default:                   <code>0</code> )           \u2013            <p>Extra padding for buffer sizes.</p> </li> </ul>"},{"location":"api/compiled.html#rex.graph.Graph.init","title":"<code>init(rng: jax.typing.ArrayLike = None, params: Dict[str, base.Params] = None, starting_step: Union[int, jax.typing.ArrayLike] = 0, starting_eps: jax.typing.ArrayLike = 0, randomize_eps: bool = False, order: Tuple[str, ...] = None) -&gt; base.GraphState</code>","text":"<p>Initializes the graph state with optional parameters for RNG and step states.</p> <p>Nodes are initialized in a specified order, with the option to override params. Useful for setting up the graph state before running the graph with .run, .rollout, or .reset.</p> <p>Parameters:</p> <ul> <li> <code>rng</code>               (<code>ArrayLike</code>, default:                   <code>None</code> )           \u2013            <p>Random number generator seed or state.</p> </li> <li> <code>params</code>               (<code>Dict[str, Params]</code>, default:                   <code>None</code> )           \u2013            <p>Predefined params for (a subset of) the nodes.</p> </li> <li> <code>starting_step</code>               (<code>Union[int, ArrayLike]</code>, default:                   <code>0</code> )           \u2013            <p>The simulation's starting step.</p> </li> <li> <code>starting_eps</code>               (<code>ArrayLike</code>, default:                   <code>0</code> )           \u2013            <p>The starting episode.</p> </li> <li> <code>randomize_eps</code>               (<code>bool</code>, default:                   <code>False</code> )           \u2013            <p>If True, randomly selects the starting episode.</p> </li> <li> <code>order</code>               (<code>Tuple[str, ...]</code>, default:                   <code>None</code> )           \u2013            <p>The order in which nodes are initialized.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>GraphState</code>           \u2013            <p>The initialized graph state.</p> </li> </ul>"},{"location":"api/compiled.html#rex.graph.Graph.init_record","title":"<code>init_record(graph_state: base.GraphState, params: Union[Dict[str, bool], bool] = None, rng: Union[Dict[str, bool], bool] = None, inputs: Union[Dict[str, bool], bool] = None, state: Union[Dict[str, bool], bool] = None, output: Union[Dict[str, bool], bool] = None) -&gt; base.GraphState</code>","text":"<p>Sets the record settings for the nodes in the graph.</p> <p>Parameters:</p> <ul> <li> <code>graph_state</code>               (<code>GraphState</code>)           \u2013            <p>The initial graph state from .init().</p> </li> <li> <code>params</code>               (<code>Union[Dict[str, bool], bool]</code>, default:                   <code>None</code> )           \u2013            <p>Whether to record params for each node. Logged once.</p> </li> <li> <code>rng</code>               (<code>Union[Dict[str, bool], bool]</code>, default:                   <code>None</code> )           \u2013            <p>Whether to record rng for each node. Logged each step.</p> </li> <li> <code>inputs</code>               (<code>Union[Dict[str, bool], bool]</code>, default:                   <code>None</code> )           \u2013            <p>Whether to record inputs for each node. Logged each step. Can become very large.</p> </li> <li> <code>state</code>               (<code>Union[Dict[str, bool], bool]</code>, default:                   <code>None</code> )           \u2013            <p>Whether to record state for each node. Logged each step.</p> </li> <li> <code>output</code>               (<code>Union[Dict[str, bool], bool]</code>, default:                   <code>None</code> )           \u2013            <p>Whether to record output for each node. Logged each step.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>GraphState</code>           \u2013            <p>The updated graph state with record settings.</p> </li> </ul>"},{"location":"api/compiled.html#rex.graph.Graph.run","title":"<code>run(graph_state: base.GraphState) -&gt; base.GraphState</code>","text":"<p>Executes one step of the graph including the supervisor node and returns the updated graph state.</p> <p>Different from the <code>.step</code> method, it automatically progresses the graph state post-supervisor execution, suitable for <code>jax.lax.scan</code> or <code>jax.lax.fori_loop</code> operations. This method is different from the gym API, as it uses the <code>.step</code> method of the supervisor node, while the reset and step methods allow the user to override the <code>.step method</code>.</p> <p>Parameters:</p> <ul> <li> <code>graph_state</code>               (<code>GraphState</code>)           \u2013            <p>The current graph state, or initial graph state from <code>.init()</code>.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>GraphState</code>           \u2013            <p>Updated graph state. It returns directly after the supervisor node's step is run.</p> </li> </ul>"},{"location":"api/compiled.html#rex.graph.Graph.reset","title":"<code>reset(graph_state: base.GraphState) -&gt; Tuple[base.GraphState, base.StepState]</code>","text":"<p>Prepares the graph for execution by resetting it to a state before the supervisor node's execution.</p> <p>Returns the graph and step state just before what would be the supervisor's step, mimicking the initial observation return of a gym environment's reset method. The step state can be considered the initial observation of a gym environment.</p> <p>Parameters:</p> <ul> <li> <code>graph_state</code>               (<code>GraphState</code>)           \u2013            <p>The graph state from .init().</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>Tuple[GraphState, StepState]</code>           \u2013            <p>Tuple of the new graph state and the supervisor node's step state before execution of the first step.</p> </li> </ul>"},{"location":"api/compiled.html#rex.graph.Graph.step","title":"<code>step(graph_state: base.GraphState, step_state: base.StepState = None, output: base.Output = None) -&gt; Tuple[base.GraphState, base.StepState]</code>","text":"<p>Executes one step of the graph, optionally overriding the supervisor node's execution.</p> <p>If <code>step_state</code> and <code>output</code> are provided, they override the supervisor's step, allowing for custom step implementations. Otherwise, the supervisor's <code>step()</code> is executed as usual.</p> <p>When providing the updated <code>step_state</code> and <code>output</code>, the provided output can be viewed as the action that the agent would take in a gym environment, which is sent to nodes connected to the supervisor node.</p> <p>Start every episode with a call to <code>reset()</code> using the initial graph state from <code>init()</code>, then call <code>step()</code> repeatedly.</p> <p>Parameters:</p> <ul> <li> <code>graph_state</code>               (<code>GraphState</code>)           \u2013            <p>The current graph state.</p> </li> <li> <code>step_state</code>               (<code>StepState</code>, default:                   <code>None</code> )           \u2013            <p>Custom step state for the supervisor node.</p> </li> <li> <code>output</code>               (<code>Output</code>, default:                   <code>None</code> )           \u2013            <p>Custom output for the supervisor node.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>Tuple[GraphState, StepState]</code>           \u2013            <p>Tuple of the new graph state and the supervisor node's step state before execution of the next step.</p> </li> </ul>"},{"location":"api/compiled.html#rex.graph.Graph.rollout","title":"<code>rollout(graph_state: base.GraphState, max_steps: int = None, carry_only: bool = True) -&gt; base.GraphState</code>","text":"<p>Executes the graph for a specified number of steps or until a condition is met, starting from a given step and episode.</p> <p>Utilizes the run method for execution, with an option to return only the final graph state or a sequence of all graph states. By virtue of using the run method, it does not allow for overriding the supervisor node's step method. That is, the supervisor node's step method is used during the rollout.</p> Note <p>To record the rollout, use the init_record method on the graph_state before calling this method and set carry_only=True. Then, the record is available in graph_state.aux[\"record\"].</p> <p>Parameters:</p> <ul> <li> <code>graph_state</code>               (<code>GraphState</code>)           \u2013            <p>The initial graph state.</p> </li> <li> <code>max_steps</code>               (<code>int</code>, default:                   <code>None</code> )           \u2013            <p>The maximum steps to execute, if None, runs until a stop condition is met.</p> </li> <li> <code>carry_only</code>               (<code>bool</code>, default:                   <code>True</code> )           \u2013            <p>If True, returns only the final graph state; otherwise returns all states.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>GraphState</code>           \u2013            <p>The final or sequence of graph states post-execution.</p> </li> </ul>"},{"location":"api/delays.html","title":"Delays","text":""},{"location":"api/delays.html#rex.base.DelayDistribution","title":"<code>rex.base.DelayDistribution</code>","text":"<p>A delay distribution data structure.</p>"},{"location":"api/delays.html#rex.base.DelayDistribution.reset","title":"<code>reset(rng: jax.Array) -&gt; DelayDistribution</code>","text":"<p>Reset the distribution (e.g. random number generator).</p> <p>Parameters:</p> <ul> <li> <code>rng</code>               (<code>Array</code>)           \u2013            <p>random number generator</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>DelayDistribution</code> (              <code>DelayDistribution</code> )          \u2013            <p>the reset distribution</p> </li> </ul>"},{"location":"api/delays.html#rex.base.DelayDistribution.sample","title":"<code>sample(shape: Union[int, Tuple] = None) -&gt; Tuple[DelayDistribution, jax.Array]</code>","text":"<p>Sample from the distribution.</p> <p>Parameters:</p> <ul> <li> <code>shape</code>               (<code>Union[int, Tuple]</code>, default:                   <code>None</code> )           \u2013            <p>the shape of the sample</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>Tuple[DelayDistribution, Array]</code>           \u2013            <p>The new distribution and the sample</p> </li> </ul>"},{"location":"api/delays.html#rex.base.DelayDistribution.quantile","title":"<code>quantile(q: float) -&gt; float</code>","text":"<p>Compute the quantile of the distribution.</p> <p>Parameters:</p> <ul> <li> <code>q</code>               (<code>float</code>)           \u2013            <p>the quantile value</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>float</code>           \u2013            <p>The quantile value</p> </li> </ul>"},{"location":"api/delays.html#rex.base.DelayDistribution.mean","title":"<code>mean() -&gt; float</code>","text":"<p>Compute the mean of the distribution.</p> <p>Returns:</p> <ul> <li> <code>float</code>           \u2013            <p>The mean value</p> </li> </ul>"},{"location":"api/delays.html#rex.base.DelayDistribution.pdf","title":"<code>pdf(x: float) -&gt; float</code>","text":"<p>Compute the probability density function of the distribution.</p> <p>Parameters:</p> <ul> <li> <code>x</code>               (<code>float</code>)           \u2013            <p>the value at which to compute the pdf</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>float</code>           \u2013            <p>The pdf value</p> </li> </ul>"},{"location":"api/delays.html#rex.base.StaticDist","title":"<code>rex.base.StaticDist</code>","text":"<p>               Bases: <code>DelayDistribution</code></p> <p>A wrapper around distrax distributions to make them compatible with the DelayDistribution interface.</p> <p>Attributes:</p> <ul> <li> <code>rng</code>               (<code>Array</code>)           \u2013            <p>the random number generator</p> </li> <li> <code>dist</code>               (<code>Distribution</code>)           \u2013            <p>the (static) distrax distribution (e.g. Normal, MixtureSameFamily, etc.)</p> </li> </ul>"},{"location":"api/delays.html#rex.base.StaticDist.create","title":"<code>create(dist: distrax.Distribution) -&gt; StaticDist</code>  <code>classmethod</code>","text":"<p>Create a static distribution.</p> <p>Parameters:</p> <ul> <li> <code>dist</code>               (<code>Distribution</code>)           \u2013            <p>the distrax distribution</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>StaticDist</code>           \u2013            <p>The static distribution</p> </li> </ul>"},{"location":"api/delays.html#rex.base.StaticDist.reset","title":"<code>reset(rng: jax.Array) -&gt; StaticDist</code>","text":"<p>Reset the random number generator.</p> <p>Parameters:</p> <ul> <li> <code>rng</code>               (<code>Array</code>)           \u2013            <p>random number generator</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>DelayDistribution</code> (              <code>StaticDist</code> )          \u2013            <p>the reset distribution</p> </li> </ul>"},{"location":"api/delays.html#rex.base.StaticDist.sample","title":"<code>sample(shape: Union[int, Tuple] = None) -&gt; Tuple[StaticDist, jax.Array]</code>","text":"<p>Sample from the distribution.</p> <p>Parameters:</p> <ul> <li> <code>shape</code>               (<code>Union[int, Tuple]</code>, default:                   <code>None</code> )           \u2013            <p>the shape of the sample</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>Tuple[StaticDist, Array]</code>           \u2013            <p>The new distribution and the sample</p> </li> </ul>"},{"location":"api/delays.html#rex.base.StaticDist.quantile","title":"<code>quantile(q: float) -&gt; Union[float, jax.typing.ArrayLike]</code>","text":"<p>Compute the quantile of the distribution.</p> <p>Parameters:</p> <ul> <li> <code>q</code>               (<code>float</code>)           \u2013            <p>the quantile value</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>Union[float, ArrayLike]</code>           \u2013            <p>The quantile value</p> </li> </ul>"},{"location":"api/delays.html#rex.base.StaticDist.mean","title":"<code>mean() -&gt; float</code>","text":"<p>Compute the mean of the distribution.</p> <p>Returns:</p> <ul> <li> <code>float</code>           \u2013            <p>The mean value</p> </li> </ul>"},{"location":"api/delays.html#rex.base.StaticDist.pdf","title":"<code>pdf(x: float) -&gt; float</code>","text":"<p>Compute the probability density function of the distribution.</p> <p>Parameters:</p> <ul> <li> <code>x</code>               (<code>float</code>)           \u2013            <p>the value at which to compute the pdf</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>float</code>           \u2013            <p>The pdf value</p> </li> </ul>"},{"location":"api/delays.html#rex.base.TrainableDist","title":"<code>rex.base.TrainableDist</code>","text":"<p>               Bases: <code>DelayDistribution</code></p> <p>Attributes:</p> <ul> <li> <code>alpha</code>               (<code>Union[float, ArrayLike]</code>)           \u2013            <p>the value between [0, 1]</p> </li> <li> <code>min</code>               (<code>Union[float, ArrayLike]</code>)           \u2013            <p>the minimum expected delay</p> </li> <li> <code>max</code>               (<code>Union[float, ArrayLike]</code>)           \u2013            <p>the maximum expected delay</p> </li> <li> <code>interp</code>               (<code>str</code>)           \u2013            <p>the interpolation method (\"zoh\", \"linear\", \"linear_real_only\").     \"zoh\": zero-order hold interpolation between received messages.     \"linear\": linear interpolation between received messages.     \"linear_real_only\": Only start interpolating between received messages.</p> </li> </ul>"},{"location":"api/delays.html#rex.base.TrainableDist.create","title":"<code>create(delay: Union[float, jax.typing.ArrayLike], min: Union[float, jax.typing.ArrayLike], max: Union[float, jax.typing.ArrayLike], interp: str = 'zoh') -&gt; TrainableDist</code>  <code>classmethod</code>","text":"<p>Creates a trainable distribution. Converts the delay to alpha, which is the value between [0, 1].</p> <p>Parameters:</p> <ul> <li> <code>delay</code>               (<code>Union[float, ArrayLike]</code>)           \u2013            <p>the desired delay</p> </li> <li> <code>min</code>               (<code>Union[float, ArrayLike]</code>)           \u2013            <p>the minimum expected delay</p> </li> <li> <code>max</code>               (<code>Union[float, ArrayLike]</code>)           \u2013            <p>the maximum expected delay</p> </li> <li> <code>interp</code>               (<code>str</code>, default:                   <code>'zoh'</code> )           \u2013            <p>the interpolation method (\"zoh\", \"linear\", \"linear_real_only\").</p> </li> </ul> <p>Returns:     The trainable distribution</p>"},{"location":"api/delays.html#rex.base.TrainableDist.reset","title":"<code>reset(rng: jax.Array) -&gt; TrainableDist</code>","text":"<p>Does nothing as the distribution is deterministic.</p> <p>Parameters:</p> <ul> <li> <code>rng</code>               (<code>Array</code>)           \u2013            <p>random number generator</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>DelayDistribution</code> (              <code>TrainableDist</code> )          \u2013            <p>the reset distribution</p> </li> </ul>"},{"location":"api/delays.html#rex.base.TrainableDist.sample","title":"<code>sample(shape: Union[int, Tuple] = None) -&gt; Tuple[TrainableDist, jax.Array]</code>","text":"<p>Sample from the distribution.</p> <p>Parameters:</p> <ul> <li> <code>shape</code>               (<code>Union[int, Tuple]</code>, default:                   <code>None</code> )           \u2013            <p>the shape of the sample</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>Tuple[TrainableDist, Array]</code>           \u2013            <p>The new distribution and the sample</p> </li> </ul>"},{"location":"api/delays.html#rex.base.TrainableDist.quantile","title":"<code>quantile(q: float) -&gt; float</code>","text":"<p>Compute the quantile of the distribution.</p> <p>As the distribution is deterministic, the quantile is trivially calculated as the constant value of the distribution.</p> <p>Parameters:</p> <ul> <li> <code>q</code>               (<code>float</code>)           \u2013            <p>the quantile value</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>float</code>           \u2013            <p>The quantile value</p> </li> </ul>"},{"location":"api/delays.html#rex.base.TrainableDist.mean","title":"<code>mean() -&gt; Union[float, jax.typing.ArrayLike]</code>","text":"<p>Compute the mean of the distribution.</p> <p>Returns:</p> <ul> <li> <code>Union[float, ArrayLike]</code>           \u2013            <p>The mean value</p> </li> </ul>"},{"location":"api/delays.html#rex.base.TrainableDist.pdf","title":"<code>pdf(x: float) -&gt; Union[jax.Array, float]</code>","text":"<p>Compute the probability density function of the distribution.</p> <p>Parameters:</p> <ul> <li> <code>x</code>               (<code>float</code>)           \u2013            <p>the value at which to compute the pdf</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>Union[Array, float]</code>           \u2013            <p>The pdf value</p> </li> </ul>"},{"location":"api/delays.html#rex.base.TrainableDist.equivalent","title":"<code>equivalent(other: DelayDistribution) -&gt; bool</code>","text":"<p>Check if two delay distributions are equivalent</p> <p>Parameters:</p> <ul> <li> <code>other</code>               (<code>DelayDistribution</code>)           \u2013            <p>the other delay distribution</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>bool</code>           \u2013            <p>True if the two delay distributions are equivalent, False otherwise</p> </li> </ul>"},{"location":"api/delays.html#rex.base.TrainableDist.window","title":"<code>window(rate_out: Union[float, int]) -&gt; int</code>","text":"<p>Compute the additional window size needed for the delay distribution.</p> <p>Parameters:</p> <ul> <li> <code>rate_out</code>               (<code>Union[float, int]</code>)           \u2013            <p>the output rate of the connection</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>int</code>           \u2013            <p>The additional window size needed for the delay distribution</p> </li> </ul>"},{"location":"api/delays.html#rex.base.TrainableDist.apply_delay","title":"<code>apply_delay(rate_out: float, input: InputState, ts_start: Union[float, jax.typing.ArrayLike]) -&gt; InputState</code>","text":"<p>Apply the delay to the input state.</p> <p>The delay is determined by the delay distribution of the connection.</p> <p>Parameters:</p> <ul> <li> <code>rate_out</code>               (<code>float</code>)           \u2013            <p>the output rate of the connection</p> </li> <li> <code>input</code>               (<code>InputState</code>)           \u2013            <p>the input state</p> </li> <li> <code>ts_start</code>               (<code>Union[float, ArrayLike]</code>)           \u2013            <p>the start time of the computation</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>InputState</code>           \u2013            <p>The input state with the delay applied. This reduces the window size of the input by self.window(rate_out).</p> </li> </ul>"},{"location":"api/delays.html#rex.base.TrainableDist.get_alpha","title":"<code>get_alpha(delay: Union[float, jax.typing.ArrayLike]) -&gt; Union[float, jax.typing.ArrayLike]</code>","text":"<p>Get the alpha value of the delay distribution.</p> <p>Parameters:</p> <ul> <li> <code>delay</code>               (<code>Union[float, ArrayLike]</code>)           \u2013            <p>the delay value</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>Union[float, ArrayLike]</code>           \u2013            <p>The alpha value</p> </li> </ul>"},{"location":"api/environment.html","title":"Environment and Wrappers","text":""},{"location":"api/environment.html#rex.rl.BaseEnv","title":"<code>rex.rl.BaseEnv</code>","text":""},{"location":"api/environment.html#rex.rl.BaseEnv.graph","title":"<code>graph = graph</code>  <code>instance-attribute</code>","text":""},{"location":"api/environment.html#rex.rl.BaseEnv.max_steps","title":"<code>max_steps: Union[int, jax.typing.ArrayLike]</code>  <code>property</code>","text":"<p>The maximum number of steps in the environment.</p> <p>Per default, this is the maximum number of steps the supervisor (i.e. agent) is stepped in the provided computation graph. You can override this property to provide a custom value (smaller than the default). This value is used as the episode length when evaluating the environment during training.</p>"},{"location":"api/environment.html#rex.rl.BaseEnv.observation_space","title":"<code>observation_space(graph_state: base.GraphState) -&gt; Box</code>","text":"<p>Returns the observation space.</p> <p>Parameters:</p> <ul> <li> <code>graph_state</code>               (<code>GraphState</code>)           \u2013            <p>The graph state.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>Box</code>           \u2013            <p>The observation space</p> </li> </ul>"},{"location":"api/environment.html#rex.rl.BaseEnv.action_space","title":"<code>action_space(graph_state: base.GraphState) -&gt; Box</code>","text":"<p>Returns the action space.</p> <p>Parameters:</p> <ul> <li> <code>graph_state</code>               (<code>GraphState</code>)           \u2013            <p>The graph state.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>Box</code>           \u2013            <p>The action space</p> </li> </ul>"},{"location":"api/environment.html#rex.rl.BaseEnv.reset","title":"<code>reset(rng: jax.Array = None) -&gt; ResetReturn</code>","text":"<p>Reset the environment.</p> <p>Parameters:</p> <ul> <li> <code>rng</code>               (<code>Array</code>, default:                   <code>None</code> )           \u2013            <p>Random number generator. Used to initialize a new graph state.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>ResetReturn</code>           \u2013            <p>The initial graph state, observation, and info</p> </li> </ul>"},{"location":"api/environment.html#rex.rl.BaseEnv.step","title":"<code>step(graph_state: base.GraphState, action: jax.Array) -&gt; StepReturn</code>","text":"<p>Step the environment.</p> <p>Parameters:</p> <ul> <li> <code>graph_state</code>               (<code>GraphState</code>)           \u2013            <p>The current graph state.</p> </li> <li> <code>action</code>               (<code>Array</code>)           \u2013            <p>The action to take.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>StepReturn</code>           \u2013            <p>The updated graph state, observation, reward, terminated, truncated, and info</p> </li> </ul>"},{"location":"api/environment.html#rex.rl.BaseWrapper","title":"<code>rex.rl.BaseWrapper</code>","text":"<p>               Bases: <code>object</code></p> <p>Base class for wrappers.</p>"},{"location":"api/environment.html#rex.rl.BaseWrapper.__init__","title":"<code>__init__(env: Union[BaseEnv, Environment, BaseWrapper])</code>","text":"<p>Initialize the wrapper.</p> <p>Parameters:</p> <ul> <li> <code>env</code>               (<code>Union[BaseEnv, Environment, BaseWrapper]</code>)           \u2013            <p>The environment to wrap.</p> </li> </ul>"},{"location":"api/environment.html#rex.rl.BaseWrapper.__getattr__","title":"<code>__getattr__(name: str) -&gt; Any</code>","text":"<p>Proxy access to regular attributes of wrapped object.</p> <p>Parameters:</p> <ul> <li> <code>name</code>               (<code>str</code>)           \u2013            <p>The name of the attribute.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>Any</code>           \u2013            <p>The attribute of the wrapped object.</p> </li> </ul>"},{"location":"api/environment.html#rex.rl.AutoResetWrapper","title":"<code>rex.rl.AutoResetWrapper</code>","text":"<p>               Bases: <code>BaseWrapper</code></p>"},{"location":"api/environment.html#rex.rl.AutoResetWrapper.__init__","title":"<code>__init__(env: Union[BaseEnv, Environment, BaseWrapper], fixed_init: bool = True)</code>","text":"<p>The AutoResetWrapper will reset the environment when the episode is done in the step method.</p> <p>When fixed_init is True, a fixed initial state is used for the environment instead of actually resetting it. This is useful when you want to use the same initial state for every episode. In some cases, resetting the environment can be expensive, so this can be used to avoid that.</p> <p>Parameters:</p> <ul> <li> <code>env</code>               (<code>Union[BaseEnv, Environment, BaseWrapper]</code>)           \u2013            <p>The environment to wrap.</p> </li> <li> <code>fixed_init</code>               (<code>bool</code>, default:                   <code>True</code> )           \u2013            <p>Whether to use a fixed initial state.</p> </li> </ul>"},{"location":"api/environment.html#rex.rl.AutoResetWrapper.reset","title":"<code>reset(rng: jax.Array = None) -&gt; ResetReturn</code>","text":"<p>Reset the environment and return the initial state.</p> <p>If fixed_init is True, the initial state is stored in the aux of the graph state.</p> <p>Parameters:</p> <ul> <li> <code>rng</code>               (<code>Array</code>, default:                   <code>None</code> )           \u2013            <p>Random number generator.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>ResetReturn</code>           \u2013            <p>The initial graph state, observation, and info</p> </li> </ul>"},{"location":"api/environment.html#rex.rl.AutoResetWrapper.step","title":"<code>step(graph_state: base.GraphState, action: jax.Array) -&gt; StepReturn</code>","text":"<p>Step the environment and reset the state if the episode is done.</p> <p>Parameters:</p> <ul> <li> <code>graph_state</code>               (<code>GraphState</code>)           \u2013            <p>The current graph state.</p> </li> <li> <code>action</code>               (<code>Array</code>)           \u2013            <p>The action to take.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>StepReturn</code>           \u2013            <p>The updated graph state, observation, reward, terminated, truncated, and info</p> </li> </ul>"},{"location":"api/environment.html#rex.rl.LogWrapper","title":"<code>rex.rl.LogWrapper</code>","text":"<p>               Bases: <code>BaseWrapper</code></p>"},{"location":"api/environment.html#rex.rl.LogWrapper.__init__","title":"<code>__init__(env: Union[BaseEnv, Environment, BaseWrapper])</code>","text":"<p>Log the episode returns and lengths.</p> <p>Parameters:</p> <ul> <li> <code>env</code>               (<code>Union[BaseEnv, Environment, BaseWrapper]</code>)           \u2013            <p>The environment to wrap.</p> </li> </ul>"},{"location":"api/environment.html#rex.rl.LogWrapper.reset","title":"<code>reset(rng: jax.Array = None) -&gt; ResetReturn</code>","text":"<p>Stores the log state in the aux of the graph state.</p> <p>Parameters:</p> <ul> <li> <code>rng</code>               (<code>Array</code>, default:                   <code>None</code> )           \u2013            <p>Random number generator.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>ResetReturn</code>           \u2013            <p>The initial graph state, observation, and info</p> </li> </ul>"},{"location":"api/environment.html#rex.rl.LogWrapper.step","title":"<code>step(graph_state: base.GraphState, action: jax.Array) -&gt; StepReturn</code>","text":"<p>Logs the episode returns and lengths.</p> <p>Parameters:</p> <ul> <li> <code>graph_state</code>               (<code>GraphState</code>)           \u2013            <p>The current graph state.</p> </li> <li> <code>action</code>               (<code>Array</code>)           \u2013            <p>The action to take.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>StepReturn</code>           \u2013            <p>The updated graph state, observation, reward, terminated, truncated, and info</p> </li> </ul>"},{"location":"api/environment.html#rex.rl.LogState","title":"<code>rex.rl.LogState</code>","text":"<p>Attributes:</p> <ul> <li> <code>episode_returns</code>               (<code>float</code>)           \u2013            <p>The sum of the rewards in the episode.</p> </li> <li> <code>episode_lengths</code>               (<code>int</code>)           \u2013            <p>The number of steps in the episode.</p> </li> <li> <code>returned_episode_returns</code>               (<code>float</code>)           \u2013            <p>The sum of the rewards in the episode that was returned.</p> </li> <li> <code>returned_episode_lengths</code>               (<code>int</code>)           \u2013            <p>The number of steps in the episode that was returned.</p> </li> <li> <code>timestep</code>               (<code>int</code>)           \u2013            <p>The current</p> </li> </ul>"},{"location":"api/environment.html#rex.rl.SquashActionWrapper","title":"<code>rex.rl.SquashActionWrapper</code>","text":"<p>               Bases: <code>BaseWrapper</code></p>"},{"location":"api/environment.html#rex.rl.SquashActionWrapper.__init__","title":"<code>__init__(env: Union[BaseEnv, Environment, BaseWrapper], squash: bool = True)</code>","text":"<p>Squashes the action space to [-1, 1] and unsquashes it when returning the action.</p> <p>Parameters:</p> <ul> <li> <code>env</code>               (<code>Union[BaseEnv, Environment, BaseWrapper]</code>)           \u2013            <p>The environment to wrap.</p> </li> <li> <code>squash</code>               (<code>bool</code>, default:                   <code>True</code> )           \u2013            <p>Whether to squash the action space.</p> </li> </ul>"},{"location":"api/environment.html#rex.rl.SquashActionWrapper.reset","title":"<code>reset(rng: jax.Array = None) -&gt; ResetReturn</code>","text":"<p>Puts the action space scaling in the aux of the graph state.</p> <p>Parameters:</p> <ul> <li> <code>rng</code>               (<code>Array</code>, default:                   <code>None</code> )           \u2013            <p>Random number generator.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>ResetReturn</code>           \u2013            <p>The initial graph state, observation, and info</p> </li> </ul>"},{"location":"api/environment.html#rex.rl.SquashActionWrapper.step","title":"<code>step(graph_state: base.GraphState, action: jax.Array) -&gt; StepReturn</code>","text":"<p>Unscales the action to the original range of the action space before stepping the environment.</p> <p>Parameters:</p> <ul> <li> <code>graph_state</code>               (<code>GraphState</code>)           \u2013            <p>The current graph state.</p> </li> <li> <code>action</code>               (<code>Array</code>)           \u2013            <p>The (scaled) action to take.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>StepReturn</code>           \u2013            <p>The updated graph state, observation, reward, terminated, truncated, and info</p> </li> </ul>"},{"location":"api/environment.html#rex.rl.SquashActionWrapper.action_space","title":"<code>action_space(graph_state: base.GraphState) -&gt; Box</code>","text":"<p>Scales the action space to [-1, 1] if squash is True.</p> <p>Parameters:</p> <ul> <li> <code>graph_state</code>               (<code>GraphState</code>)           \u2013            <p>The graph state.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>Box</code>           \u2013            <p>The scaled action space</p> </li> </ul>"},{"location":"api/environment.html#rex.rl.SquashState","title":"<code>rex.rl.SquashState</code>","text":"<p>Attributes:</p> <ul> <li> <code>low</code>               (<code>Array</code>)           \u2013            <p>The lower bound of the action space.</p> </li> <li> <code>high</code>               (<code>Array</code>)           \u2013            <p>The upper bound of the action space.</p> </li> <li> <code>squash</code>               (<code>bool</code>)           \u2013            <p>Whether to squash the action space.</p> </li> </ul>"},{"location":"api/environment.html#rex.rl.SquashState.action_space","title":"<code>action_space: Box</code>  <code>property</code>","text":"<p>Returns:</p> <ul> <li> <code>Box</code>           \u2013            <p>The scaled action space.</p> </li> </ul>"},{"location":"api/environment.html#rex.rl.SquashState.scale","title":"<code>scale(x: jax.Array) -&gt; jax.Array</code>","text":"<p>Scales the input to [-1, 1] and unsquashes.</p> <p>Parameters:</p> <ul> <li> <code>x</code>               (<code>Array</code>)           \u2013            <p>The input to scale.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>Array</code>           \u2013            <p>The scaled input.</p> </li> </ul>"},{"location":"api/environment.html#rex.rl.SquashState.unsquash","title":"<code>unsquash(x: jax.Array) -&gt; jax.Array</code>","text":"<p>Squashes x to [-1, 1] and then unscales to the original range [low, high]. Else, x is clipped to the range of the action space.</p> <p>Parameters:</p> <ul> <li> <code>x</code>               (<code>Array</code>)           \u2013            <p>The input to unscale.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>Array</code>           \u2013            <p>Unscaled input.</p> </li> </ul>"},{"location":"api/environment.html#rex.rl.ClipActionWrapper","title":"<code>rex.rl.ClipActionWrapper</code>","text":"<p>               Bases: <code>BaseWrapper</code></p>"},{"location":"api/environment.html#rex.rl.ClipActionWrapper.step","title":"<code>step(graph_state: base.GraphState, action: jax.Array) -&gt; StepReturn</code>","text":"<p>Clips the action to the action space before stepping the environment.</p> <p>Parameters:</p> <ul> <li> <code>graph_state</code>               (<code>GraphState</code>)           \u2013            <p>The current graph state.</p> </li> <li> <code>action</code>               (<code>Array</code>)           \u2013            <p>The action to take.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>StepReturn</code>           \u2013            <p>The updated graph state, observation, reward, terminated, truncated</p> </li> </ul>"},{"location":"api/environment.html#rex.rl.VecEnvWrapper","title":"<code>rex.rl.VecEnvWrapper</code>","text":"<p>               Bases: <code>BaseWrapper</code></p>"},{"location":"api/environment.html#rex.rl.VecEnvWrapper.__init__","title":"<code>__init__(env: Union[BaseEnv, Environment, BaseWrapper], in_axes: Union[int, None, Sequence[Any]] = 0)</code>","text":"<p>Vectorizes the environment.</p> <p>Parameters:</p> <ul> <li> <code>env</code>               (<code>Union[BaseEnv, Environment, BaseWrapper]</code>)           \u2013            <p>The environment to wrap.</p> </li> <li> <code>in_axes</code>               (<code>Union[int, None, Sequence[Any]]</code>, default:                   <code>0</code> )           \u2013            <p>The axes to map over.</p> </li> </ul>"},{"location":"api/environment.html#rex.rl.NormalizeVecObservationWrapper","title":"<code>rex.rl.NormalizeVecObservationWrapper</code>","text":"<p>               Bases: <code>BaseWrapper</code></p>"},{"location":"api/environment.html#rex.rl.NormalizeVecObservationWrapper.__init__","title":"<code>__init__(env: Union[BaseEnv, Environment, BaseWrapper], clip_obs: float = 10.0)</code>","text":"<p>Normalize the observations to have zero mean and unit variance.</p> <p>Parameters:</p> <ul> <li> <code>env</code>               (<code>Union[BaseEnv, Environment, BaseWrapper]</code>)           \u2013            <p>The environment to wrap.</p> </li> <li> <code>clip_obs</code>               (<code>float</code>, default:                   <code>10.0</code> )           \u2013            <p>The clipping value.</p> </li> </ul>"},{"location":"api/environment.html#rex.rl.NormalizeVecObservationWrapper.reset","title":"<code>reset(rng: jax.Array = None) -&gt; ResetReturn</code>","text":"<p>Places the normalization state in the aux of the graph state.</p> <p>Parameters:</p> <ul> <li> <code>rng</code>               (<code>Array</code>, default:                   <code>None</code> )           \u2013            <p>Random number generator.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>ResetReturn</code>           \u2013            <p>The initial graph state, observation, and info</p> </li> </ul>"},{"location":"api/environment.html#rex.rl.NormalizeVecObservationWrapper.step","title":"<code>step(graph_state: base.GraphState, action: jax.Array) -&gt; StepReturn</code>","text":"<p>Normalize the observations to have zero mean and unit variance before returning them.</p> <p>Parameters:</p> <ul> <li> <code>graph_state</code>               (<code>GraphState</code>)           \u2013            <p>The current graph state.</p> </li> <li> <code>action</code>               (<code>Array</code>)           \u2013            <p>The action to take.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>StepReturn</code>           \u2013            <p>The updated graph state, observation, reward, terminated, truncated, and info</p> </li> </ul>"},{"location":"api/environment.html#rex.rl.NormalizeVecReward","title":"<code>rex.rl.NormalizeVecReward</code>","text":"<p>               Bases: <code>BaseWrapper</code></p>"},{"location":"api/environment.html#rex.rl.NormalizeVecReward.__init__","title":"<code>__init__(env: Union[BaseEnv, Environment, BaseWrapper], gamma: Union[float, jax.typing.ArrayLike], clip_reward: float = 10.0)</code>","text":"<p>Normalize the rewards to have zero mean and unit variance.</p> <p>Parameters:</p> <ul> <li> <code>env</code>               (<code>Union[BaseEnv, Environment, BaseWrapper]</code>)           \u2013            <p>The environment to wrap.</p> </li> <li> <code>gamma</code>               (<code>Union[float, ArrayLike]</code>)           \u2013            <p>The discount factor.</p> </li> <li> <code>clip_reward</code>               (<code>float</code>, default:                   <code>10.0</code> )           \u2013            <p>The clipping value.</p> </li> </ul>"},{"location":"api/environment.html#rex.rl.NormalizeVecReward.reset","title":"<code>reset(rng: jax.Array = None) -&gt; ResetReturn</code>","text":"<p>Places the normalization state in the aux of the graph state.</p> <p>Parameters:</p> <ul> <li> <code>rng</code>               (<code>Array</code>, default:                   <code>None</code> )           \u2013            <p>Random number generator.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>ResetReturn</code>           \u2013            <p>The initial graph state, observation, and info</p> </li> </ul>"},{"location":"api/environment.html#rex.rl.NormalizeVecReward.step","title":"<code>step(graph_state: base.GraphState, action: jax.Array) -&gt; StepReturn</code>","text":"<p>Normalize the rewards to have zero mean and unit variance before returning them.</p> <p>Parameters:</p> <ul> <li> <code>graph_state</code>               (<code>GraphState</code>)           \u2013            <p>The current graph state.</p> </li> <li> <code>action</code>               (<code>Array</code>)           \u2013            <p>The action to take.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>StepReturn</code>           \u2013            <p>The updated graph state, observation, reward, terminated, truncated, and info</p> </li> </ul>"},{"location":"api/environment.html#rex.rl.NormalizeVec","title":"<code>rex.rl.NormalizeVec</code>","text":"<p>Attributes     mean: The mean of the observations.     var: The variance of the observations.     count: The number of observations.     return_val: The return value.     clip: The clipping value.</p>"},{"location":"api/environment.html#rex.rl.NormalizeVec.normalize","title":"<code>normalize(x: jax.Array, clip: bool = True, subtract_mean: bool = True) -&gt; jax.Array</code>","text":"<p>Normalize x to have zero mean and unit variance.</p> <p>Parameters:</p> <ul> <li> <code>x</code>               (<code>Array</code>)           \u2013            <p>The input to normalize.</p> </li> <li> <code>clip</code>               (<code>bool</code>, default:                   <code>True</code> )           \u2013            <p>Whether to clip the input.</p> </li> <li> <code>subtract_mean</code>               (<code>bool</code>, default:                   <code>True</code> )           \u2013            <p>Whether to subtract the mean.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>Array</code>           \u2013            <p>The normalized input.</p> </li> </ul>"},{"location":"api/environment.html#rex.rl.NormalizeVec.denormalize","title":"<code>denormalize(x: jax.Array, add_mean: bool = True) -&gt; jax.Array</code>","text":"<p>Denormalize x to have the original mean and variance.</p> <p>Parameters:</p> <ul> <li> <code>x</code>               (<code>Array</code>)           \u2013            <p>The input to denormalize.</p> </li> <li> <code>add_mean</code>               (<code>bool</code>, default:                   <code>True</code> )           \u2013            <p>Whether to add the mean.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>Array</code>           \u2013            <p>The denormalized input.</p> </li> </ul>"},{"location":"api/evosax.html","title":"Evosax","text":""},{"location":"api/evosax.html#rex.evo.evo","title":"<code>rex.evo.evo(loss: Loss, solver: EvoSolver, init_state: evx.strategy.EvoState, transform: Transform, max_steps: int = 100, rng: jax.Array = None, verbose: bool = True, logger: LogState = None) -&gt; Tuple[evx.strategy.EvoState, LogState, jax.Array]</code>","text":"<p>Run the Evolutionary Solver (can be jit-compiled).</p> <p>Parameters:</p> <ul> <li> <code>loss</code>               (<code>Loss</code>)           \u2013            <p>Loss function.</p> </li> <li> <code>solver</code>               (<code>EvoSolver</code>)           \u2013            <p>Evolutionary Solver.</p> </li> <li> <code>init_state</code>               (<code>EvoState</code>)           \u2013            <p>Initial state of the Evolutionary Solver.</p> </li> <li> <code>transform</code>               (<code>Transform</code>)           \u2013            <p>Transform function to go from a normalized set of trainable parameters to the denormalized and extended set of parameters.</p> </li> <li> <code>max_steps</code>               (<code>int</code>, default:                   <code>100</code> )           \u2013            <p>Maximum number of steps to run the Evolutionary Solver.</p> </li> <li> <code>rng</code>               (<code>Array</code>, default:                   <code>None</code> )           \u2013            <p>Random number generator.</p> </li> <li> <code>verbose</code>               (<code>bool</code>, default:                   <code>True</code> )           \u2013            <p>Whether to print the progress.</p> </li> <li> <code>logger</code>               (<code>LogState</code>, default:                   <code>None</code> )           \u2013            <p>Logger for the Evolutionary Solver.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>final_state</code> (              <code>EvoState</code> )          \u2013            <p>Final state of the Evolutionary Solver.</p> </li> <li> <code>logger</code> (              <code>LogState</code> )          \u2013            <p>Logger for the Evolutionary Solver.</p> </li> <li> <code>losses</code> (              <code>Array</code> )          \u2013            <p>Losses at each step.</p> </li> </ul>"},{"location":"api/evosax.html#rex.evo.EvoSolver","title":"<code>rex.evo.EvoSolver</code>","text":"<p>Evolutionary Solver class to manage the evolutionary strategy and its parameters.</p> <p>Attributes:</p> <ul> <li> <code>strategy_params</code>               (<code>EvoParams</code>)           \u2013            <p>Parameters for the evolutionary strategy.</p> </li> <li> <code>strategy</code>               (<code>Strategy</code>)           \u2013            <p>Instance of the evolutionary strategy.</p> </li> <li> <code>strategy_name</code>               (<code>str</code>)           \u2013            <p>Name of the strategy used.</p> </li> </ul>"},{"location":"api/evosax.html#rex.evo.EvoSolver.init","title":"<code>init(u_min: Dict[str, Params], u_max: Dict[str, Params], strategy: str, strategy_kwargs: Dict = None, fitness_kwargs: Dict = None) -&gt; EvoSolver</code>  <code>classmethod</code>","text":"<p>Initialize the Evolutionary Solver.</p> <p>Parameters:</p> <ul> <li> <code>u_min</code>               (<code>Dict[str, Params]</code>)           \u2013            <p>(Normalized) Minimum values for the parameters (pytree).</p> </li> <li> <code>u_max</code>               (<code>Dict[str, Params]</code>)           \u2013            <p>(Normalized) Maximum values for the parameters (pytree).</p> </li> <li> <code>strategy</code>               (<code>str</code>)           \u2013            <p>Name of the strategy to use from evosax.Strategies.</p> </li> <li> <code>strategy_kwargs</code>               (<code>Dict</code>, default:                   <code>None</code> )           \u2013            <p>Keyword arguments to pass to the strategy.</p> </li> <li> <code>fitness_kwargs</code>               (<code>Dict</code>, default:                   <code>None</code> )           \u2013            <p>Keyword arguments to pass to the fitness function of the strategy.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>EvoSolver</code>           \u2013            <p>EvoSolver instance.</p> </li> </ul>"},{"location":"api/evosax.html#rex.evo.EvoSolver.init_state","title":"<code>init_state(mean: Dict[str, Params], rng: jax.Array = None) -&gt; EvoState</code>","text":"<p>Initialize the state of the Evolutionary Solver.</p> <p>Parameters:</p> <ul> <li> <code>mean</code>               (<code>Dict[str, Params]</code>)           \u2013            <p>Normalized mean values for the parameters (pytree).</p> </li> <li> <code>rng</code>               (<code>Array</code>, default:                   <code>None</code> )           \u2013            <p>Random number generator.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>EvoState</code> (              <code>EvoState</code> )          \u2013            <p>The initialized state of the Evolutionary Solver.</p> </li> </ul>"},{"location":"api/evosax.html#rex.evo.EvoSolver.init_logger","title":"<code>init_logger(num_generations: int, top_k: int = 5, maximize: bool = False) -&gt; LogState</code>","text":"<p>Initialize the logger for the Evolutionary Solver.</p> <p>Parameters:</p> <ul> <li> <code>num_generations</code>               (<code>int</code>)           \u2013            <p>Number of generations to log.</p> </li> <li> <code>top_k</code>               (<code>int</code>, default:                   <code>5</code> )           \u2013            <p>Number of top individuals to log.</p> </li> <li> <code>maximize</code>               (<code>bool</code>, default:                   <code>False</code> )           \u2013            <p>Whether the strategy is maximizing or minimizing.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>LogState</code> (              <code>LogState</code> )          \u2013            <p>The initialized log state.</p> </li> </ul>"},{"location":"api/evosax.html#rex.evo.EvoState","title":"<code>rex.evo.EvoState = evx.strategy.EvoState</code>  <code>module-attribute</code>","text":""},{"location":"api/evosax.html#rex.evo.LogState","title":"<code>rex.evo.LogState</code>","text":"<p>LogState class to manage the logging of evolutionary strategy states.</p> <p>Attributes:</p> <ul> <li> <code>state</code>               (<code>Dict</code>)           \u2013            <p>The current state of the logger.</p> </li> <li> <code>logger</code>               (<code>ESLog</code>)           \u2013            <p>The logger instance used for logging.</p> </li> </ul>"},{"location":"api/evosax.html#rex.evo.LogState.save","title":"<code>save(filename: str)</code>","text":"<p>Save the log state to a file.</p> <p>Parameters:</p> <ul> <li> <code>filename</code>               (<code>str</code>)           \u2013            <p>The name of the file to save to.</p> </li> </ul>"},{"location":"api/evosax.html#rex.evo.LogState.load","title":"<code>load(filename: str) -&gt; LogState</code>","text":"<p>Load the log state from a file.</p> <p>Parameters:</p> <ul> <li> <code>filename</code>               (<code>str</code>)           \u2013            <p>The name of the file to load from.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>LogState</code> (              <code>LogState</code> )          \u2013            <p>The loaded log state.</p> </li> </ul>"},{"location":"api/evosax.html#rex.evo.LogState.plot","title":"<code>plot(title: str, ylims: List[int] = None, fig: plt.Figure = None, ax: plt.Axes = None, no_legend: bool = False) -&gt; Tuple[plt.Figure, plt.Axes]</code>","text":"<p>Plot the log state.</p> <p>Parameters:</p> <ul> <li> <code>title</code>               (<code>str</code>)           \u2013            <p>The title of the plot.</p> </li> <li> <code>ylims</code>               (<code>List[int]</code>, default:                   <code>None</code> )           \u2013            <p>The y-axis limits.</p> </li> <li> <code>fig</code>               (<code>Figure</code>, default:                   <code>None</code> )           \u2013            <p>The figure to plot on.</p> </li> <li> <code>ax</code>               (<code>Axes</code>, default:                   <code>None</code> )           \u2013            <p>The axes to plot on.</p> </li> <li> <code>no_legend</code>               (<code>bool</code>, default:                   <code>False</code> )           \u2013            <p>Whether to omit the legend.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>Tuple[Figure, Axes]</code>           \u2013            <p>The plot.</p> </li> </ul>"},{"location":"api/gmm_estimator.html","title":"Gmm estimator","text":""},{"location":"api/gmm_estimator.html#rex.gmm_estimator.GMMEstimator","title":"<code>rex.gmm_estimator.GMMEstimator</code>","text":""},{"location":"api/gmm_estimator.html#rex.gmm_estimator.GMMEstimator.__init__","title":"<code>__init__(data: jax.typing.ArrayLike, name: str = 'GMM', threshold: float = 1e-07, verbose: bool = True)</code>","text":"<p>Gaussian Mixture Model Estimator.</p> <p>Parameters:</p> <ul> <li> <code>data</code>               (<code>ArrayLike</code>)           \u2013            <p>1D array of delay data.</p> </li> <li> <code>name</code>               (<code>str</code>, default:                   <code>'GMM'</code> )           \u2013            <p>Name of the model.</p> </li> <li> <code>threshold</code>               (<code>float</code>, default:                   <code>1e-07</code> )           \u2013            <p>Threshold for determining if the data is deterministic.</p> </li> <li> <code>verbose</code>               (<code>bool</code>, default:                   <code>True</code> )           \u2013            <p>Whether to print progress.</p> </li> </ul>"},{"location":"api/gmm_estimator.html#rex.gmm_estimator.GMMEstimator.fit","title":"<code>fit(num_steps: int = 100, num_components: int = 2, step_size: float = 0.05, seed: int = 0)</code>","text":"<p>Fit the model to the data.</p> <p>Parameters:</p> <ul> <li> <code>num_steps</code>               (<code>int</code>, default:                   <code>100</code> )           \u2013            <p>Number of steps to train the model.</p> </li> <li> <code>num_components</code>               (<code>int</code>, default:                   <code>2</code> )           \u2013            <p>Number of components in the mixture model.</p> </li> <li> <code>step_size</code>               (<code>float</code>, default:                   <code>0.05</code> )           \u2013            <p>Step size for the optimizer.</p> </li> <li> <code>seed</code>               (<code>int</code>, default:                   <code>0</code> )           \u2013            <p>Random seed.</p> </li> </ul>"},{"location":"api/gmm_estimator.html#rex.gmm_estimator.GMMEstimator.get_dist","title":"<code>get_dist(percentile: float = 0.99) -&gt; base.StaticDist</code>","text":"<p>Get the distribution.</p> <p>Parameters:</p> <ul> <li> <code>percentile</code>               (<code>float</code>, default:                   <code>0.99</code> )           \u2013            <p>A percentile to prune the number of components that do not contribute much.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>StaticDist</code>           \u2013            <p>base.StaticDist: The distribution object.</p> </li> </ul>"},{"location":"api/gmm_estimator.html#rex.gmm_estimator.GMMEstimator.plot_hist","title":"<code>plot_hist(ax: plt.Axes = None, edgecolor: str = None, facecolor: str = None, bins: int = 100, xmin: float = None, xmax: float = None, num_points: int = 1000, plot_dist: bool = True) -&gt; plt.Axes</code>","text":"<p>Plot the histogram of the data and the fitted distribution.</p> <p>Parameters:</p> <ul> <li> <code>ax</code>               (<code>Axes</code>, default:                   <code>None</code> )           \u2013            <p>Axes to plot on.</p> </li> <li> <code>edgecolor</code>               (<code>str</code>, default:                   <code>None</code> )           \u2013            <p>Edge color of the histogram.</p> </li> <li> <code>facecolor</code>               (<code>str</code>, default:                   <code>None</code> )           \u2013            <p>Face color of the histogram.</p> </li> <li> <code>bins</code>               (<code>int</code>, default:                   <code>100</code> )           \u2013            <p>Number of bins for the histogram.</p> </li> <li> <code>xmin</code>               (<code>float</code>, default:                   <code>None</code> )           \u2013            <p>Minimum x value for the histogram. Can be used to avoid outliers.</p> </li> <li> <code>xmax</code>               (<code>float</code>, default:                   <code>None</code> )           \u2013            <p>Maximum x value for the histogram. Can be used to avoid outliers.</p> </li> <li> <code>num_points</code>               (<code>int</code>, default:                   <code>1000</code> )           \u2013            <p>Number of points to plot the distribution.</p> </li> <li> <code>plot_dist</code>               (<code>bool</code>, default:                   <code>True</code> )           \u2013            <p>Whether to plot the fitted distribution.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>Axes</code>           \u2013            <p>The axes with the plot.</p> </li> </ul>"},{"location":"api/gmm_estimator.html#rex.gmm_estimator.GMMEstimator.plot_loss","title":"<code>plot_loss(ax: plt.Axes = None, edgecolor: str = None) -&gt; plt.Axes</code>","text":"<p>Plot the loss function.</p> <p>Parameters:</p> <ul> <li> <code>ax</code>               (<code>Axes</code>, default:                   <code>None</code> )           \u2013            <p>Axes to plot on.</p> </li> <li> <code>edgecolor</code>               (<code>str</code>, default:                   <code>None</code> )           \u2013            <p>Edge color of the plot.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>Axes</code>           \u2013            <p>plt.Axes: The axes with the plot.</p> </li> </ul>"},{"location":"api/gmm_estimator.html#rex.gmm_estimator.GMMEstimator.plot_normalized_weights","title":"<code>plot_normalized_weights(ax: plt.Axes = None, edgecolor: str = None) -&gt; plt.Axes</code>","text":"<p>Plot the normalized weights.</p> <p>Parameters:</p> <ul> <li> <code>ax</code>               (<code>Axes</code>, default:                   <code>None</code> )           \u2013            <p>Axes to plot on.</p> </li> <li> <code>edgecolor</code>               (<code>str</code>, default:                   <code>None</code> )           \u2013            <p>Edge color of the plot.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>Axes</code>           \u2013            <p>The axes with the plot.</p> </li> </ul>"},{"location":"api/gmm_estimator.html#rex.gmm_estimator.GMMEstimator.animate_training","title":"<code>animate_training(num_frames: int = 30, fig: plt.Figure = None, ax: plt.Axes = None, edgecolor: str = None, facecolor: str = None, bins: int = 40, xmin: float = None, xmax: float = None, num_points: int = 1000) -&gt; matplotlib.animation.FuncAnimation</code>","text":"<p>Animate the training process.</p> <p>Parameters:</p> <ul> <li> <code>num_frames</code>               (<code>int</code>, default:                   <code>30</code> )           \u2013            <p>Number of frames to animate.</p> </li> <li> <code>fig</code>               (<code>Figure</code>, default:                   <code>None</code> )           \u2013            <p>Figure to plot on.</p> </li> <li> <code>ax</code>               (<code>Axes</code>, default:                   <code>None</code> )           \u2013            <p>Axes to plot on.</p> </li> <li> <code>edgecolor</code>               (<code>str</code>, default:                   <code>None</code> )           \u2013            <p>Edge color of the histogram.</p> </li> <li> <code>facecolor</code>               (<code>str</code>, default:                   <code>None</code> )           \u2013            <p>Face color of the histogram.</p> </li> <li> <code>bins</code>               (<code>int</code>, default:                   <code>40</code> )           \u2013            <p>Number of bins for the histogram.</p> </li> <li> <code>xmin</code>               (<code>float</code>, default:                   <code>None</code> )           \u2013            <p>Minimum x value for the histogram. Can be used to avoid outliers.</p> </li> <li> <code>xmax</code>               (<code>float</code>, default:                   <code>None</code> )           \u2013            <p>Maximum x value for the histogram. Can be used to avoid outliers.</p> </li> <li> <code>num_points</code>               (<code>int</code>, default:                   <code>1000</code> )           \u2013            <p>Number of points to plot the distribution.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>FuncAnimation</code>           \u2013            <p>matplotlib.animation.FuncAnimation: The animation object.</p> </li> </ul>"},{"location":"api/node.html","title":"Node","text":""},{"location":"api/node.html#rex.node.BaseNode","title":"<code>rex.node.BaseNode</code>","text":""},{"location":"api/node.html#rex.node.BaseNode.info","title":"<code>info: base.NodeInfo</code>  <code>property</code>","text":"<p>Get the node info.</p>"},{"location":"api/node.html#rex.node.BaseNode.__init__","title":"<code>__init__(name: str, rate: float, delay: float = None, delay_dist: Union[base.DelayDistribution, distrax.Distribution] = None, advance: bool = False, scheduling: Scheduling = Scheduling.FREQUENCY, color: str = None, order: int = None)</code>","text":"<p>Base node class. All nodes should inherit from this class.</p> Basic template for a node class: <pre><code>class MyNode(BaseNode):\n    def __init__(self, *args, extra_arg, **kwargs):  # Optional\n        super().__init__(*args, **kwargs)\n        self.extra_arg = extra_arg\n\n    def init_params(self, rng=None, graph_state=None):  # Optional\n        return MyParams(param1=1.0, param2=2.0)\n\n    def init_state(self, rng=None, graph_state=None): # Optional\n        return MyState(state1=1.0, state2=2.0)\n\n    def init_output(self, rng=None, graph_state=None):  # Required\n        return MyOutput(output1=1.0, output2=2.0)\n\n    def init_delays(self, rng=None, graph_state=None):  # Optional\n        # Set trainable delays to values from params\n        params = graph_state.params[self.name]\n        return {\"some_node\": params.param1}  # Connected node name\n\n    def startup(self, graph_state, timeout=None):  # Optional\n        # Move the robot to a starting position\n        return True\n\n    def step(self, step_state):  # Required\n        # Unpack step state\n        params = step_state.params\n        state = step_state.state\n        inputs = step_state.inputs\n        # Calculate output\n        output = MyOutput(...)\n        # Update state\n        new_state = MyState(...)\n        return step_state.replace(state=new_state), output\n\n    def stop(self, timeout=None):  # Optional\n        # Safely the robot at the end of the episode\n        return True\n</code></pre> <p>Parameters:</p> <ul> <li> <code>name</code>               (<code>str</code>)           \u2013            <p>The name of the node (unique).</p> </li> <li> <code>rate</code>               (<code>float</code>)           \u2013            <p>The rate of the node (Hz).</p> </li> <li> <code>delay</code>               (<code>float</code>, default:                   <code>None</code> )           \u2013            <p>The expected computation delay of the node (s). Used to calculate the phase shift.</p> </li> <li> <code>delay_dist</code>               (<code>Union[DelayDistribution, Distribution]</code>, default:                   <code>None</code> )           \u2013            <p>The computation delay distribution of the node for simulation.</p> </li> <li> <code>advance</code>               (<code>bool</code>, default:                   <code>False</code> )           \u2013            <p>Whether the node's step triggers when all inputs are ready, or throttles until the scheduled time.</p> </li> <li> <code>scheduling</code>               (<code>Scheduling</code>, default:                   <code>FREQUENCY</code> )           \u2013            <p>The scheduling of the node. If <code>FREQUENCY</code>, the node is scheduled at a fixed rate, while ignoring         any phase shift w.r.t the clock. If <code>PHASE</code>, the node steps are scheduled at a fixed rate and phase         w.r.t the clock.</p> </li> <li> <code>color</code>               (<code>str</code>, default:                   <code>None</code> )           \u2013            <p>The color of the node (for visualization).</p> </li> <li> <code>order</code>               (<code>int</code>, default:                   <code>None</code> )           \u2013            <p>The order of the node (for visualization).</p> </li> </ul>"},{"location":"api/node.html#rex.node.BaseNode.connect","title":"<code>connect(output_node: BaseNode, blocking: bool = False, delay: float = None, delay_dist: Union[base.DelayDistribution, distrax.Distribution] = None, window: int = 1, skip: bool = False, jitter: Jitter = Jitter.LATEST, name: str = None)</code>","text":"<p>Connects the node to another node.</p> <p>Parameters:</p> <ul> <li> <code>output_node</code>               (<code>BaseNode</code>)           \u2013            <p>The node to connect to.</p> </li> <li> <code>blocking</code>               (<code>bool</code>, default:                   <code>False</code> )           \u2013            <p>Whether the connection is blocking.</p> </li> <li> <code>delay</code>               (<code>float</code>, default:                   <code>None</code> )           \u2013            <p>The expected communication delay of the connection.</p> </li> <li> <code>delay_dist</code>               (<code>Union[DelayDistribution, Distribution]</code>, default:                   <code>None</code> )           \u2013            <p>The communication delay distribution of the connection for simulation.</p> </li> <li> <code>window</code>               (<code>int</code>, default:                   <code>1</code> )           \u2013            <p>The window size of the connection. It determines how many output messages are used as input to     the <code>.step()</code> function.</p> </li> <li> <code>skip</code>               (<code>bool</code>, default:                   <code>False</code> )           \u2013            <p>Whether to skip the connection. It resolves cyclic dependencies, by skipping the output if it arrives   at the same time as the start of the <code>.step()</code> function (i.e. <code>step_state.ts</code>).</p> </li> <li> <code>jitter</code>               (<code>Jitter</code>, default:                   <code>LATEST</code> )           \u2013            <p>How to deal with jitter of the connection. If <code>LATEST</code>, the latest messages are used. If <code>BUFFER</code>, the     messages are buffered and used in accordance with the expected delay.</p> </li> <li> <code>name</code>               (<code>str</code>, default:                   <code>None</code> )           \u2013            <p>A shadow name for the connected node. If <code>None</code>, the name of the output node is used.</p> </li> </ul>"},{"location":"api/node.html#rex.node.BaseNode.init_params","title":"<code>init_params(rng: jax.Array = None, graph_state: base.GraphState = None) -&gt; base.Params</code>","text":"<p>Init params of the node.</p> <p>The params are composed of values that remain constant during an episode (e.g. network weights).</p> <p>At this point, the graph state may contain the params of other nodes required to get the default params. The order of node initialization can be specified in Graph.init(... order=[node1, node2, ...]).</p> <p>Parameters:</p> <ul> <li> <code>rng</code>               (<code>Array</code>, default:                   <code>None</code> )           \u2013            <p>Random number generator.</p> </li> <li> <code>graph_state</code>               (<code>GraphState</code>, default:                   <code>None</code> )           \u2013            <p>The graph state that may be used to get the default params.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>Params</code>           \u2013            <p>The default params of the node.</p> </li> </ul>"},{"location":"api/node.html#rex.node.BaseNode.init_state","title":"<code>init_state(rng: jax.Array = None, graph_state: base.GraphState = None) -&gt; base.State</code>","text":"<p>Init state of the node.</p> <p>The state is composed of values that are updated during the episode in the <code>.step()</code> function (e.g. position, velocity).</p> <p>At this point, the params of all nodes are already initialized and present in the graph state (if specified). Moreover, the state of other nodes required to get the default state may also be present in the graph state. The order of node initialization can be specified in <code>Graph.init(... order=[node1, node2, ...])</code>.</p> <p>Parameters:</p> <ul> <li> <code>rng</code>               (<code>Array</code>, default:                   <code>None</code> )           \u2013            <p>Random number generator.</p> </li> <li> <code>graph_state</code>               (<code>GraphState</code>, default:                   <code>None</code> )           \u2013            <p>The graph state that may be used to get the default state.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>State</code>           \u2013            <p>The default state of the node.</p> </li> </ul>"},{"location":"api/node.html#rex.node.BaseNode.init_inputs","title":"<code>init_inputs(rng: jax.Array = None, graph_state: base.GraphState = None) -&gt; FrozenDict[str, base.InputState]</code>","text":"<p>Initialize default inputs for the node.</p> <p>Fills input buffers with default outputs from connected nodes. Used during the initial steps of an episode when input buffers are not yet filled.</p> <p>Parameters:</p> <ul> <li> <code>rng</code>               (<code>Array</code>, default:                   <code>None</code> )           \u2013            <p>Random number generator.</p> </li> <li> <code>graph_state</code>               (<code>GraphState</code>, default:                   <code>None</code> )           \u2013            <p>The graph state that may be used to get the default inputs.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>FrozenDict[str, InputState]</code>           \u2013            <p>The default inputs of the node.</p> </li> </ul>"},{"location":"api/node.html#rex.node.BaseNode.init_delays","title":"<code>init_delays(rng: jax.Array = None, graph_state: base.GraphState = None) -&gt; Dict[str, Union[float, jax.typing.ArrayLike]]</code>","text":"<p>Initialize trainable communication delays.</p> Note <p>These delays include only trainable connections. To make a delay trainable, replace the parameters in the delay distribution with trainable parameters.</p> A rough template for the init_delays function is as follows: <pre><code>def init_delays(self, rng=None, graph_state=None):\n    # Assumes graph_state contains the params of the node\n    params = graph_state.params[self.name]\n    trainable_delays = {\"world\": params.delay_param}\n    return trainable_delays\n</code></pre> <p>Parameters:</p> <ul> <li> <code>rng</code>               (<code>Array</code>, default:                   <code>None</code> )           \u2013            <p>Random number generator.</p> </li> <li> <code>graph_state</code>               (<code>GraphState</code>, default:                   <code>None</code> )           \u2013            <p>The graph state that may be used to get the default output.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>Dict[str, Union[float, ArrayLike]]</code>           \u2013            <p>Trainable delays. Can be an incomplete dictionary.</p> </li> <li> <code>Dict[str, Union[float, ArrayLike]]</code>           \u2013            <p>Entries for non-trainable delays or non-existent connections are ignored.</p> </li> </ul>"},{"location":"api/node.html#rex.node.BaseNode.init_step_state","title":"<code>init_step_state(rng: jax.Array = None, graph_state: base.GraphState = None) -&gt; base.StepState</code>","text":"<p>Initializes the step state of the node, which is used to run the <code>seq</code>'th step of the node at time <code>ts</code>.</p> <ul> <li><code>BaseNode.init_params</code></li> <li><code>BaseNode.init_state</code></li> <li><code>BaseNode.init_inputs</code> using <code>BaseNode.init_output</code> of connected nodes (to fill the input buffers)</li> </ul> Note <p>If a node's initialization depends on the params, state, or inputs of other nodes this may fail.  In such cases, the user can provide a graph state with the necessary information to get the default step state.</p> <p>Parameters:</p> <ul> <li> <code>rng</code>               (<code>Array</code>, default:                   <code>None</code> )           \u2013            <p>Random number generator.</p> </li> <li> <code>graph_state</code>               (<code>GraphState</code>, default:                   <code>None</code> )           \u2013            <p>The graph state that may be used to get the default step state.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>StepState</code>           \u2013            <p>The default step state of the node.</p> </li> </ul>"},{"location":"api/node.html#rex.node.BaseNode.startup","title":"<code>startup(graph_state: base.GraphState, timeout: float = None) -&gt; bool</code>","text":"<p>Initializes the node to the state specified by <code>graph_state</code>. This method is called just before an episode starts. It can be used to move a real robot to a starting position as specified by the <code>graph_state</code>.</p> Note <p>Only called when running asynchronously.</p> <p>Parameters:</p> <ul> <li> <code>graph_state</code>               (<code>GraphState</code>)           \u2013            <p>The graph state.</p> </li> <li> <code>timeout</code>               (<code>float</code>, default:                   <code>None</code> )           \u2013            <p>The timeout of the startup.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>bool</code>           \u2013            <p>Whether the node has started successfully.</p> </li> </ul>"},{"location":"api/node.html#rex.node.BaseNode.stop","title":"<code>stop(timeout: float = None) -&gt; bool</code>","text":"<p>Stopping routine that is called after the episode is done.</p> Note <p>Only called when running asynchronously.</p> Warning <p>It may happen that stop is already called before the final <code>.step</code> call of an episode returns, which may cause unsafe behavior when the final step undoes the work of the .stop method. This should be handled by the user. For example, by stopping \"longer\" before returning here.</p> <p>Parameters:</p> <ul> <li> <code>timeout</code>               (<code>float</code>, default:                   <code>None</code> )           \u2013            <p>The timeout of the stop.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>bool</code>           \u2013            <p>Whether the node has stopped successfully.</p> </li> </ul>"},{"location":"api/node.html#rex.node.BaseNode.step","title":"<code>step(step_state: base.StepState) -&gt; Tuple[base.StepState, base.Output]</code>","text":"<p>Execute the node for the <code>seq</code>-th time step at time <code>ts</code>. This function updates the node's state and generates an output, which is sent to connected nodes. It is called at the node's rate.  Users are expected to update the state (and rng if used), but not the seq and ts, as they are automatically updated.</p> Wrapping side-effecting code <p>Side-effecting code should be wrapped to ensure execution on the host machine when using <code>jax.jit</code>. See here for more info.</p> A rough template for the step function is as follows: <pre><code>def step(step_state: base.StepState) -&gt; Tuple[base.StepState, base.Output]:\n    # Per input with `input_name`, the following information is available:\n    step_state.inputs[input_name][window_index].data # A window_index of -1 leads to the most recent message.\n    step_state.inputs[input_name][window_index].seq # The sequence number of the message.\n    step_state.inputs[input_name][window_index].ts_sent # The time the message was sent.\n    step_state.inputs[input_name][window_index].ts_recv # The time the message was received.\n\n    # The following information is available for the node:\n    step_state.params # The parameters of the node.\n    step_state.state # The state of the node.\n    step_state.eps # The episode number.\n    step_state.seq # The sequence number.\n    step_state.ts # The time of the step within the episode.\n    step_state.rng # The random number generator.\n\n    # Calculate output and updated state\n    new_rng, rng_step = jax.random.split(step_state.rng)\n    output = ...\n    new_state = ...\n\n    # Update the state of the node\n    new_ss = step_state.replace(rng=new_rng, state=new_state)  #\n    return new_ss, output\n</code></pre> <p>Parameters:</p> <ul> <li> <code>step_state</code>               (<code>StepState</code>)           \u2013            <p>The step state of the node.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>Tuple[StepState, Output]</code>           \u2013            <p>The updated step state and the output of the node.</p> </li> </ul>"},{"location":"api/node.html#rex.node.BaseNode.now","title":"<code>now() -&gt; float</code>","text":"<p>Get the passed time since start of episode according to the simulated and wall clock.</p> <p>Returns:</p> <ul> <li> <code>float</code>           \u2013            <p>Time since start of episode. Only returns &gt; 0 timestamps if running asynchronously.</p> </li> </ul>"},{"location":"api/node.html#rex.node.BaseNode.set_delay","title":"<code>set_delay(delay_dist: Union[base.DelayDistribution, distrax.Distribution] = None, delay: float = None)</code>","text":"<p>Set the delay distribution and delay for the computation delay of the node.</p> <p>Parameters:</p> <ul> <li> <code>delay_dist</code>               (<code>Union[DelayDistribution, Distribution]</code>, default:                   <code>None</code> )           \u2013            <p>The delay distribution to simulate.</p> </li> <li> <code>delay</code>               (<code>float</code>, default:                   <code>None</code> )           \u2013            <p>The delay to take into account for the phase shift.</p> </li> </ul>"},{"location":"api/node.html#rex.node.BaseNode.from_info","title":"<code>from_info(info: base.NodeInfo, **kwargs: Dict[str, Any])</code>  <code>classmethod</code>","text":"<p>Re-instantiates a Node from a NodeInfo object.</p> Don't forget to call <code>connect_from_info()</code>. <p>Make sure to call connect_from_info() on the resulting subclass object to restore the connections.</p> Note <p>This method attempts to restore the subclass object from the BaseNode object. Hence, it requires any additional arguments to be passed as keyword arguments. Moreover, the signature of the subclass must be the same as the BaseNode, except for the additional args and *kwargs.</p> <p>Parameters:</p> <ul> <li> <code>info</code>               (<code>NodeInfo</code>)           \u2013            <p>Node info object.</p> </li> <li> <code>**kwargs</code>               (<code>Dict[str, Any]</code>, default:                   <code>{}</code> )           \u2013            <p>Additional keyword arguments for the subclass.</p> </li> </ul>"},{"location":"api/node.html#rex.node.BaseNode.connect_from_info","title":"<code>connect_from_info(infos: [str, base.InputInfo], nodes: Dict[str, BaseNode])</code>","text":"<p>Connects the node to other nodes based on the input infos.</p> <p>Parameters:</p> <ul> <li> <code>infos</code>               (<code>[str, InputInfo]</code>)           \u2013            <p>A dictionary of input names to input infos.</p> </li> <li> <code>nodes</code>               (<code>Dict[str, BaseNode]</code>)           \u2013            <p>A dictionary of node names to node objects.</p> </li> </ul>"},{"location":"api/node.html#rex.node.BaseWorld","title":"<code>rex.node.BaseWorld</code>","text":"<p>               Bases: <code>BaseNode</code></p>"},{"location":"api/node.html#rex.node.BaseWorld.__init__","title":"<code>__init__(name: str, rate: float, color: str = None, order: int = None, **kwargs)</code>","text":"<p>Base node class for world (i.e. simulator) nodes.</p> <p>A convenience class that pre-sets parameters for nodes that simulate real-world processes. That is, nodes that simulate continuous processes in a discrete manner.</p> <ul> <li>The delay distribution is set to the time step of the node (~1/rate). It's currently set slightly below the time     step to ensure numerical stability, as else we may unavoidably introduce more delay.</li> <li>The advance is set to False, as the world node should adhere to the rate of the node.</li> <li>The scheduling is set to FREQUENCY, as the world node should adhere to the rate of the node.</li> </ul> <p>Parameters:</p> <ul> <li> <code>name</code>               (<code>str</code>)           \u2013            <p>The name of the node (unique).</p> </li> <li> <code>rate</code>               (<code>float</code>)           \u2013            <p>The rate of the node (Hz).</p> </li> <li> <code>color</code>               (<code>str</code>, default:                   <code>None</code> )           \u2013            <p>The color of the node (for visualization).</p> </li> <li> <code>order</code>               (<code>int</code>, default:                   <code>None</code> )           \u2013            <p>The order of the node (for visualization).</p> </li> </ul>"},{"location":"api/ppo.html","title":"Proximal Policy Optimization","text":""},{"location":"api/ppo.html#rex.ppo.train","title":"<code>rex.ppo.train(env: Union[BaseEnv, Environment], config: Config, rng: jax.Array) -&gt; PPOResult</code>","text":"<p>Train the PPO model.</p> <p>PPO implementation based on the PPO implementation from purejaxrl: https://github.com/luchris429/purejaxrl</p> <p>Parameters:</p> <ul> <li> <code>env</code>               (<code>Union[BaseEnv, Environment]</code>)           \u2013            <p>The environment to train on.</p> </li> <li> <code>config</code>               (<code>Config</code>)           \u2013            <p>Configuration for the PPO algorithm.</p> </li> <li> <code>rng</code>               (<code>Array</code>)           \u2013            <p>Random number generator key.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>PPOResult</code> (              <code>PPOResult</code> )          \u2013            <p>The result of the training process.</p> </li> </ul>"},{"location":"api/ppo.html#rex.ppo.Config","title":"<code>rex.ppo.Config</code>","text":"<p>               Bases: <code>Base</code></p> <p>Configuration for PPO.</p> <p>Inherit from this class and override the <code>EVAL_METRICS_JAX_CB</code> and <code>EVAL_METRICS_HOST_CB</code> methods to customize the evaluation metrics and the host-side callback for the evaluation metrics.</p> <p>Attributes:</p> <ul> <li> <code>LR</code>               (<code>float</code>)           \u2013            <p>The learning rate.</p> </li> <li> <code>NUM_ENVS</code>               (<code>int</code>)           \u2013            <p>The number of parallel environments.</p> </li> <li> <code>NUM_STEPS</code>               (<code>int</code>)           \u2013            <p>The number of steps to run in each environment per update.</p> </li> <li> <code>TOTAL_TIMESTEPS</code>               (<code>int</code>)           \u2013            <p>The total number of timesteps to run.</p> </li> <li> <code>UPDATE_EPOCHS</code>               (<code>int</code>)           \u2013            <p>The number of epochs to run per update.</p> </li> <li> <code>NUM_MINIBATCHES</code>               (<code>int</code>)           \u2013            <p>The number of minibatches to split the data into.</p> </li> <li> <code>GAMMA</code>               (<code>float</code>)           \u2013            <p>The discount factor.</p> </li> <li> <code>GAE_LAMBDA</code>               (<code>float</code>)           \u2013            <p>The Generalized Advantage Estimation (GAE) parameter.</p> </li> <li> <code>CLIP_EPS</code>               (<code>float</code>)           \u2013            <p>The clipping parameter for the ratio in the policy loss.</p> </li> <li> <code>ENT_COEF</code>               (<code>float</code>)           \u2013            <p>The coefficient of the entropy regularizer.</p> </li> <li> <code>VF_COEF</code>               (<code>float</code>)           \u2013            <p>The value function coefficient.</p> </li> <li> <code>MAX_GRAD_NORM</code>               (<code>float</code>)           \u2013            <p>The maximum gradient norm.</p> </li> <li> <code>NUM_HIDDEN_LAYERS</code>               (<code>int</code>)           \u2013            <p>The number of hidden layers (same for actor and critic).</p> </li> <li> <code>NUM_HIDDEN_UNITS</code>               (<code>int</code>)           \u2013            <p>The number of hidden units per layer (same for actor and critic).</p> </li> <li> <code>KERNEL_INIT_TYPE</code>               (<code>str</code>)           \u2013            <p>The kernel initialization type (same for actor and critic).</p> </li> <li> <code>HIDDEN_ACTIVATION</code>               (<code>str</code>)           \u2013            <p>The hidden activation function (same for actor and critic).</p> </li> <li> <code>STATE_INDEPENDENT_STD</code>               (<code>bool</code>)           \u2013            <p>Whether to use state-independent standard deviation for the actor.</p> </li> <li> <code>SQUASH</code>               (<code>bool</code>)           \u2013            <p>Whether to squash the action output of the actor.</p> </li> <li> <code>ANNEAL_LR</code>               (<code>bool</code>)           \u2013            <p>Whether to anneal the learning rate.</p> </li> <li> <code>NORMALIZE_ENV</code>               (<code>bool</code>)           \u2013            <p>Whether to normalize the environment (observations and rewards), actions are always normalized.</p> </li> <li> <code>FIXED_INIT</code>               (<code>bool</code>)           \u2013            <p>Whether to use fixed initial states for each parallel environment.</p> </li> <li> <code>OFFSET_STEP</code>               (<code>bool</code>)           \u2013            <p>Whether to offset the step counter for each parallel environment to break temporal correlations.</p> </li> <li> <code>NUM_EVAL_ENVS</code>               (<code>int</code>)           \u2013            <p>The number of evaluation environments.</p> </li> <li> <code>EVAL_FREQ</code>               (<code>int</code>)           \u2013            <p>The number of evaluations to run per run of training.</p> </li> <li> <code>VERBOSE</code>               (<code>bool</code>)           \u2013            <p>Whether to print verbose output.</p> </li> <li> <code>DEBUG</code>               (<code>bool</code>)           \u2013            <p>Whether to print debug output per step.</p> </li> </ul>"},{"location":"api/ppo.html#rex.ppo.Config.EVAL_METRICS_JAX_CB","title":"<code>EVAL_METRICS_JAX_CB(total_steps: Union[int, jax.Array], diagnostics: Diagnostics, eval_transitions: Transition = None) -&gt; Dict</code>","text":"<p>Compute evaluation metrics for the PPO algorithm.</p> <p>Parameters:</p> <ul> <li> <code>total_steps</code>               (<code>Union[int, Array]</code>)           \u2013            <p>The total number of steps run.</p> </li> <li> <code>diagnostics</code>               (<code>Diagnostics</code>)           \u2013            <p>The diagnostics from the training process.</p> </li> <li> <code>eval_transitions</code>               (<code>Transition</code>, default:                   <code>None</code> )           \u2013            <p>The transitions from the evaluation process.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>Dict</code> (              <code>Dict</code> )          \u2013            <p>A dictionary containing the evaluation metrics.</p> </li> </ul>"},{"location":"api/ppo.html#rex.ppo.Config.EVAL_METRICS_HOST_CB","title":"<code>EVAL_METRICS_HOST_CB(metrics: Dict) -&gt; None</code>","text":"<p>Evaluate the evaluation metrics for the PPO algorithm on the host.</p> <p>Can be used for printing or logging the evaluation metrics on the host as this is side-effectful.</p> <p>Parameters:</p> <ul> <li> <code>metrics</code>               (<code>Dict</code>)           \u2013            <p>The evaluation metrics.</p> </li> </ul>"},{"location":"api/ppo.html#rex.ppo.PPOResult","title":"<code>rex.ppo.PPOResult</code>","text":"<p>               Bases: <code>Base</code></p> <p>Represents the result of the PPO training process.</p> <p>Attributes:</p> <ul> <li> <code>config</code>               (<code>Config</code>)           \u2013            <p>Configuration for the PPO algorithm.</p> </li> <li> <code>runner_state</code>               (<code>RunnerState</code>)           \u2013            <p>The state of the runner after training.</p> </li> <li> <code>metrics</code>               (<code>Dict[str, Any]</code>)           \u2013            <p>Dictionary containing various metrics collected during training.</p> </li> </ul>"},{"location":"api/ppo.html#rex.ppo.PPOResult.obs_scaling","title":"<code>obs_scaling: SquashState</code>  <code>property</code>","text":"<p>Returns the observation scaling parameters.</p>"},{"location":"api/ppo.html#rex.ppo.PPOResult.act_scaling","title":"<code>act_scaling: SquashActionWrapper</code>  <code>property</code>","text":"<p>Returns the action scaling parameters.</p>"},{"location":"api/ppo.html#rex.ppo.PPOResult.policy","title":"<code>policy: Policy</code>  <code>property</code>","text":"<p>Returns the policy model.</p>"},{"location":"api/ppo.html#rex.ppo.Policy","title":"<code>rex.ppo.Policy</code>","text":"<p>               Bases: <code>Base</code></p> <p>Represents the policy model.</p> <p>Attributes:</p> <ul> <li> <code>act_scaling</code>               (<code>SquashState</code>)           \u2013            <p>The action scaling parameters.</p> </li> <li> <code>obs_scaling</code>               (<code>NormalizeVec</code>)           \u2013            <p>The observation scaling parameters.</p> </li> <li> <code>model</code>               (<code>Dict[str, Dict[str, Union[ArrayLike, Any]]]</code>)           \u2013            <p>The model parameters.</p> </li> <li> <code>hidden_activation</code>               (<code>str</code>)           \u2013            <p>The hidden activation function.</p> </li> <li> <code>output_activation</code>               (<code>str</code>)           \u2013            <p>The output activation function.</p> </li> <li> <code>state_independent_std</code>               (<code>bool</code>)           \u2013            <p>Whether the standard deviation of the actor is state-independent</p> </li> </ul>"},{"location":"api/ppo.html#rex.ppo.Policy.apply_actor","title":"<code>apply_actor(norm_obs: jax.typing.ArrayLike, rng: jax.Array = None) -&gt; jax.Array</code>","text":"<p>Apply the actor model to the normalized observation</p> <p>Parameters:</p> <ul> <li> <code>norm_obs</code>               (<code>ArrayLike</code>)           \u2013            <p>The normalized observation</p> </li> <li> <code>rng</code>               (<code>Array</code>, default:                   <code>None</code> )           \u2013            <p>Random number generator key</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>Array</code>           \u2013            <p>The unscaled action</p> </li> </ul>"},{"location":"api/ppo.html#rex.ppo.Policy.get_action","title":"<code>get_action(obs: jax.typing.ArrayLike, rng: jax.Array = None) -&gt; jax.Array</code>","text":"<p>Get the action from the policy model</p> <p>Parameters:</p> <ul> <li> <code>obs</code>               (<code>ArrayLike</code>)           \u2013            <p>The observation</p> </li> <li> <code>rng</code>               (<code>Array</code>, default:                   <code>None</code> )           \u2013            <p>Random number generator key</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>Array</code>           \u2013            <p>The action, scaled to the action space.</p> </li> </ul>"},{"location":"api/ppo.html#rex.ppo.RunnerState","title":"<code>rex.ppo.RunnerState</code>","text":"<p>               Bases: <code>Base</code></p> <p>Represents the state of the runner during training.</p> <p>Attributes:</p> <ul> <li> <code>train_state</code>               (<code>TrainState</code>)           \u2013            <p>The state of the training process.</p> </li> <li> <code>env_state</code>               (<code>GraphState</code>)           \u2013            <p>The state of the environment.</p> </li> <li> <code>last_obs</code>               (<code>ArrayLike</code>)           \u2013            <p>The last observation.</p> </li> <li> <code>rng</code>               (<code>Array</code>)           \u2013            <p>Random number generator key</p> </li> </ul>"},{"location":"api/record.html","title":"Record","text":""},{"location":"api/record.html#rex.base.ExperimentRecord","title":"<code>rex.base.ExperimentRecord</code>","text":"<p>A data structure that holds recorded data of an experiment.</p> <p>Attributes:</p> <ul> <li> <code>episodes</code>               (<code>List[EpisodeRecord]</code>)           \u2013            <p>The episode records.</p> </li> </ul>"},{"location":"api/record.html#rex.base.ExperimentRecord.filter","title":"<code>filter(nodes: Dict[str, BaseNode], filter_connections: bool = False) -&gt; ExperimentRecord</code>","text":""},{"location":"api/record.html#rex.base.ExperimentRecord.to_graph","title":"<code>to_graph() -&gt; Graph</code>","text":""},{"location":"api/record.html#rex.base.ExperimentRecord.stack","title":"<code>stack(method: str = 'padded') -&gt; EpisodeRecord</code>","text":""},{"location":"api/record.html#rex.base.EpisodeRecord","title":"<code>rex.base.EpisodeRecord</code>","text":"<p>A data structure that holds recorded data of an episode.</p> <p>Attributes:</p> <ul> <li> <code>nodes</code>               (<code>Dict[str, NodeRecord]</code>)           \u2013            <p>The node records.</p> </li> </ul>"},{"location":"api/record.html#rex.base.EpisodeRecord.__getitem__","title":"<code>__getitem__(val: int) -&gt; EpisodeRecord</code>","text":"<p>Get the value of the episode record at a specific index.</p> <p>Parameters:</p> <ul> <li> <code>val</code>               (<code>int</code>)           \u2013            <p>the index to get the value from</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>EpisodeRecord</code>           \u2013            <p>The episode record at the specific index</p> </li> </ul>"},{"location":"api/record.html#rex.base.EpisodeRecord.filter","title":"<code>filter(nodes: Dict[str, BaseNode], filter_connections: bool = False) -&gt; EpisodeRecord</code>","text":"<p>Filter the episode record.</p> <p>Parameters:</p> <ul> <li> <code>nodes</code>               (<code>Dict[str, BaseNode]</code>)           \u2013            <p>Only keep record.nodes in of subgraph spanned by nodes.</p> </li> <li> <code>filter_connections</code>               (<code>bool</code>, default:                   <code>False</code> )           \u2013            <p>If True, only keep connections between nodes in nodes.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>EpisodeRecord</code>           \u2013            <p>A new episode record with only the nodes and connections in nodes.</p> </li> </ul>"},{"location":"api/record.html#rex.base.EpisodeRecord.to_graph","title":"<code>to_graph() -&gt; Graph</code>","text":"<p>Convert the episode record to a graph.</p> <p>Returns:</p> <ul> <li> <code>Graph</code>           \u2013            <p>The graph of the episode record.</p> </li> </ul>"},{"location":"api/record.html#rex.base.NodeInfo","title":"<code>rex.base.NodeInfo</code>","text":"<p>A data structure that holds information about the node.</p> <p>Attributes:</p> <ul> <li> <code>rate</code>               (<code>float</code>)           \u2013            <p>The rate of the node.</p> </li> <li> <code>advance</code>               (<code>bool</code>)           \u2013            <p>Whether the node advances the episode.</p> </li> <li> <code>scheduling</code>               (<code>Scheduling</code>)           \u2013            <p>The scheduling of the node.</p> </li> <li> <code>phase</code>               (<code>float</code>)           \u2013            <p>The phase of the node.</p> </li> <li> <code>delay_dist</code>               (<code>DelayDistribution</code>)           \u2013            <p>The delay distribution of the node.</p> </li> <li> <code>delay</code>               (<code>float</code>)           \u2013            <p>The delay of the node.</p> </li> <li> <code>inputs</code>               (<code>Dict[str, InputInfo]</code>)           \u2013            <p>The inputs of the node.</p> </li> <li> <code>name</code>               (<code>str</code>)           \u2013            <p>The name of the node.</p> </li> <li> <code>cls</code>               (<code>str</code>)           \u2013            <p>The class of the node.</p> </li> <li> <code>color</code>               (<code>str</code>)           \u2013            <p>The color of the node.</p> </li> <li> <code>order</code>               (<code>int</code>)           \u2013            <p>The order of the node.</p> </li> </ul>"},{"location":"api/record.html#rex.base.InputInfo","title":"<code>rex.base.InputInfo</code>","text":"<p>A data structure that holds information about the input connection.</p> <p>Attributes:</p> <ul> <li> <code>rate</code>               (<code>float</code>)           \u2013            <p>The rate of the connection.</p> </li> <li> <code>window</code>               (<code>int</code>)           \u2013            <p>The window size of the connection.</p> </li> <li> <code>blocking</code>               (<code>bool</code>)           \u2013            <p>Whether the connection is blocking.</p> </li> <li> <code>skip</code>               (<code>bool</code>)           \u2013            <p>Whether the connection is skipped.</p> </li> <li> <code>jitter</code>               (<code>Jitter</code>)           \u2013            <p>The jitter of the connection.</p> </li> <li> <code>phase</code>               (<code>float</code>)           \u2013            <p>The phase of the connection.</p> </li> <li> <code>delay_dist</code>               (<code>DelayDistribution</code>)           \u2013            <p>The delay distribution of the connection.</p> </li> <li> <code>delay</code>               (<code>float</code>)           \u2013            <p>The delay of the connection.</p> </li> <li> <code>name</code>               (<code>str</code>)           \u2013            <p>The name of the connection.</p> </li> <li> <code>output</code>               (<code>str</code>)           \u2013            <p>The name of the output node.</p> </li> </ul>"},{"location":"api/record.html#rex.base.NodeRecord","title":"<code>rex.base.NodeRecord</code>","text":"<p>A data structure that holds information about a node.</p> <p>Attributes:</p> <ul> <li> <code>info</code>               (<code>NodeInfo</code>)           \u2013            <p>The node information.</p> </li> <li> <code>clock</code>               (<code>Clock</code>)           \u2013            <p>The clock of the node.</p> </li> <li> <code>real_time_factor</code>               (<code>float</code>)           \u2013            <p>The real time factor of the node.</p> </li> <li> <code>ts_start</code>               (<code>float</code>)           \u2013            <p>The start time of the node.</p> </li> <li> <code>params</code>               (<code>Base</code>)           \u2013            <p>The parameters of the node.</p> </li> <li> <code>inputs</code>               (<code>Dict[str, InputRecord]</code>)           \u2013            <p>The input record.</p> </li> <li> <code>steps</code>               (<code>StepRecord</code>)           \u2013            <p>The step record.</p> </li> </ul>"},{"location":"api/record.html#rex.base.InputRecord","title":"<code>rex.base.InputRecord</code>","text":"<p>A data structure that holds information about the input connection.</p> <p>Attributes:</p> <ul> <li> <code>info</code>               (<code>InputInfo</code>)           \u2013            <p>The input information.</p> </li> <li> <code>messages</code>               (<code>MessageRecord</code>)           \u2013            <p>The message record.</p> </li> </ul>"},{"location":"api/record.html#rex.base.StepRecord","title":"<code>rex.base.StepRecord</code>","text":"<p>A data structure that holds information about a step.</p> <p>Attributes:</p> <ul> <li> <code>eps</code>               (<code>Union[int, ArrayLike]</code>)           \u2013            <p>The episode number.</p> </li> <li> <code>seq</code>               (<code>Union[int, ArrayLike]</code>)           \u2013            <p>The step number.</p> </li> <li> <code>ts_start</code>               (<code>Union[float, ArrayLike]</code>)           \u2013            <p>The start time of the step.</p> </li> <li> <code>ts_end</code>               (<code>Union[float, ArrayLike]</code>)           \u2013            <p>The end time of the step.</p> </li> <li> <code>delay</code>               (<code>Union[float, ArrayLike]</code>)           \u2013            <p>The delay of the step.</p> </li> <li> <code>rng</code>               (<code>Array</code>)           \u2013            <p>The random number generator.</p> </li> <li> <code>inputs</code>               (<code>InputState</code>)           \u2013            <p>The input state.</p> </li> <li> <code>state</code>               (<code>Base</code>)           \u2013            <p>The state of the node.</p> </li> <li> <code>output</code>               (<code>Base</code>)           \u2013            <p>The output of the node</p> </li> </ul>"},{"location":"api/record.html#rex.base.AsyncStepRecord","title":"<code>rex.base.AsyncStepRecord</code>","text":"<p>               Bases: <code>StepRecord</code></p> <p>A data structure that holds information about an asynchronous step.</p> <p>Attributes:</p> <ul> <li> <code>ts_scheduled</code>               (<code>Union[float, ArrayLike]</code>)           \u2013            <p>The scheduled time of the step.</p> </li> <li> <code>ts_max</code>               (<code>Union[float, ArrayLike]</code>)           \u2013            <p>The maximum time of the step.</p> </li> <li> <code>ts_end_prev</code>               (<code>Union[float, ArrayLike]</code>)           \u2013            <p>The end time of the previous step.</p> </li> <li> <code>phase</code>               (<code>Union[float, ArrayLike]</code>)           \u2013            <p>The phase of the step.</p> </li> <li> <code>phase_scheduled</code>               (<code>Union[float, ArrayLike]</code>)           \u2013            <p>The scheduled phase of the step.</p> </li> <li> <code>phase_inputs</code>               (<code>Union[float, ArrayLike]</code>)           \u2013            <p>The phase of the inputs.</p> </li> <li> <code>phase_last</code>               (<code>Union[float, ArrayLike]</code>)           \u2013            <p>The last phase of the step.</p> </li> <li> <code>sent</code>               (<code>Header</code>)           \u2013            <p>The header of the sent message.</p> </li> <li> <code>phase_overwrite</code>               (<code>Union[float, ArrayLike]</code>)           \u2013            <p>The phase overwrite.</p> </li> </ul>"},{"location":"api/record.html#rex.base.MessageRecord","title":"<code>rex.base.MessageRecord</code>","text":"<p>A data structure that holds information about a sent or received message.</p> <p>Attributes:</p> <ul> <li> <code>seq_out</code>               (<code>Union[int, ArrayLike]</code>)           \u2013            <p>The sequence number of the sent message.</p> </li> <li> <code>seq_in</code>               (<code>Union[int, ArrayLike]</code>)           \u2013            <p>The sequence number of the received message.</p> </li> <li> <code>ts_sent</code>               (<code>Union[float, ArrayLike]</code>)           \u2013            <p>The time the message was sent.</p> </li> <li> <code>ts_recv</code>               (<code>Union[float, ArrayLike]</code>)           \u2013            <p>The time the message was received.</p> </li> <li> <code>delay</code>               (<code>Union[float, ArrayLike]</code>)           \u2013            <p>The delay of the message.</p> </li> </ul>"},{"location":"api/record.html#rex.base.Header","title":"<code>rex.base.Header</code>","text":"<p>A data structure that holds the header information of a record.</p> <p>Attributes:</p> <ul> <li> <code>eps</code>               (<code>Union[int, ArrayLike]</code>)           \u2013            <p>The episode number.</p> </li> <li> <code>seq</code>               (<code>Union[int, ArrayLike]</code>)           \u2013            <p>The step number.</p> </li> <li> <code>ts</code>               (<code>Union[float, ArrayLike]</code>)           \u2013            <p>The time</p> </li> </ul>"},{"location":"api/transforms.html","title":"Transforms","text":""},{"location":"api/transforms.html#rex.base.Transform","title":"<code>rex.base.Transform</code>","text":"<p>A transformation that can be applied to parameters.</p> <p>Can be used to normalize, denormalize, or transform parameters in any way.</p>"},{"location":"api/transforms.html#rex.base.Transform.init","title":"<code>init(*args: Any, **kwargs: Any) -&gt; Transform</code>  <code>classmethod</code>","text":"<p>Initialize the transform.</p> <p>Parameters:</p> <ul> <li> <code>*args</code>               (<code>Any</code>, default:                   <code>()</code> )           \u2013            <p>The arguments to initialize the transform.</p> </li> <li> <code>**kwargs</code>               (<code>Any</code>, default:                   <code>{}</code> )           \u2013            <p>The keyword arguments to initialize the transform.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>Transform</code>           \u2013            <p>The initialized transform.</p> </li> </ul>"},{"location":"api/transforms.html#rex.base.Transform.apply","title":"<code>apply(params: Dict[str, Params]) -&gt; Dict[str, Params]</code>","text":"<p>Apply the transformation to the parameters.</p> <p>Parameters:</p> <ul> <li> <code>params</code>               (<code>Dict[str, Params]</code>)           \u2013            <p>The original parameters.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>Dict[str, Params]</code>           \u2013            <p>The transformed parameters.</p> </li> </ul>"},{"location":"api/transforms.html#rex.base.Transform.inv","title":"<code>inv(params: Dict[str, Params]) -&gt; Dict[str, Params]</code>","text":"<p>Invert the transformation.</p> <p>Parameters:</p> <ul> <li> <code>params</code>               (<code>Dict[str, Params]</code>)           \u2013            <p>The transformed parameters.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>Dict[str, Params]</code>           \u2013            <p>The original parameters.</p> </li> </ul>"},{"location":"api/transforms.html#rex.base.Denormalize","title":"<code>rex.base.Denormalize</code>","text":"<p>               Bases: <code>Transform</code></p> <p>(De)normalize the parameters to/from a [-1, 1] range.</p> <p>Attributes:</p> <ul> <li> <code>scale</code>               (<code>Params</code>)           \u2013            <p>The scale of the original parameters.</p> </li> <li> <code>offset</code>               (<code>Params</code>)           \u2013            <p>The offset of the original parameters.</p> </li> </ul>"},{"location":"api/transforms.html#rex.base.Denormalize.init","title":"<code>init(min_params: Params, max_params: Params) -&gt; Denormalize</code>  <code>classmethod</code>","text":"<p>Initialize the denormalize transformation</p> Non-zero scale is required. <p>Therefore, the min and max values should be different for each parameter.</p> <p>Parameters:</p> <ul> <li> <code>min_params</code>               (<code>Params</code>)           \u2013            <p>The minimum values of the original parameters.</p> </li> <li> <code>max_params</code>               (<code>Params</code>)           \u2013            <p>The maximum values of the original parameters.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>Denormalize</code>           \u2013            <p>The denormalize transformation</p> </li> </ul>"},{"location":"api/transforms.html#rex.base.Denormalize.apply","title":"<code>apply(params: Params) -&gt; Params</code>","text":"<p>Apply the denormalize transformation to the parameters.</p> <p>Parameters:</p> <ul> <li> <code>params</code>               (<code>Params</code>)           \u2013            <p>The normalized parameters.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>Params</code>           \u2013            <p>The denormalized parameters.</p> </li> </ul>"},{"location":"api/transforms.html#rex.base.Extend","title":"<code>rex.base.Extend</code>","text":"<p>               Bases: <code>Transform</code></p> <p>Extend the structure of a pytree with additional parameters from another pytree.</p> <p>Useful when you only want to optimize a subset of the parameters, but the full structure is required for simulation.</p> Example <pre><code>from rex.base import Extend\nbase_params = {\"a\": {\"b\": 0, \"c\": \"1\"}, \"d\": 2}\nopt_params = {\"a\": None, \"d\": 99}\n\ntransform = Extend.init(base_params, opt_params)\nextended = transform.apply(opt_params) # {\"a\": {\"b\": 0, \"c\": \"1\"}, \"b\": 99}\nfiltered = transform.inv(extended) # {\"a\": None, \"d\": 99}\n</code></pre> <p>Attributes:</p> <ul> <li> <code>base_params</code>               (<code>Params</code>)           \u2013            <p>The base parameters.</p> </li> <li> <code>mask</code>               (<code>Params</code>)           \u2013            <p>The mask of the extended parameters.</p> </li> </ul>"},{"location":"api/transforms.html#rex.base.Extend.init","title":"<code>init(base_params: Params, opt_params: Params = None) -&gt; Extend</code>  <code>classmethod</code>","text":"<p>Initialize the extend transformation.</p> <p>Parameters:</p> <ul> <li> <code>base_params</code>               (<code>Params</code>)           \u2013            <p>The base parameters.</p> </li> <li> <code>opt_params</code>               (<code>Params</code>, default:                   <code>None</code> )           \u2013            <p>The structure of the params that is going to be extended with the base parameters.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>Extend</code>           \u2013            <p>The extend transformation.</p> </li> </ul>"},{"location":"api/transforms.html#rex.base.Extend.apply","title":"<code>apply(params: Params) -&gt; Params</code>","text":"<p>Apply the extend transformation to the parameters.</p> <p>Parameters:</p> <ul> <li> <code>params</code>               (<code>Params</code>)           \u2013            <p>The parameters.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>Params</code>           \u2013            <p>The extended parameters to the structure of the base parameters.</p> </li> </ul>"},{"location":"api/transforms.html#rex.base.Chain","title":"<code>rex.base.Chain</code>","text":"<p>               Bases: <code>Transform</code></p> <p>Chain multiple transformations together.</p> <p>Attributes:</p> <ul> <li> <code>transforms</code>               (<code>Sequence[Transform]</code>)           \u2013            <p>The transformations to chain together.</p> </li> </ul>"},{"location":"api/transforms.html#rex.base.Chain.init","title":"<code>init(*transforms: Sequence[Transform]) -&gt; Chain</code>  <code>classmethod</code>","text":"<p>Initialize the chain of transformations.</p> <p>Parameters:</p> <ul> <li> <code>*transforms</code>               (<code>Sequence[Transform]</code>, default:                   <code>()</code> )           \u2013            <p>The transformations to chain together. The order of the transformations is important.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>Chain</code>           \u2013            <p>The chain of transformations.</p> </li> </ul>"},{"location":"api/transforms.html#rex.base.Chain.apply","title":"<code>apply(params: Params) -&gt; Params</code>","text":"<p>Apply the chain of transformations to the parameters.</p> <p>Parameters:</p> <ul> <li> <code>params</code>               (<code>Params</code>)           \u2013            <p>The parameters.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>Params</code>           \u2013            <p>The transformed parameters.</p> </li> </ul>"},{"location":"api/transforms.html#rex.base.Chain.inv","title":"<code>inv(params: Params) -&gt; Params</code>","text":"<p>Invert the chain of transformations.</p> <p>Parameters:</p> <ul> <li> <code>params</code>               (<code>Params</code>)           \u2013            <p>The transformed parameters.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>Params</code>           \u2013            <p>The original parameters.</p> </li> </ul>"},{"location":"api/transforms.html#rex.base.Shared","title":"<code>rex.base.Shared</code>","text":"<p>               Bases: <code>Transform</code></p> <p>A shared transformation that can be applied to parameters.</p> <p>Useful to share parameters between different parts of the model.</p> Example <pre><code>where_fn = lambda p: p[\"a\"]\nreplace_fn = lambda p: p[\"b\"]\ninverse_fn = lambda p: None\ntransform = Shared.init(where=where_fn, replace_fn=replace_fn, inverse_fn=inverse_fn)\n\nopt_params = {\"a\": 1, \"b\": 2}\napplied = transform.apply(opt_params) # {\"a\": 2, \"b\": 2}\ninverted = transform.inv(applied)  # {\"a\": None, \"b\": 2}\n</code></pre> <p>Attributes:</p> <ul> <li> <code>where</code>               (<code>Callable[[Any], Union[Any, Sequence[Any]]]</code>)           \u2013            <p>The function that determines where to apply the transformation.</p> </li> <li> <code>replace_fn</code>               (<code>Callable[[Any], Union[Any, Sequence[Any]]]</code>)           \u2013            <p>The function that replaces the parameters.</p> </li> <li> <code>inverse_fn</code>               (<code>Callable[[Any], Union[Any, Sequence[Any]]]</code>)           \u2013            <p>The function that inverts the transformation.</p> </li> </ul>"},{"location":"api/transforms.html#rex.base.Shared.init","title":"<code>init(where: Callable[[Any], Union[Any, Sequence[Any]]], replace_fn: Callable[[Any], Union[Any, Sequence[Any]]], inverse_fn: Callable[[Any], Union[Any, Sequence[Any]]] = lambda _tree: None) -&gt; Shared</code>  <code>classmethod</code>","text":"<p>Initialize the shared transformation.</p> <p>Parameters:</p> <ul> <li> <code>where</code>               (<code>Callable[[Any], Union[Any, Sequence[Any]]]</code>)           \u2013            <p>The function that determines where to apply the transformation.</p> </li> <li> <code>replace_fn</code>               (<code>Callable[[Any], Union[Any, Sequence[Any]]]</code>)           \u2013            <p>The function that replaces the parameters.</p> </li> <li> <code>inverse_fn</code>               (<code>Callable[[Any], Union[Any, Sequence[Any]]]</code>, default:                   <code>lambda _tree: None</code> )           \u2013            <p>The function that inverts the transformation.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>Shared</code>           \u2013            <p>The shared transformation.</p> </li> </ul>"},{"location":"api/transforms.html#rex.base.Shared.apply","title":"<code>apply(params: Params) -&gt; Params</code>","text":"<p>Apply the shared transformation to the parameters.</p> <p>Parameters:</p> <ul> <li> <code>params</code>               (<code>Params</code>)           \u2013            <p>The parameters.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>Params</code>           \u2013            <p>The transformed parameters.</p> </li> </ul>"},{"location":"api/transforms.html#rex.base.Identity","title":"<code>rex.base.Identity</code>","text":"<p>               Bases: <code>Transform</code></p> <p>The identity transformation (NOOP).</p>"},{"location":"api/transforms.html#rex.base.Identity.init","title":"<code>init() -&gt; Identity</code>  <code>classmethod</code>","text":"<p>Initialize the identity transformation.</p> <p>Returns:</p> <ul> <li> <code>Identity</code>           \u2013            <p>The identity transformation.</p> </li> </ul>"},{"location":"api/transforms.html#rex.base.Identity.apply","title":"<code>apply(params: Params) -&gt; Params</code>","text":"<p>Apply the identity transformation (NOOP).</p> <p>Parameters:</p> <ul> <li> <code>params</code>               (<code>Params</code>)           \u2013            <p>The parameters.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>Params</code>           \u2013            <p>The same parameters.</p> </li> </ul>"},{"location":"api/transforms.html#rex.base.Identity.inv","title":"<code>inv(params: Params) -&gt; Params</code>","text":"<p>Invert the identity transformation (NOOP).</p> <p>Parameters:</p> <ul> <li> <code>params</code>               (<code>Params</code>)           \u2013            <p>The parameters.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>Params</code>           \u2013            <p>The same parameters.</p> </li> </ul>"},{"location":"api/transforms.html#rex.base.Exponential","title":"<code>rex.base.Exponential</code>","text":"<p>               Bases: <code>Transform</code></p> <p>Apply the exponential transformation to the parameters.</p>"},{"location":"api/transforms.html#rex.base.Exponential.init","title":"<code>init() -&gt; Exponential</code>  <code>classmethod</code>","text":"<p>Create an exponential transformation.</p> <p>Returns:</p> <ul> <li> <code>Exponential</code>           \u2013            <p>The exponential transformation.</p> </li> </ul>"},{"location":"api/transforms.html#rex.base.Exponential.apply","title":"<code>apply(params: Params) -&gt; Params</code>","text":"<p>Apply the exponential transformation to the parameters.</p> <p>Parameters:</p> <ul> <li> <code>params</code>               (<code>Params</code>)           \u2013            <p>The parameters.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>Params</code>           \u2013            <p>The transformed parameters.</p> </li> </ul>"},{"location":"api/transforms.html#rex.base.Exponential.inv","title":"<code>inv(params: Params) -&gt; Params</code>","text":"<p>Invert the exponential transformation.</p> <p>Parameters:</p> <ul> <li> <code>params</code>               (<code>Params</code>)           \u2013            <p>The transformed parameters.</p> </li> </ul> <p>Returns:</p> <ul> <li> <code>Params</code>           \u2013            <p>The original parameters.</p> </li> </ul>"},{"location":"examples/graph_and_environment_creation.html","title":"Graphs and environments","text":"<pre><code># @title Install Necessary Libraries\n# @markdown This cell installs the required libraries for the project.\n# @markdown If you are running this notebook in Google Colab, most libraries should already be installed.\n\ntry:\n    import rex\n\n    print(\"Rex already installed\")\nexcept ImportError:\n    print(\n        \"Installing rex via `pip install rex-lib[examples]`. \"\n        \"If you are running this in a Colab notebook, you can ignore this message.\"\n    )\n    !pip install rex-lib[examples]\n    import rex\n\n# Check if we have a GPU\nimport itertools\n\nimport jax\n\n\ntry:\n    gpu = jax.devices(\"gpu\")\n    gpu = gpu[0] if len(gpu) &gt; 0 else None\n    print(\"GPU found!\")\nexcept RuntimeError:\n    print(\"Warning: No GPU found, falling back to CPU. Speedups will be less pronounced.\")\n    print(\n        \"Hint: if you are using Google Colab, try to change the runtime to GPU: \"\n        \"Runtime -&gt; Change runtime type -&gt; Hardware accelerator -&gt; GPU.\"\n    )\n    gpu = None\n\n# Check the number of available CPU cores\nprint(f\"CPU cores available: {len(jax.devices('cpu'))}\")\ncpus = itertools.cycle(jax.devices(\"cpu\"))\n\n# Set plot settings\nimport seaborn as sns\n\n\nsns.set()\n</code></pre> <pre>\n<code>Installing rex via `pip install rex-lib[examples]`. If you are running this in a Colab notebook, you can ignore this message.\nCollecting rex-lib[examples]\n  Downloading rex_lib-0.0.5-py3-none-any.whl.metadata (15 kB)\nCollecting dill&gt;=0.3.8 (from rex-lib[examples])\n  Downloading dill-0.3.9-py3-none-any.whl.metadata (10 kB)\nCollecting distrax&gt;=0.1.5 (from rex-lib[examples])\n  Downloading distrax-0.1.5-py3-none-any.whl.metadata (13 kB)\nCollecting equinox&gt;=0.11.4 (from rex-lib[examples])\n  Downloading equinox-0.11.7-py3-none-any.whl.metadata (18 kB)\nCollecting evosax&gt;=0.1.6 (from rex-lib[examples])\n  Downloading evosax-0.1.6-py3-none-any.whl.metadata (26 kB)\nRequirement already satisfied: flax&gt;=0.8.5 in /usr/local/lib/python3.10/dist-packages (from rex-lib[examples]) (0.8.5)\nCollecting gymnasium&gt;=0.29.1 (from rex-lib[examples])\n  Downloading gymnasium-1.0.0-py3-none-any.whl.metadata (9.5 kB)\nRequirement already satisfied: jax&gt;=0.4.30 in /usr/local/lib/python3.10/dist-packages (from rex-lib[examples]) (0.4.33)\nRequirement already satisfied: matplotlib&gt;=3.7.0 in /usr/local/lib/python3.10/dist-packages (from rex-lib[examples]) (3.7.1)\nRequirement already satisfied: networkx&gt;=3.2.1 in /usr/local/lib/python3.10/dist-packages (from rex-lib[examples]) (3.3)\nRequirement already satisfied: optax&gt;=0.2.3 in /usr/local/lib/python3.10/dist-packages (from rex-lib[examples]) (0.2.3)\nCollecting seaborn&gt;=0.13.2 (from rex-lib[examples])\n  Downloading seaborn-0.13.2-py3-none-any.whl.metadata (5.4 kB)\nCollecting supergraph&gt;=0.0.8 (from rex-lib[examples])\n  Downloading supergraph-0.0.8-py3-none-any.whl.metadata (1.2 kB)\nRequirement already satisfied: termcolor&gt;=2.4.0 in /usr/local/lib/python3.10/dist-packages (from rex-lib[examples]) (2.4.0)\nRequirement already satisfied: tqdm&gt;=4.66.4 in /usr/local/lib/python3.10/dist-packages (from rex-lib[examples]) (4.66.5)\nCollecting brax&gt;=0.10.5 (from rex-lib[examples])\n  Downloading brax-0.11.0-py3-none-any.whl.metadata (7.7 kB)\nRequirement already satisfied: absl-py in /usr/local/lib/python3.10/dist-packages (from brax&gt;=0.10.5-&gt;rex-lib[examples]) (1.4.0)\nCollecting dm-env (from brax&gt;=0.10.5-&gt;rex-lib[examples])\n  Downloading dm_env-1.6-py3-none-any.whl.metadata (966 bytes)\nRequirement already satisfied: etils in /usr/local/lib/python3.10/dist-packages (from brax&gt;=0.10.5-&gt;rex-lib[examples]) (1.9.4)\nRequirement already satisfied: flask in /usr/local/lib/python3.10/dist-packages (from brax&gt;=0.10.5-&gt;rex-lib[examples]) (2.2.5)\nCollecting flask-cors (from brax&gt;=0.10.5-&gt;rex-lib[examples])\n  Downloading Flask_Cors-5.0.0-py2.py3-none-any.whl.metadata (5.5 kB)\nRequirement already satisfied: grpcio in /usr/local/lib/python3.10/dist-packages (from brax&gt;=0.10.5-&gt;rex-lib[examples]) (1.64.1)\nRequirement already satisfied: gym in /usr/local/lib/python3.10/dist-packages (from brax&gt;=0.10.5-&gt;rex-lib[examples]) (0.25.2)\nRequirement already satisfied: jaxlib&gt;=0.4.6 in /usr/local/lib/python3.10/dist-packages (from brax&gt;=0.10.5-&gt;rex-lib[examples]) (0.4.33)\nCollecting jaxopt (from brax&gt;=0.10.5-&gt;rex-lib[examples])\n  Downloading jaxopt-0.8.3-py3-none-any.whl.metadata (2.6 kB)\nRequirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from brax&gt;=0.10.5-&gt;rex-lib[examples]) (3.1.4)\nCollecting ml-collections (from brax&gt;=0.10.5-&gt;rex-lib[examples])\n  Downloading ml_collections-0.1.1.tar.gz (77 kB)\n     \u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501 77.9/77.9 kB 2.9 MB/s eta 0:00:00\n  Preparing metadata (setup.py) ... done\nCollecting mujoco (from brax&gt;=0.10.5-&gt;rex-lib[examples])\n  Downloading mujoco-3.2.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (44 kB)\n     \u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501 44.4/44.4 kB 1.2 MB/s eta 0:00:00\nCollecting mujoco-mjx (from brax&gt;=0.10.5-&gt;rex-lib[examples])\n  Downloading mujoco_mjx-3.2.3-py3-none-any.whl.metadata (3.4 kB)\nRequirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from brax&gt;=0.10.5-&gt;rex-lib[examples]) (1.26.4)\nRequirement already satisfied: orbax-checkpoint in /usr/local/lib/python3.10/dist-packages (from brax&gt;=0.10.5-&gt;rex-lib[examples]) (0.6.4)\nRequirement already satisfied: Pillow in /usr/local/lib/python3.10/dist-packages (from brax&gt;=0.10.5-&gt;rex-lib[examples]) (10.4.0)\nCollecting pytinyrenderer (from brax&gt;=0.10.5-&gt;rex-lib[examples])\n  Downloading pytinyrenderer-0.0.14-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.2 kB)\nRequirement already satisfied: scipy in /usr/local/lib/python3.10/dist-packages (from brax&gt;=0.10.5-&gt;rex-lib[examples]) (1.13.1)\nCollecting tensorboardX (from brax&gt;=0.10.5-&gt;rex-lib[examples])\n  Downloading tensorboardX-2.6.2.2-py2.py3-none-any.whl.metadata (5.8 kB)\nCollecting trimesh (from brax&gt;=0.10.5-&gt;rex-lib[examples])\n  Downloading trimesh-4.4.9-py3-none-any.whl.metadata (18 kB)\nRequirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from brax&gt;=0.10.5-&gt;rex-lib[examples]) (4.12.2)\nRequirement already satisfied: chex&gt;=0.1.8 in /usr/local/lib/python3.10/dist-packages (from distrax&gt;=0.1.5-&gt;rex-lib[examples]) (0.1.87)\nRequirement already satisfied: tensorflow-probability&gt;=0.15.0 in /usr/local/lib/python3.10/dist-packages (from distrax&gt;=0.1.5-&gt;rex-lib[examples]) (0.24.0)\nCollecting jaxtyping&gt;=0.2.20 (from equinox&gt;=0.11.4-&gt;rex-lib[examples])\n  Downloading jaxtyping-0.2.34-py3-none-any.whl.metadata (6.4 kB)\nRequirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from evosax&gt;=0.1.6-&gt;rex-lib[examples]) (6.0.2)\nCollecting dotmap (from evosax&gt;=0.1.6-&gt;rex-lib[examples])\n  Downloading dotmap-1.3.30-py3-none-any.whl.metadata (3.2 kB)\nRequirement already satisfied: msgpack in /usr/local/lib/python3.10/dist-packages (from flax&gt;=0.8.5-&gt;rex-lib[examples]) (1.0.8)\nRequirement already satisfied: tensorstore in /usr/local/lib/python3.10/dist-packages (from flax&gt;=0.8.5-&gt;rex-lib[examples]) (0.1.66)\nRequirement already satisfied: rich&gt;=11.1 in /usr/local/lib/python3.10/dist-packages (from flax&gt;=0.8.5-&gt;rex-lib[examples]) (13.9.1)\nRequirement already satisfied: cloudpickle&gt;=1.2.0 in /usr/local/lib/python3.10/dist-packages (from gymnasium&gt;=0.29.1-&gt;rex-lib[examples]) (2.2.1)\nCollecting farama-notifications&gt;=0.0.1 (from gymnasium&gt;=0.29.1-&gt;rex-lib[examples])\n  Downloading Farama_Notifications-0.0.4-py3-none-any.whl.metadata (558 bytes)\nRequirement already satisfied: ml-dtypes&gt;=0.2.0 in /usr/local/lib/python3.10/dist-packages (from jax&gt;=0.4.30-&gt;rex-lib[examples]) (0.4.1)\nRequirement already satisfied: opt-einsum in /usr/local/lib/python3.10/dist-packages (from jax&gt;=0.4.30-&gt;rex-lib[examples]) (3.4.0)\nRequirement already satisfied: contourpy&gt;=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib&gt;=3.7.0-&gt;rex-lib[examples]) (1.3.0)\nRequirement already satisfied: cycler&gt;=0.10 in /usr/local/lib/python3.10/dist-packages (from matplotlib&gt;=3.7.0-&gt;rex-lib[examples]) (0.12.1)\nRequirement already satisfied: fonttools&gt;=4.22.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib&gt;=3.7.0-&gt;rex-lib[examples]) (4.54.1)\nRequirement already satisfied: kiwisolver&gt;=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib&gt;=3.7.0-&gt;rex-lib[examples]) (1.4.7)\nRequirement already satisfied: packaging&gt;=20.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib&gt;=3.7.0-&gt;rex-lib[examples]) (24.1)\nRequirement already satisfied: pyparsing&gt;=2.3.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib&gt;=3.7.0-&gt;rex-lib[examples]) (3.1.4)\nRequirement already satisfied: python-dateutil&gt;=2.7 in /usr/local/lib/python3.10/dist-packages (from matplotlib&gt;=3.7.0-&gt;rex-lib[examples]) (2.8.2)\nRequirement already satisfied: pandas&gt;=1.2 in /usr/local/lib/python3.10/dist-packages (from seaborn&gt;=0.13.2-&gt;rex-lib[examples]) (2.2.2)\nRequirement already satisfied: toolz&gt;=0.9.0 in /usr/local/lib/python3.10/dist-packages (from chex&gt;=0.1.8-&gt;distrax&gt;=0.1.5-&gt;rex-lib[examples]) (0.12.1)\nCollecting typeguard==2.13.3 (from jaxtyping&gt;=0.2.20-&gt;equinox&gt;=0.11.4-&gt;rex-lib[examples])\n  Downloading typeguard-2.13.3-py3-none-any.whl.metadata (3.6 kB)\nRequirement already satisfied: pytz&gt;=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas&gt;=1.2-&gt;seaborn&gt;=0.13.2-&gt;rex-lib[examples]) (2024.2)\nRequirement already satisfied: tzdata&gt;=2022.7 in /usr/local/lib/python3.10/dist-packages (from pandas&gt;=1.2-&gt;seaborn&gt;=0.13.2-&gt;rex-lib[examples]) (2024.2)\nRequirement already satisfied: six&gt;=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil&gt;=2.7-&gt;matplotlib&gt;=3.7.0-&gt;rex-lib[examples]) (1.16.0)\nRequirement already satisfied: markdown-it-py&gt;=2.2.0 in /usr/local/lib/python3.10/dist-packages (from rich&gt;=11.1-&gt;flax&gt;=0.8.5-&gt;rex-lib[examples]) (3.0.0)\nRequirement already satisfied: pygments&lt;3.0.0,&gt;=2.13.0 in /usr/local/lib/python3.10/dist-packages (from rich&gt;=11.1-&gt;flax&gt;=0.8.5-&gt;rex-lib[examples]) (2.18.0)\nRequirement already satisfied: decorator in /usr/local/lib/python3.10/dist-packages (from tensorflow-probability&gt;=0.15.0-&gt;distrax&gt;=0.1.5-&gt;rex-lib[examples]) (4.4.2)\nRequirement already satisfied: gast&gt;=0.3.2 in /usr/local/lib/python3.10/dist-packages (from tensorflow-probability&gt;=0.15.0-&gt;distrax&gt;=0.1.5-&gt;rex-lib[examples]) (0.6.0)\nRequirement already satisfied: dm-tree in /usr/local/lib/python3.10/dist-packages (from tensorflow-probability&gt;=0.15.0-&gt;distrax&gt;=0.1.5-&gt;rex-lib[examples]) (0.1.8)\nRequirement already satisfied: Werkzeug&gt;=2.2.2 in /usr/local/lib/python3.10/dist-packages (from flask-&gt;brax&gt;=0.10.5-&gt;rex-lib[examples]) (3.0.4)\nRequirement already satisfied: itsdangerous&gt;=2.0 in /usr/local/lib/python3.10/dist-packages (from flask-&gt;brax&gt;=0.10.5-&gt;rex-lib[examples]) (2.2.0)\nRequirement already satisfied: click&gt;=8.0 in /usr/local/lib/python3.10/dist-packages (from flask-&gt;brax&gt;=0.10.5-&gt;rex-lib[examples]) (8.1.7)\nRequirement already satisfied: MarkupSafe&gt;=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2-&gt;brax&gt;=0.10.5-&gt;rex-lib[examples]) (2.1.5)\nRequirement already satisfied: gym-notices&gt;=0.0.4 in /usr/local/lib/python3.10/dist-packages (from gym-&gt;brax&gt;=0.10.5-&gt;rex-lib[examples]) (0.0.8)\nRequirement already satisfied: contextlib2 in /usr/local/lib/python3.10/dist-packages (from ml-collections-&gt;brax&gt;=0.10.5-&gt;rex-lib[examples]) (21.6.0)\nCollecting glfw (from mujoco-&gt;brax&gt;=0.10.5-&gt;rex-lib[examples])\n  Downloading glfw-2.7.0-py2.py27.py3.py30.py31.py32.py33.py34.py35.py36.py37.py38-none-manylinux2014_x86_64.whl.metadata (5.4 kB)\nRequirement already satisfied: pyopengl in /usr/local/lib/python3.10/dist-packages (from mujoco-&gt;brax&gt;=0.10.5-&gt;rex-lib[examples]) (3.1.7)\nRequirement already satisfied: nest_asyncio in /usr/local/lib/python3.10/dist-packages (from orbax-checkpoint-&gt;brax&gt;=0.10.5-&gt;rex-lib[examples]) (1.6.0)\nRequirement already satisfied: protobuf in /usr/local/lib/python3.10/dist-packages (from orbax-checkpoint-&gt;brax&gt;=0.10.5-&gt;rex-lib[examples]) (3.20.3)\nRequirement already satisfied: humanize in /usr/local/lib/python3.10/dist-packages (from orbax-checkpoint-&gt;brax&gt;=0.10.5-&gt;rex-lib[examples]) (4.10.0)\nRequirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py&gt;=2.2.0-&gt;rich&gt;=11.1-&gt;flax&gt;=0.8.5-&gt;rex-lib[examples]) (0.1.2)\nRequirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from etils-&gt;brax&gt;=0.10.5-&gt;rex-lib[examples]) (2024.6.1)\nRequirement already satisfied: importlib_resources in /usr/local/lib/python3.10/dist-packages (from etils-&gt;brax&gt;=0.10.5-&gt;rex-lib[examples]) (6.4.5)\nRequirement already satisfied: zipp in /usr/local/lib/python3.10/dist-packages (from etils-&gt;brax&gt;=0.10.5-&gt;rex-lib[examples]) (3.20.2)\nDownloading brax-0.11.0-py3-none-any.whl (998 kB)\n   \u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501 998.6/998.6 kB 17.0 MB/s eta 0:00:00\nDownloading dill-0.3.9-py3-none-any.whl (119 kB)\n   \u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501 119.4/119.4 kB 6.5 MB/s eta 0:00:00\nDownloading distrax-0.1.5-py3-none-any.whl (319 kB)\n   \u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501 319.7/319.7 kB 13.9 MB/s eta 0:00:00\nDownloading equinox-0.11.7-py3-none-any.whl (178 kB)\n   \u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501 178.4/178.4 kB 9.3 MB/s eta 0:00:00\nDownloading evosax-0.1.6-py3-none-any.whl (240 kB)\n   \u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501 240.4/240.4 kB 13.8 MB/s eta 0:00:00\nDownloading gymnasium-1.0.0-py3-none-any.whl (958 kB)\n   \u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501 958.1/958.1 kB 22.7 MB/s eta 0:00:00\nDownloading seaborn-0.13.2-py3-none-any.whl (294 kB)\n   \u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501 294.9/294.9 kB 18.2 MB/s eta 0:00:00\nDownloading supergraph-0.0.8-py3-none-any.whl (65 kB)\n   \u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501 65.5/65.5 kB 2.1 MB/s eta 0:00:00\nDownloading rex_lib-0.0.5-py3-none-any.whl (115 kB)\n   \u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501 115.1/115.1 kB 8.3 MB/s eta 0:00:00\nDownloading Farama_Notifications-0.0.4-py3-none-any.whl (2.5 kB)\nDownloading jaxtyping-0.2.34-py3-none-any.whl (42 kB)\n   \u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501 42.4/42.4 kB 605.9 kB/s eta 0:00:00\nDownloading typeguard-2.13.3-py3-none-any.whl (17 kB)\nDownloading dm_env-1.6-py3-none-any.whl (26 kB)\nDownloading dotmap-1.3.30-py3-none-any.whl (11 kB)\nDownloading Flask_Cors-5.0.0-py2.py3-none-any.whl (14 kB)\nDownloading jaxopt-0.8.3-py3-none-any.whl (172 kB)\n   \u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501 172.3/172.3 kB 7.5 MB/s eta 0:00:00\nDownloading mujoco-3.2.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (6.1 MB)\n   \u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501 6.1/6.1 MB 43.5 MB/s eta 0:00:00\nDownloading mujoco_mjx-3.2.3-py3-none-any.whl (6.7 MB)\n   \u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501 6.7/6.7 MB 23.4 MB/s eta 0:00:00\nDownloading pytinyrenderer-0.0.14-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.9 MB)\n   \u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501 1.9/1.9 MB 20.8 MB/s eta 0:00:00\nDownloading tensorboardX-2.6.2.2-py2.py3-none-any.whl (101 kB)\n   \u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501 101.7/101.7 kB 3.8 MB/s eta 0:00:00\nDownloading trimesh-4.4.9-py3-none-any.whl (700 kB)\n   \u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501 700.1/700.1 kB 15.6 MB/s eta 0:00:00\nDownloading glfw-2.7.0-py2.py27.py3.py30.py31.py32.py33.py34.py35.py36.py37.py38-none-manylinux2014_x86_64.whl (211 kB)\n   \u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501 211.8/211.8 kB 15.1 MB/s eta 0:00:00\nBuilding wheels for collected packages: ml-collections\n  Building wheel for ml-collections (setup.py) ... done\n  Created wheel for ml-collections: filename=ml_collections-0.1.1-py3-none-any.whl size=94507 sha256=c2ba0db03ffefa350aba3215509ced8bfaf78a8937ceaf29ffbf3655f21c333c\n  Stored in directory: /root/.cache/pip/wheels/7b/89/c9/a9b87790789e94aadcfc393c283e3ecd5ab916aed0a31be8fe\nSuccessfully built ml-collections\nInstalling collected packages: pytinyrenderer, glfw, farama-notifications, dotmap, typeguard, trimesh, tensorboardX, supergraph, ml-collections, gymnasium, dm-env, dill, jaxtyping, seaborn, mujoco, flask-cors, mujoco-mjx, jaxopt, equinox, distrax, evosax, brax, rex-lib\n  Attempting uninstall: typeguard\n    Found existing installation: typeguard 4.3.0\n    Uninstalling typeguard-4.3.0:\n      Successfully uninstalled typeguard-4.3.0\n  Attempting uninstall: seaborn\n    Found existing installation: seaborn 0.13.1\n    Uninstalling seaborn-0.13.1:\n      Successfully uninstalled seaborn-0.13.1\nERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\ninflect 7.4.0 requires typeguard&gt;=4.0.1, but you have typeguard 2.13.3 which is incompatible.\nSuccessfully installed brax-0.11.0 dill-0.3.9 distrax-0.1.5 dm-env-1.6 dotmap-1.3.30 equinox-0.11.7 evosax-0.1.6 farama-notifications-0.0.4 flask-cors-5.0.0 glfw-2.7.0 gymnasium-1.0.0 jaxopt-0.8.3 jaxtyping-0.2.34 ml-collections-0.1.1 mujoco-3.2.3 mujoco-mjx-3.2.3 pytinyrenderer-0.0.14 rex-lib-0.0.5 seaborn-0.13.2 supergraph-0.0.8 tensorboardX-2.6.2.2 trimesh-4.4.9 typeguard-2.13.3\nGPU found!\nCPU cores available: 1\n</code>\n</pre> <pre><code># @title Defining a Graph from Nodes\n# @markdown First, you need to define the nodes that will make up your graph.\n# @markdown These nodes represent different components of a system, such as sensors, agents, actuators, and the world.\n# @markdown **Note**: The `delay_dist` parameter is used to simulate computation delays, which is useful when modeling real-world systems.\n\n# Import necessary modules and node classes\nfrom distrax import Normal\n\nimport rex.examples.pendulum as pdm\n\n\n# Instantiate nodes with their respective parameters\nsensor = pdm.SimSensor(name=\"sensor\", rate=50, color=\"pink\", order=1, delay_dist=Normal(loc=0.0075, scale=0.003))\nagent = pdm.Agent(\n    name=\"agent\", rate=50, color=\"teal\", order=3, delay_dist=Normal(loc=0.01, scale=0.003)\n)  # Computation delay of the agent\nactuator = pdm.SimActuator(\n    name=\"actuator\", rate=50, color=\"orange\", order=2, delay_dist=Normal(loc=0.0075, scale=0.003)\n)  # Computation delay of the actuator\nworld = pdm.OdeWorld(name=\"world\", rate=50, color=\"grape\", order=0)  # Brax world that simulates the pendulum\nnodes = dict(world=world, sensor=sensor, agent=agent, actuator=actuator)\n</code></pre> <pre><code># @title Connecting Nodes\n# @markdown Now, we establish connections between the nodes using the `connect` method.\n# @markdown - **`window`**: Determines how many past messages are stored and accessible in the input buffer.\n# @markdown - **`blocking`**: If `True`, the receiving node waits for the input before proceeding.\n# @markdown - **`skip`**: Used to resolve cyclic dependencies by skipping the connection when messages arrive simultaneously.\n\n# Agent receives data from the sensor\nagent.connect(\n    output_node=sensor,\n    window=3,  # Use the last three sensor messages\n    name=\"sensor\",  # Input name in the agent\n    blocking=True,  # Wait for the sensor data before proceeding\n    delay_dist=Normal(loc=0.002, scale=0.002),\n)\n\n# Actuator receives commands from the agent\nactuator.connect(\n    output_node=agent,\n    window=1,  # Use the most recent action\n    name=\"agent\",\n    blocking=True,\n    delay_dist=Normal(loc=0.002, scale=0.002),\n)\n\n# World receives actions from the actuator\nworld.connect(\n    output_node=actuator,\n    window=1,\n    name=\"actuator\",\n    # Resolve cyclic dependency world-&gt;sensor-&gt;agent-&gt;actuator-&gt;world\n    skip=True,\n    blocking=False,  # Non-blocking connection (i.e. world does not wait for actuator)\n    delay_dist=Normal(loc=0.01, scale=0.002),\n)\n\n# Sensor receives state updates from the world\nsensor.connect(\n    output_node=world,\n    window=1,\n    name=\"world\",\n    blocking=False,  # Non-blocking connection (i.e. sensor does not wait for world)\n    delay_dist=Normal(loc=0.01, scale=0.002),\n)\n</code></pre> <pre><code># @title Visualizing the System\n# @markdown You can visualize the system to understand the structure of your graph.\n\nimport matplotlib.pyplot as plt\n\nfrom rex.utils import plot_system\n\n\n# Collect node information for visualization\nnode_infos = {node.name: node.info for node in [sensor, agent, actuator, world]}\n\n# Plot the system\nfig, ax = plt.subplots(figsize=(8, 3))\nplot_system(node_infos, ax=ax)\nax.legend()\nax.set_title(\"System Structure\")\nplt.show()\n</code></pre> <pre><code># @title Creating the Graph\n# @markdown With the nodes defined and connected, we can create a graph.\n\nfrom rex.asynchronous import AsyncGraph\nfrom rex.constants import Clock, RealTimeFactor\n\n\n# Create the graph by specifying the nodes and the supervisor node\ngraph = AsyncGraph(\n    nodes=nodes,\n    supervisor=agent,\n    # Settings for simulating at fast as possible speed according to specified delays\n    clock=Clock.SIMULATED,\n    real_time_factor=RealTimeFactor.FAST_AS_POSSIBLE,\n    # Settings for simulating at real-time speed according to specified delays\n    # clock=Clock.SIMULATED, real_time_factor=RealTimeFactor.REAL_TIME,\n    # Settings for real-world deployment\n    # clock=Clock.WALL_CLOCK, real_time_factor=RealTimeFactor.REAL_TIME,\n)\n</code></pre> <pre><code># @title Initializing the Graph\n# @markdown Before starting an episode, we initialize the graph's state.\n# @markdown This prepares the graph for execution by initializing the parameters, states, and inputs of all nodes.\n# @markdown If we must initialize in a specific order, we can specify the order of node initialization.\n# @markdown We also compile ahead-of-time the step functions of all nodes to speed up execution, where we can specify the devices for each node.\n\n# Import JAX random number generator\nimport jax\n\n\n# Initialize the graph state\nrng = jax.random.PRNGKey(0)  # Optional random number generator for reproducibility\n\n# Start initialization with the agent node. This is important as the world's state\n# depends on the initial theta and thdot sampled in agent.init_state(...)\ninitial_graph_state = graph.init(rng=rng, order=[\"agent\"])\n\n# Specify what we want to record (params, state, output) for each node,\ngraph.set_record_settings(params=True, inputs=False, state=True, output=True)\n\n# Ahead-of-time compilation of all step functions\n# Compile the step functions of all nodes to speed up execution.\n# Specify the devices for each node, placing them on the CPU or GPU.\nfrom rex.constants import LogLevel\nfrom rex.utils import set_log_level\n\n\n# Place all nodes on the CPU, except the agent, which is placed on the GPU (if available)\n[set_log_level(LogLevel.DEBUG, n) for n in nodes.values()]  # Silence the log output\ndevices_step = {k: next(cpus) if k != \"agent\" or gpu is None else gpu for k in nodes}\ngraph.warmup(initial_graph_state, devices_step, jit_step=True, profile=True)  # Profile=True for profiling the step function\n</code></pre> <pre><code># @title Graph interaction (Gym-Like API)\n# @markdown We use the graph state obtained with .init() and perform step-by-step execution with .reset() and .step().\n# @markdown Finally, we get the recorded episode data for analysis.\nimport jax.numpy as jnp\nimport tqdm  # Used for progress bars\n\n\n# Starts the graph with the initial state and returns the supervisor's initial step state.\n# If nodes have specified `.startup` methods, they will be called here as well.\ngraph_state, initial_step_state = graph.reset(initial_graph_state)\nstep_state = initial_step_state  # The supervisor's step state\nfor i in tqdm.tqdm(range(300), desc=\"gather data\"):\n    # Access the last sensor message of the input buffer\n    # -1 is the most recent message, -2 the second most recent, etc. up until the window size\n    sensor_msg = step_state.inputs[\"sensor\"][-1].data  # .data grabs the pytree message object\n    action = jnp.array([0.5])  # Replace with actual action\n    output = step_state.params.to_output(action)  # Convert the action to an output message\n    # Step the graph (i.e., executes the next time step by sending the output message to the actuator node)\n    graph_state, step_state = graph.step(graph_state, step_state, output)  # Step the graph with the agent's output\ngraph.stop()  # Stops all nodes that were running asynchronously in the background\n\n# Get the episode data (params, delays, outputs, etc.)\nrecord = graph.get_record()  # Gets the records of all nodes\n</code></pre> <pre>\n<code>gather data: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 300/300 [00:08&lt;00:00, 36.60it/s]\n</code>\n</pre> <pre><code># @title Visualizing the Dataflow of an Episode\n# @markdown We can visualize the dataflow of the episode to understand the interactions between nodes.\n# @markdown The top plot shows how long each node takes to process data and forward it to the next node.\n# @markdown The bottom plot provides a graph representation that will form the basis for the computational graph used for compilation.\n# @markdown - Each vertex represents a step call of a node, and each edge represents message transmission between two nodes.\n# @markdown - Edges between consecutive steps of the same node represent the transmission of the internal state of the node.\n# @markdown - Nodes start processing after an initial phase-shift, which can be controlled in the node definition.\n\nimport supergraph  # Used for visualizing the graph\n\nimport rex.utils as rutils\n\n\n# Convert the episode data to a data flow graph\ndf = record.to_graph()\ntiming_mode = \"arrival\"  # \"arrival\" or \"usage\"\nG = rutils.to_networkx_graph(df, nodes=nodes)\nfig, axes = plt.subplots(2, 1, figsize=(12, 6))\nrutils.plot_graph(\n    G,\n    max_x=0.5,\n    ax=axes[0],\n    message_arrow_timing_mode=timing_mode,\n    edge_linewidth=1.4,\n    arrowsize=10,\n    show_labels=True,\n    height=0.6,\n    label_loc=\"center\",\n)\nsupergraph.plot_graph(G, max_x=0.5, ax=axes[1])\nfig.suptitle(\"Data flow of one episode\")\naxes[-1].set_xlabel(\"Time [s]\");\n</code></pre> <pre><code># @title Creating a Compiled Graph\n# @markdown Next, we create a compiled graph to speed up execution by pre-compiling the dataflow graph.\n# @markdown This approach requires a recording of the dataflow graph during a simulation episode.\n# @markdown By simulating an episode according to the exact same dataflow graph, we include the asynchronous effects of delays.\n\n# Initialize a graph that can be compiled and enables parallelized execution\ncgraph = rex.graph.Graph(nodes, nodes[\"agent\"], df)\n</code></pre> <pre>\n<code>Growing supergraph: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 301/301 [00:00&lt;00:00, 487.16it/s, 1/1 graphs, 1210/1210 matched (67.00% efficiency, 6 nodes (pre-filtered: 6 nodes))]\n</code>\n</pre> <pre><code># @title Simulating a Compiled Graph\n# @markdown A compiled graph has the same API as a regular graph (init, reset, step, run).\n# @markdown However, we can also simulate entire rollouts in an optimized manner.\n# @markdown Here, we simulate multiple rollouts in parallel to speed up the simulation process.\nnum_rollouts = 10_000\n\n\n# Define a function for rolling out the graph that can be compiled and executed in parallel\ndef rollout_fn(rng):\n    # Initialize graph state\n    gs = cgraph.init(rng, order=(\"agent\",))\n    # Make sure to record the states\n    gs = cgraph.init_record(gs, params=False, state=True, output=False)\n    # Run the graph for a fixed number of steps\n    gs_final = cgraph.rollout(gs)\n    # This returns a record that may only be partially filled.\n    record = gs_final.aux[\"record\"]\n    is_filled = record.nodes[\"world\"].steps.seq &gt;= 0  # Unfilled steps are marked with -1\n    return is_filled, record.nodes[\"world\"].steps.state\n\n\n# Prepare timers\ntimer_jit = rutils.timer(f\"Vectorized evaluation of {num_rollouts} rollouts | compile\", log_level=100)\ntimer_run = rutils.timer(f\"Vectorized evaluation of {num_rollouts} rollouts | rollouts\", log_level=100)\n\n# Run the rollouts in parallel\nrng, rng_rollout = jax.random.split(rng)\nrngs_rollout = jax.random.split(rng_rollout, num=num_rollouts)\nwith timer_jit:\n    rollout_fn_jv = jax.jit(jax.vmap(rollout_fn))\n    rollout_fn_jv = rollout_fn_jv.lower(rngs_rollout)\n    rollout_fn_jv = rollout_fn_jv.compile()\nwith timer_run:\n    is_filled, final_states = rollout_fn_jv(rngs_rollout)\n    final_states.th.block_until_ready()\n\n# Only keep the filled rollouts (we did not run the full duration of the computation graph)\nfinal_states = final_states[is_filled]\nprint(\n    f\"sim. eval | fps: {(num_rollouts * cgraph.max_steps) / timer_run.duration / 1e6:.0f} Million steps/s | compile: {timer_jit.duration:.2f} s | run: {timer_run.duration:.2f} s\"\n)\n</code></pre> <pre>\n<code>[434  ][MainThread               ][tracer              ][Vectorized evaluation of 10000 rollouts | compile] Elapsed: 4.8439 sec\n[434  ][MainThread               ][tracer              ][Vectorized evaluation of 10000 rollouts | rollouts] Elapsed: 0.0998 sec\nsim. eval | fps: 30 Million steps/s | compile: 4.84 s | run: 0.10 s\n</code>\n</pre> <pre><code># @title Example: Pendulum swing-up environment\n\nfrom typing import Any, Dict, Union\n\nimport jax\nimport jax.numpy as jnp\n\nfrom rex import base\nfrom rex.examples.pendulum.agent import AgentParams\nfrom rex.graph import Graph\nfrom rex.rl import BaseEnv, Box, ResetReturn, StepReturn\n\n\nclass SwingUpEnv(BaseEnv):\n    def __init__(self, graph: Graph):\n        super().__init__(graph=graph)\n        self._init_params = {}\n\n    @property\n    def max_steps(self) -&gt; Union[int, jax.typing.ArrayLike]:\n        \"\"\"Maximum number of steps in an evaluation episode\"\"\"\n        return int(3.5 * self.graph.nodes[\"agent\"].rate)\n\n    def set_params(self, params: Dict[str, Any]):\n        \"\"\"Pre-set parameters for the environment\"\"\"\n        self._init_params.update(params)\n\n    def observation_space(self, graph_state: base.GraphState) -&gt; Box:\n        cdata = self.get_observation(graph_state)\n        low = jnp.full(cdata.shape, -1e6)\n        high = jnp.full(cdata.shape, 1e6)\n        return Box(low, high, shape=cdata.shape, dtype=cdata.dtype)\n\n    def action_space(self, graph_state: base.GraphState) -&gt; Box:\n        params: AgentParams = graph_state.params[\"agent\"]\n        high = jnp.array([params.max_torque], dtype=jnp.float32)\n        return Box(-high, high, shape=high.shape, dtype=high.dtype)\n\n    def get_observation(self, graph_state: base.GraphState) -&gt; jax.Array:\n        # Flatten all inputs and state of the supervisor as the observation\n        ss = graph_state.step_state[\"agent\"]\n        params: AgentParams = ss.params\n        obs = params.get_observation(ss)\n        return obs\n\n    def reset(self, rng: jax.Array = None) -&gt; ResetReturn:\n        # Initialize the graph state\n        init_gs = self.graph.init(rng=rng, params=self._init_params, order=(\"agent\",))\n        # Run the graph until the agent node\n        gs, _ = self.graph.reset(init_gs)\n        # Get observation\n        obs = self.get_observation(gs)\n        info = {}  # No info to return\n        return gs, obs, info\n\n    def step(self, graph_state: base.GraphState, action: jax.Array) -&gt; StepReturn:\n        params: AgentParams = graph_state.params[\"agent\"]\n        # Update the agent's state (i.e. action and observation history)\n        new_agent = params.update_state(graph_state.step_state[\"agent\"], action)\n        # The loss_task (i.e. reward) is accumulated in the World node's step function\n        # Hence, we read out the loss_task from the world node and set it to 0 before stepping\n        # This is to ensure that the loss_task is only counted once\n        # Note that this is not obligatory, but it's a good way to ensure that the reward is consistent in the\n        # face of simulated asynchrounous effects.\n        new_world = graph_state.state[\"world\"].replace(loss_task=0.0)\n        # Update the states in the graph state\n        gs = graph_state.replace(state=graph_state.state.copy({\"agent\": new_agent, \"world\": new_world}))\n        # Convert action to output (i.e. the one that the Agent node outputs)\n        ss = gs.step_state[\"agent\"]\n        output = params.to_output(action)\n        # Step the graph (i.e. all nodes except the Agent node)\n        next_gs, next_ss = self.graph.step(gs, ss, output)\n        # Get observation\n        obs = self.get_observation(next_gs)\n        info = {}\n        # Read out the loss_task from the world node's state\n        reward = -graph_state.state[\"world\"].loss_task\n        # Determine if the episode is truncated\n        terminated = False  # Infinite horizon task\n        truncated = params.tmax &lt;= next_ss.ts  # Truncate if the time limit is reached\n        # Mitigate truncation of infinite horizon tasks by adding a final reward\n        # Add the steady-state solution as if the agent had stayed in the same state for the rest of the episode\n        gamma = params.gamma\n        reward_final = truncated * (1 / (1 - gamma)) * reward  # Assumes that the reward is constant after truncation\n        reward = reward + reward_final\n        return next_gs, obs, reward, terminated, truncated, info\n</code></pre> <pre><code># @title Example: Training a PPO agent\n# @markdown We can now train a PPO agent on the defined environment.\n# @markdown In fact, we do so in parallel with 5 policies to speed up training.\n\nimport functools\n\nimport rex.ppo as ppo\n\n\n# Create the environment\nenv = SwingUpEnv(cgraph)\n\n# Configure the PPO agent\nconfig = ppo.Config(\n    LR=0.0003261962464827655,\n    NUM_ENVS=128,\n    NUM_STEPS=32,\n    TOTAL_TIMESTEPS=5e6,\n    UPDATE_EPOCHS=8,\n    NUM_MINIBATCHES=16,\n    GAMMA=0.9939508937435216,\n    GAE_LAMBDA=0.9712149137900143,\n    CLIP_EPS=0.16413213812946092,\n    ENT_COEF=0.01,\n    VF_COEF=0.8015258840683805,\n    MAX_GRAD_NORM=0.9630061315073456,\n    NUM_HIDDEN_LAYERS=2,\n    NUM_HIDDEN_UNITS=64,\n    KERNEL_INIT_TYPE=\"xavier_uniform\",\n    HIDDEN_ACTIVATION=\"tanh\",\n    STATE_INDEPENDENT_STD=True,\n    SQUASH=True,\n    ANNEAL_LR=False,\n    NORMALIZE_ENV=True,\n    FIXED_INIT=True,\n    OFFSET_STEP=False,\n    NUM_EVAL_ENVS=20,\n    EVAL_FREQ=20,\n    VERBOSE=True,\n    DEBUG=False,\n)\n\n# Train 5 policies in parallel\nrng, rng_train = jax.random.split(rng)\nrngs_train = jax.random.split(rng_train, num=5)  # Train 5 policies in parallel\ntrain = functools.partial(ppo.train, env)\nwith rutils.timer(\"ppo | compile\"):\n    train_v = jax.vmap(train, in_axes=(None, 0))\n    train_vjit = jax.jit(train_v)\n    train_vjit = train_vjit.lower(config, rngs_train).compile()\nwith rutils.timer(\"ppo | train\"):\n    res = train_vjit(config, rngs_train)\n</code></pre> <pre>\n<code>[434  ][MainThread               ][tracer              ][ppo | compile       ] Elapsed: 40.6049 sec\ntrain_steps=249856 | eval_eps=20 | return=-776.5+-0.0 | length=147+-0.0 | approxkl=0.0033\ntrain_steps=249856 | eval_eps=20 | return=-725.1+-0.0 | length=147+-0.0 | approxkl=0.0038\ntrain_steps=249856 | eval_eps=20 | return=-722.7+-0.0 | length=147+-0.0 | approxkl=0.0038\ntrain_steps=249856 | eval_eps=20 | return=-1002.3+-0.0 | length=147+-0.0 | approxkl=0.0039\ntrain_steps=249856 | eval_eps=20 | return=-1490.4+-0.0 | length=147+-0.0 | approxkl=0.0037\ntrain_steps=499712 | eval_eps=20 | return=-696.3+-0.0 | length=147+-0.0 | approxkl=0.0028\ntrain_steps=499712 | eval_eps=20 | return=-702.2+-0.0 | length=147+-0.0 | approxkl=0.0031\ntrain_steps=499712 | eval_eps=20 | return=-627.2+-0.0 | length=147+-0.0 | approxkl=0.0029\ntrain_steps=499712 | eval_eps=20 | return=-613.8+-0.0 | length=147+-0.0 | approxkl=0.0029\ntrain_steps=499712 | eval_eps=20 | return=-686.7+-0.0 | length=147+-0.0 | approxkl=0.0029\ntrain_steps=749568 | eval_eps=20 | return=-630.4+-0.0 | length=147+-0.0 | approxkl=0.0031\ntrain_steps=749568 | eval_eps=20 | return=-648.7+-0.0 | length=147+-0.0 | approxkl=0.0027\ntrain_steps=749568 | eval_eps=20 | return=-685.1+-0.0 | length=147+-0.0 | approxkl=0.0029\ntrain_steps=749568 | eval_eps=20 | return=-619.9+-0.0 | length=147+-0.0 | approxkl=0.0028\ntrain_steps=749568 | eval_eps=20 | return=-1561.8+-0.0 | length=147+-0.0 | approxkl=0.0027\ntrain_steps=999424 | eval_eps=20 | return=-761.9+-0.0 | length=147+-0.0 | approxkl=0.0027\ntrain_steps=999424 | eval_eps=20 | return=-637.3+-0.0 | length=147+-0.0 | approxkl=0.0025\ntrain_steps=999424 | eval_eps=20 | return=-681.2+-0.0 | length=147+-0.0 | approxkl=0.0029\ntrain_steps=999424 | eval_eps=20 | return=-674.9+-0.0 | length=147+-0.0 | approxkl=0.0026\ntrain_steps=999424 | eval_eps=20 | return=-1585.7+-0.0 | length=147+-0.0 | approxkl=0.0025\ntrain_steps=1249280 | eval_eps=20 | return=-752.2+-0.0 | length=147+-0.0 | approxkl=0.0028\ntrain_steps=1249280 | eval_eps=20 | return=-699.8+-0.0 | length=147+-0.0 | approxkl=0.0026\ntrain_steps=1249280 | eval_eps=20 | return=-1091.1+-0.0 | length=147+-0.0 | approxkl=0.0031\ntrain_steps=1249280 | eval_eps=20 | return=-650.8+-0.0 | length=147+-0.0 | approxkl=0.0027\ntrain_steps=1249280 | eval_eps=20 | return=-1087.9+-0.0 | length=147+-0.0 | approxkl=0.0024\ntrain_steps=1499136 | eval_eps=20 | return=-548.7+-0.0 | length=147+-0.0 | approxkl=0.0028\ntrain_steps=1499136 | eval_eps=20 | return=-844.7+-0.0 | length=147+-0.0 | approxkl=0.0025\ntrain_steps=1499136 | eval_eps=20 | return=-657.3+-0.0 | length=147+-0.0 | approxkl=0.0029\ntrain_steps=1499136 | eval_eps=20 | return=-591.9+-0.0 | length=147+-0.0 | approxkl=0.0025\ntrain_steps=1499136 | eval_eps=20 | return=-645.2+-0.0 | length=147+-0.0 | approxkl=0.0024\ntrain_steps=1748992 | eval_eps=20 | return=-590.2+-0.0 | length=147+-0.0 | approxkl=0.0026\ntrain_steps=1748992 | eval_eps=20 | return=-851.7+-0.0 | length=147+-0.0 | approxkl=0.0027\ntrain_steps=1748992 | eval_eps=20 | return=-639.6+-0.0 | length=147+-0.0 | approxkl=0.0031\ntrain_steps=1748992 | eval_eps=20 | return=-554.2+-0.0 | length=147+-0.0 | approxkl=0.0026\ntrain_steps=1748992 | eval_eps=20 | return=-664.0+-0.0 | length=147+-0.0 | approxkl=0.0027\ntrain_steps=1998848 | eval_eps=20 | return=-638.6+-0.0 | length=147+-0.0 | approxkl=0.0027\ntrain_steps=1998848 | eval_eps=20 | return=-662.4+-0.0 | length=147+-0.0 | approxkl=0.0026\ntrain_steps=1998848 | eval_eps=20 | return=-690.1+-0.0 | length=147+-0.0 | approxkl=0.0032\ntrain_steps=1998848 | eval_eps=20 | return=-1450.2+-0.0 | length=147+-0.0 | approxkl=0.0025\ntrain_steps=1998848 | eval_eps=20 | return=-961.3+-0.0 | length=147+-0.0 | approxkl=0.0026\ntrain_steps=2248704 | eval_eps=20 | return=-687.5+-0.0 | length=147+-0.0 | approxkl=0.0026\ntrain_steps=2248704 | eval_eps=20 | return=-561.8+-0.0 | length=147+-0.0 | approxkl=0.0025\ntrain_steps=2248704 | eval_eps=20 | return=-508.9+-0.0 | length=147+-0.0 | approxkl=0.0034\ntrain_steps=2248704 | eval_eps=20 | return=-559.8+-0.0 | length=147+-0.0 | approxkl=0.0027\ntrain_steps=2248704 | eval_eps=20 | return=-604.6+-0.0 | length=147+-0.0 | approxkl=0.0027\ntrain_steps=2498560 | eval_eps=20 | return=-1182.2+-0.0 | length=147+-0.0 | approxkl=0.0026\ntrain_steps=2498560 | eval_eps=20 | return=-630.9+-0.0 | length=147+-0.0 | approxkl=0.0025\ntrain_steps=2498560 | eval_eps=20 | return=-720.0+-0.0 | length=147+-0.0 | approxkl=0.0031\ntrain_steps=2498560 | eval_eps=20 | return=-569.4+-0.0 | length=147+-0.0 | approxkl=0.0026\ntrain_steps=2498560 | eval_eps=20 | return=-419.6+-0.0 | length=147+-0.0 | approxkl=0.0034\ntrain_steps=2748416 | eval_eps=20 | return=-567.8+-0.0 | length=147+-0.0 | approxkl=0.0026\ntrain_steps=2748416 | eval_eps=20 | return=-552.8+-0.0 | length=147+-0.0 | approxkl=0.0026\ntrain_steps=2748416 | eval_eps=20 | return=-626.8+-0.0 | length=147+-0.0 | approxkl=0.0032\ntrain_steps=2748416 | eval_eps=20 | return=-563.6+-0.0 | length=147+-0.0 | approxkl=0.0026\ntrain_steps=2748416 | eval_eps=20 | return=-385.8+-0.0 | length=147+-0.0 | approxkl=0.0038\ntrain_steps=2998272 | eval_eps=20 | return=-1553.7+-0.0 | length=147+-0.0 | approxkl=0.0025\ntrain_steps=2998272 | eval_eps=20 | return=-765.8+-0.0 | length=147+-0.0 | approxkl=0.0027\ntrain_steps=2998272 | eval_eps=20 | return=-623.3+-0.0 | length=147+-0.0 | approxkl=0.0033\ntrain_steps=2998272 | eval_eps=20 | return=-688.6+-0.0 | length=147+-0.0 | approxkl=0.0029\ntrain_steps=2998272 | eval_eps=20 | return=-387.5+-0.0 | length=147+-0.0 | approxkl=0.0042\ntrain_steps=3248128 | eval_eps=20 | return=-614.8+-0.0 | length=147+-0.0 | approxkl=0.0025\ntrain_steps=3248128 | eval_eps=20 | return=-495.8+-0.0 | length=147+-0.0 | approxkl=0.0028\ntrain_steps=3248128 | eval_eps=20 | return=-569.3+-0.0 | length=147+-0.0 | approxkl=0.0034\ntrain_steps=3248128 | eval_eps=20 | return=-607.6+-0.0 | length=147+-0.0 | approxkl=0.0029\ntrain_steps=3248128 | eval_eps=20 | return=-395.1+-0.0 | length=147+-0.0 | approxkl=0.0044\ntrain_steps=3497984 | eval_eps=20 | return=-559.6+-0.0 | length=147+-0.0 | approxkl=0.0027\ntrain_steps=3497984 | eval_eps=20 | return=-861.6+-0.0 | length=147+-0.0 | approxkl=0.0027\ntrain_steps=3497984 | eval_eps=20 | return=-578.0+-0.0 | length=147+-0.0 | approxkl=0.0033\ntrain_steps=3497984 | eval_eps=20 | return=-613.4+-0.0 | length=147+-0.0 | approxkl=0.0029\ntrain_steps=3497984 | eval_eps=20 | return=-382.8+-0.0 | length=147+-0.0 | approxkl=0.0044\ntrain_steps=3747840 | eval_eps=20 | return=-575.2+-0.0 | length=147+-0.0 | approxkl=0.0027\ntrain_steps=3747840 | eval_eps=20 | return=-518.3+-0.0 | length=147+-0.0 | approxkl=0.0027\ntrain_steps=3747840 | eval_eps=20 | return=-565.4+-0.0 | length=147+-0.0 | approxkl=0.0032\ntrain_steps=3747840 | eval_eps=20 | return=-679.5+-0.0 | length=147+-0.0 | approxkl=0.0030\ntrain_steps=3747840 | eval_eps=20 | return=-379.2+-0.0 | length=147+-0.0 | approxkl=0.0047\ntrain_steps=3997696 | eval_eps=20 | return=-544.8+-0.0 | length=147+-0.0 | approxkl=0.0027\ntrain_steps=3997696 | eval_eps=20 | return=-556.3+-0.0 | length=147+-0.0 | approxkl=0.0028\ntrain_steps=3997696 | eval_eps=20 | return=-547.6+-0.0 | length=147+-0.0 | approxkl=0.0034\ntrain_steps=3997696 | eval_eps=20 | return=-955.2+-0.0 | length=147+-0.0 | approxkl=0.0030\ntrain_steps=3997696 | eval_eps=20 | return=-376.4+-0.0 | length=147+-0.0 | approxkl=0.0050\ntrain_steps=4247552 | eval_eps=20 | return=-595.6+-0.0 | length=147+-0.0 | approxkl=0.0028\ntrain_steps=4247552 | eval_eps=20 | return=-552.1+-0.0 | length=147+-0.0 | approxkl=0.0028\ntrain_steps=4247552 | eval_eps=20 | return=-1352.2+-0.0 | length=147+-0.0 | approxkl=0.0034\ntrain_steps=4247552 | eval_eps=20 | return=-552.7+-0.0 | length=147+-0.0 | approxkl=0.0029\ntrain_steps=4247552 | eval_eps=20 | return=-376.0+-0.0 | length=147+-0.0 | approxkl=0.0052\ntrain_steps=4497408 | eval_eps=20 | return=-549.1+-0.0 | length=147+-0.0 | approxkl=0.0029\ntrain_steps=4497408 | eval_eps=20 | return=-568.6+-0.0 | length=147+-0.0 | approxkl=0.0030\ntrain_steps=4497408 | eval_eps=20 | return=-453.7+-0.0 | length=147+-0.0 | approxkl=0.0035\ntrain_steps=4497408 | eval_eps=20 | return=-611.0+-0.0 | length=147+-0.0 | approxkl=0.0028\ntrain_steps=4497408 | eval_eps=20 | return=-374.9+-0.0 | length=147+-0.0 | approxkl=0.0057\ntrain_steps=4747264 | eval_eps=20 | return=-561.1+-0.0 | length=147+-0.0 | approxkl=0.0030\ntrain_steps=4747264 | eval_eps=20 | return=-480.1+-0.0 | length=147+-0.0 | approxkl=0.0030\ntrain_steps=4747264 | eval_eps=20 | return=-1391.0+-0.0 | length=147+-0.0 | approxkl=0.0036\ntrain_steps=4747264 | eval_eps=20 | return=-645.8+-0.0 | length=147+-0.0 | approxkl=0.0030\ntrain_steps=4747264 | eval_eps=20 | return=-372.7+-0.0 | length=147+-0.0 | approxkl=0.0068\ntrain_steps=4997120 | eval_eps=20 | return=-538.1+-0.0 | length=147+-0.0 | approxkl=0.0030\ntrain_steps=4997120 | eval_eps=20 | return=-650.3+-0.0 | length=147+-0.0 | approxkl=0.0030\ntrain_steps=4997120 | eval_eps=20 | return=-775.1+-0.0 | length=147+-0.0 | approxkl=0.0035\ntrain_steps=4997120 | eval_eps=20 | return=-564.2+-0.0 | length=147+-0.0 | approxkl=0.0033\ntrain_steps=4997120 | eval_eps=20 | return=-370.8+-0.0 | length=147+-0.0 | approxkl=0.0083\n[434  ][MainThread               ][tracer              ][ppo | train         ] Elapsed: 40.3077 sec\n</code>\n</pre> <pre><code># @title Visualize PPO Training Progress\n# @markdown The plots below show the training progress of the PPO algorithm in terms of returns and policy KL divergence.\n# @markdown Note that we are not solving the swing-up task here, but rather demonstrating the training process.\n# @markdown See `sim2real.ipynb` for a complete example of solving the swing-up task using PPO.\n\nfig_ppo, axes_ppo = plt.subplots(1, 2, figsize=(8, 3))\ntotal_steps = res.metrics[\"train/total_steps\"].transpose()\nmean, std = res.metrics[\"eval/mean_returns\"].transpose(), res.metrics[\"eval/std_returns\"].transpose()\naxes_ppo[0].plot(total_steps, mean, label=\"mean\")\naxes_ppo[0].set_title(\"Returns\")\naxes_ppo[0].set_xlabel(\"Total steps\")\naxes_ppo[0].set_ylabel(\"Cum. return\")\nmean, std = res.metrics[\"train/mean_approxkl\"].transpose(), res.metrics[\"train/std_approxkl\"].transpose()\naxes_ppo[1].plot(total_steps, mean, label=\"mean\")\naxes_ppo[1].set_title(\"Policy KL\")\naxes_ppo[1].set_xlabel(\"Total steps\")\naxes_ppo[1].set_ylabel(\"Approx. kl\");\n</code></pre> <pre><code>\n</code></pre>"},{"location":"examples/graph_and_environment_creation.html#defining-graphs-and-environments-in-rex-robotic-environments-with-jax","title":"Defining Graphs and Environments in rex (Robotic Environments with jaX)","text":"<p>This notebook offers an introductory tutorial for rex (Robotic Environments with jaX), a JAX-based framework for creating graph-based environments tailored for sim2real robotics.</p> <p>In this tutorial, we will guide you through the process of defining graphs and environments. Specifically, we will demonstrate how to define the nodes and the training environment used in the sim2real.ipynb notebook.</p>"},{"location":"examples/graph_and_environment_creation.html#introduction-to-graphs-and-environments-in-rex","title":"Introduction to Graphs and Environments in Rex","text":"<p>In Rex, a graph represents the interconnected structure of nodes, defining how data flows and computations are organized within a system. By assembling nodes into a graph, you can model complex systems that reflect real-world interactions or simulations. This section introduces how to define a graph using a set of nodes, interact with it using various APIs, understand the role of the supervisor node, and specify environments that interact with the graph.</p>"},{"location":"examples/graph_and_environment_creation.html#graphs-in-rex-simulated-and-wall_clock-runtimes","title":"Graphs in Rex (<code>SIMULATED</code> and <code>WALL_CLOCK</code> runtimes)","text":"<p>In Rex, a graph is created by connecting nodes to define the flow of data and execution between them. A graph serves as the backbone for modeling systems that involve multiple interacting components, such as sensors, actuators, and agents.</p>"},{"location":"examples/graph_and_environment_creation.html#key-components-of-a-graph","title":"Key Components of a Graph","text":"<ul> <li><code>nodes</code>: The nodes that form the building blocks of the graph, each performing specific tasks like sensing, acting, or controlling.</li> <li><code>supervisor</code>: A designated node that determines the step-by-step progression of the graph (more details in the next section).</li> <li><code>clock</code>: Determines how time is managed in the graph. Choices include <code>Clock.SIMULATED</code> for virtual simulations and <code>Clock.WALL_CLOCK</code> for real-time applications.</li> <li><code>real_time_factor</code>: Sets the speed of the simulation. It can simulate as fast as possible (<code>RealTimeFactor.FAST_AS_POSSIBLE</code>), in real-time (<code>RealTimeFactor.REAL_TIME</code>), or at any custom speed relative to real-time.</li> </ul>"},{"location":"examples/graph_and_environment_creation.html#real-time-and-simulated-clocks","title":"Real-Time and Simulated Clocks","text":"<p>Rex provides flexible control over the simulation's timing through two main clock types and the ability to adjust the real-time factor: 1. <code>Clock.SIMULATED</code>: The simulation advances based on the specified delays between nodes. This mode is ideal for running simulations in a controlled environment. 2. <code>Clock.WALL_CLOCK</code>: The graph progresses based on real-world time. This mode is essential for real-time systems and deployments.</p>"},{"location":"examples/graph_and_environment_creation.html#controlling-simulation-speed-with-real_time_factor","title":"Controlling Simulation Speed with <code>real_time_factor</code>","text":"<p>The <code>real_time_factor</code> modifies the simulation speed: - <code>RealTimeFactor.FAST_AS_POSSIBLE</code>: Simulates as quickly as the system allows, constrained only by computational limits. - <code>RealTimeFactor.REAL_TIME</code>: Simulates in real-time, matching the speed of real-world processes. Combine with <code>Clock.WALL_CLOCK</code> for real-time applications. - Custom Speed: Any positive float value allows for custom speeds relative to real-time.</p>"},{"location":"examples/graph_and_environment_creation.html#the-role-of-the-supervisor-node","title":"The Role of the Supervisor Node","text":"<p>A critical aspect of graph design in Rex is selecting a supervisor node, which dictates the execution flow. The supervisor node plays a pivotal role in controlling the step-by-step progression of the graph and can alter the perspective from which the system is viewed.</p> <p>As a mental model, it helps to think of the graph as dividing the nodes into two groups: 1. Supervisor Node: The designated node that controls the graph's execution flow. 2. All Other Nodes: These nodes form the environment the supervisor interacts with.</p> <p>This partitioning of nodes essentially creates an agent-environment interface, where the supervisor node acts as the agent, and the remaining nodes represent the environment. The graph provides gym-like <code>.reset</code> and <code>.step</code> methods that mirror reinforcement learning interfaces: - <code>.reset</code>: Initializes the system and returns the initial observation as would be seen by the supervisor node. - <code>.step</code>: Advances the simulation by one step (i.e. steps all nodes except the supervisor) and returns the next observation.</p> <p>The beauty of this design lies in its flexibility. By selecting different supervisor nodes, you can create learning environments from varying perspectives: - Agent as Supervisor: Forms a traditional reinforcement learning environment. - Sensor as Supervisor: Creates an interface where the <code>.reset</code> and <code>.step</code> methods return the sensor's inputs, simulating the I/O process from the sensor's viewpoint.</p>"},{"location":"examples/graph_and_environment_creation.html#interacting-with-the-graph","title":"Interacting with the Graph","text":"<p>After creating the graph, we can interact with it using the provided APIs to initialize, reset, and step through the graph.</p>"},{"location":"examples/graph_and_environment_creation.html#defining-an-environment","title":"Defining an Environment","text":"<p>To integrate your graph within a reinforcement learning environment or other systems, define an environment class that interacts with the graph. RL algorithms such as the one defined in <code>rex.ppo</code> requires an environment that implements the following methods:</p>"},{"location":"examples/graph_and_environment_creation.html#implementing-the-environment-class","title":"Implementing the Environment Class","text":"<ul> <li><code>observation_space</code>: Describes the observation space of the environment.</li> <li><code>action_space</code>: Describes the action space of the environment.</li> <li><code>max_steps</code>: The maximum number of step the environment can run (i.e. episode length). When using a compiled Graph, this is constrained by the length of the recorded episode.</li> <li><code>reset</code>: Prepares the environment for a new episode by initializing and resetting the graph.</li> <li><code>step</code>: Advances the environment by one timestep, applying the provided action and returning the new observation and reward.</li> </ul>"},{"location":"examples/node_definitions.html","title":"How to define nodes","text":"<pre><code># @title Install Necessary Libraries\n# @markdown This cell installs the required libraries for the project.\n# @markdown If you are running this notebook in Google Colab, most libraries should already be installed.\n\ntry:\n    import rex  # noqa: F401\n\n    print(\"Rex already installed\")\nexcept ImportError:\n    print(\n        \"Installing rex via `pip install rex-lib[examples]`. \"\n        \"If you are running this in a Colab notebook, you can ignore this message.\"\n    )\n    !pip install rex-lib[examples]\n</code></pre> <pre>\n<code>Installing rex via `pip install rex-lib[examples]`. If you are running this in a Colab notebook, you can ignore this message.\nCollecting rex-lib[examples]\n  Downloading rex_lib-0.0.5-py3-none-any.whl.metadata (15 kB)\nCollecting dill&gt;=0.3.8 (from rex-lib[examples])\n  Downloading dill-0.3.9-py3-none-any.whl.metadata (10 kB)\nCollecting distrax&gt;=0.1.5 (from rex-lib[examples])\n  Downloading distrax-0.1.5-py3-none-any.whl.metadata (13 kB)\nCollecting equinox&gt;=0.11.4 (from rex-lib[examples])\n  Downloading equinox-0.11.7-py3-none-any.whl.metadata (18 kB)\nCollecting evosax&gt;=0.1.6 (from rex-lib[examples])\n  Downloading evosax-0.1.6-py3-none-any.whl.metadata (26 kB)\nRequirement already satisfied: flax&gt;=0.8.5 in /usr/local/lib/python3.10/dist-packages (from rex-lib[examples]) (0.8.5)\nCollecting gymnasium&gt;=0.29.1 (from rex-lib[examples])\n  Downloading gymnasium-1.0.0-py3-none-any.whl.metadata (9.5 kB)\nRequirement already satisfied: jax&gt;=0.4.30 in /usr/local/lib/python3.10/dist-packages (from rex-lib[examples]) (0.4.33)\nRequirement already satisfied: matplotlib&gt;=3.7.0 in /usr/local/lib/python3.10/dist-packages (from rex-lib[examples]) (3.7.1)\nRequirement already satisfied: networkx&gt;=3.2.1 in /usr/local/lib/python3.10/dist-packages (from rex-lib[examples]) (3.3)\nRequirement already satisfied: optax&gt;=0.2.3 in /usr/local/lib/python3.10/dist-packages (from rex-lib[examples]) (0.2.3)\nCollecting seaborn&gt;=0.13.2 (from rex-lib[examples])\n  Downloading seaborn-0.13.2-py3-none-any.whl.metadata (5.4 kB)\nCollecting supergraph&gt;=0.0.8 (from rex-lib[examples])\n  Downloading supergraph-0.0.8-py3-none-any.whl.metadata (1.2 kB)\nRequirement already satisfied: termcolor&gt;=2.4.0 in /usr/local/lib/python3.10/dist-packages (from rex-lib[examples]) (2.4.0)\nRequirement already satisfied: tqdm&gt;=4.66.4 in /usr/local/lib/python3.10/dist-packages (from rex-lib[examples]) (4.66.5)\nCollecting brax&gt;=0.10.5 (from rex-lib[examples])\n  Downloading brax-0.11.0-py3-none-any.whl.metadata (7.7 kB)\nRequirement already satisfied: absl-py in /usr/local/lib/python3.10/dist-packages (from brax&gt;=0.10.5-&gt;rex-lib[examples]) (1.4.0)\nCollecting dm-env (from brax&gt;=0.10.5-&gt;rex-lib[examples])\n  Downloading dm_env-1.6-py3-none-any.whl.metadata (966 bytes)\nRequirement already satisfied: etils in /usr/local/lib/python3.10/dist-packages (from brax&gt;=0.10.5-&gt;rex-lib[examples]) (1.9.4)\nRequirement already satisfied: flask in /usr/local/lib/python3.10/dist-packages (from brax&gt;=0.10.5-&gt;rex-lib[examples]) (2.2.5)\nCollecting flask-cors (from brax&gt;=0.10.5-&gt;rex-lib[examples])\n  Downloading Flask_Cors-5.0.0-py2.py3-none-any.whl.metadata (5.5 kB)\nRequirement already satisfied: grpcio in /usr/local/lib/python3.10/dist-packages (from brax&gt;=0.10.5-&gt;rex-lib[examples]) (1.64.1)\nRequirement already satisfied: gym in /usr/local/lib/python3.10/dist-packages (from brax&gt;=0.10.5-&gt;rex-lib[examples]) (0.25.2)\nRequirement already satisfied: jaxlib&gt;=0.4.6 in /usr/local/lib/python3.10/dist-packages (from brax&gt;=0.10.5-&gt;rex-lib[examples]) (0.4.33)\nCollecting jaxopt (from brax&gt;=0.10.5-&gt;rex-lib[examples])\n  Downloading jaxopt-0.8.3-py3-none-any.whl.metadata (2.6 kB)\nRequirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from brax&gt;=0.10.5-&gt;rex-lib[examples]) (3.1.4)\nCollecting ml-collections (from brax&gt;=0.10.5-&gt;rex-lib[examples])\n  Downloading ml_collections-0.1.1.tar.gz (77 kB)\n     \u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501 77.9/77.9 kB 1.8 MB/s eta 0:00:00\n  Preparing metadata (setup.py) ... done\nCollecting mujoco (from brax&gt;=0.10.5-&gt;rex-lib[examples])\n  Downloading mujoco-3.2.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (44 kB)\n     \u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501 44.4/44.4 kB 1.2 MB/s eta 0:00:00\nCollecting mujoco-mjx (from brax&gt;=0.10.5-&gt;rex-lib[examples])\n  Downloading mujoco_mjx-3.2.3-py3-none-any.whl.metadata (3.4 kB)\nRequirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from brax&gt;=0.10.5-&gt;rex-lib[examples]) (1.26.4)\nRequirement already satisfied: orbax-checkpoint in /usr/local/lib/python3.10/dist-packages (from brax&gt;=0.10.5-&gt;rex-lib[examples]) (0.6.4)\nRequirement already satisfied: Pillow in /usr/local/lib/python3.10/dist-packages (from brax&gt;=0.10.5-&gt;rex-lib[examples]) (10.4.0)\nCollecting pytinyrenderer (from brax&gt;=0.10.5-&gt;rex-lib[examples])\n  Downloading pytinyrenderer-0.0.14-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.2 kB)\nRequirement already satisfied: scipy in /usr/local/lib/python3.10/dist-packages (from brax&gt;=0.10.5-&gt;rex-lib[examples]) (1.13.1)\nCollecting tensorboardX (from brax&gt;=0.10.5-&gt;rex-lib[examples])\n  Downloading tensorboardX-2.6.2.2-py2.py3-none-any.whl.metadata (5.8 kB)\nCollecting trimesh (from brax&gt;=0.10.5-&gt;rex-lib[examples])\n  Downloading trimesh-4.4.9-py3-none-any.whl.metadata (18 kB)\nRequirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from brax&gt;=0.10.5-&gt;rex-lib[examples]) (4.12.2)\nRequirement already satisfied: chex&gt;=0.1.8 in /usr/local/lib/python3.10/dist-packages (from distrax&gt;=0.1.5-&gt;rex-lib[examples]) (0.1.87)\nRequirement already satisfied: tensorflow-probability&gt;=0.15.0 in /usr/local/lib/python3.10/dist-packages (from distrax&gt;=0.1.5-&gt;rex-lib[examples]) (0.24.0)\nCollecting jaxtyping&gt;=0.2.20 (from equinox&gt;=0.11.4-&gt;rex-lib[examples])\n  Downloading jaxtyping-0.2.34-py3-none-any.whl.metadata (6.4 kB)\nRequirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from evosax&gt;=0.1.6-&gt;rex-lib[examples]) (6.0.2)\nCollecting dotmap (from evosax&gt;=0.1.6-&gt;rex-lib[examples])\n  Downloading dotmap-1.3.30-py3-none-any.whl.metadata (3.2 kB)\nRequirement already satisfied: msgpack in /usr/local/lib/python3.10/dist-packages (from flax&gt;=0.8.5-&gt;rex-lib[examples]) (1.0.8)\nRequirement already satisfied: tensorstore in /usr/local/lib/python3.10/dist-packages (from flax&gt;=0.8.5-&gt;rex-lib[examples]) (0.1.66)\nRequirement already satisfied: rich&gt;=11.1 in /usr/local/lib/python3.10/dist-packages (from flax&gt;=0.8.5-&gt;rex-lib[examples]) (13.9.1)\nRequirement already satisfied: cloudpickle&gt;=1.2.0 in /usr/local/lib/python3.10/dist-packages (from gymnasium&gt;=0.29.1-&gt;rex-lib[examples]) (2.2.1)\nCollecting farama-notifications&gt;=0.0.1 (from gymnasium&gt;=0.29.1-&gt;rex-lib[examples])\n  Downloading Farama_Notifications-0.0.4-py3-none-any.whl.metadata (558 bytes)\nRequirement already satisfied: ml-dtypes&gt;=0.2.0 in /usr/local/lib/python3.10/dist-packages (from jax&gt;=0.4.30-&gt;rex-lib[examples]) (0.4.1)\nRequirement already satisfied: opt-einsum in /usr/local/lib/python3.10/dist-packages (from jax&gt;=0.4.30-&gt;rex-lib[examples]) (3.4.0)\nRequirement already satisfied: contourpy&gt;=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib&gt;=3.7.0-&gt;rex-lib[examples]) (1.3.0)\nRequirement already satisfied: cycler&gt;=0.10 in /usr/local/lib/python3.10/dist-packages (from matplotlib&gt;=3.7.0-&gt;rex-lib[examples]) (0.12.1)\nRequirement already satisfied: fonttools&gt;=4.22.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib&gt;=3.7.0-&gt;rex-lib[examples]) (4.54.1)\nRequirement already satisfied: kiwisolver&gt;=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib&gt;=3.7.0-&gt;rex-lib[examples]) (1.4.7)\nRequirement already satisfied: packaging&gt;=20.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib&gt;=3.7.0-&gt;rex-lib[examples]) (24.1)\nRequirement already satisfied: pyparsing&gt;=2.3.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib&gt;=3.7.0-&gt;rex-lib[examples]) (3.1.4)\nRequirement already satisfied: python-dateutil&gt;=2.7 in /usr/local/lib/python3.10/dist-packages (from matplotlib&gt;=3.7.0-&gt;rex-lib[examples]) (2.8.2)\nRequirement already satisfied: pandas&gt;=1.2 in /usr/local/lib/python3.10/dist-packages (from seaborn&gt;=0.13.2-&gt;rex-lib[examples]) (2.2.2)\nRequirement already satisfied: toolz&gt;=0.9.0 in /usr/local/lib/python3.10/dist-packages (from chex&gt;=0.1.8-&gt;distrax&gt;=0.1.5-&gt;rex-lib[examples]) (0.12.1)\nCollecting typeguard==2.13.3 (from jaxtyping&gt;=0.2.20-&gt;equinox&gt;=0.11.4-&gt;rex-lib[examples])\n  Downloading typeguard-2.13.3-py3-none-any.whl.metadata (3.6 kB)\nRequirement already satisfied: pytz&gt;=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas&gt;=1.2-&gt;seaborn&gt;=0.13.2-&gt;rex-lib[examples]) (2024.2)\nRequirement already satisfied: tzdata&gt;=2022.7 in /usr/local/lib/python3.10/dist-packages (from pandas&gt;=1.2-&gt;seaborn&gt;=0.13.2-&gt;rex-lib[examples]) (2024.2)\nRequirement already satisfied: six&gt;=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil&gt;=2.7-&gt;matplotlib&gt;=3.7.0-&gt;rex-lib[examples]) (1.16.0)\nRequirement already satisfied: markdown-it-py&gt;=2.2.0 in /usr/local/lib/python3.10/dist-packages (from rich&gt;=11.1-&gt;flax&gt;=0.8.5-&gt;rex-lib[examples]) (3.0.0)\nRequirement already satisfied: pygments&lt;3.0.0,&gt;=2.13.0 in /usr/local/lib/python3.10/dist-packages (from rich&gt;=11.1-&gt;flax&gt;=0.8.5-&gt;rex-lib[examples]) (2.18.0)\nRequirement already satisfied: decorator in /usr/local/lib/python3.10/dist-packages (from tensorflow-probability&gt;=0.15.0-&gt;distrax&gt;=0.1.5-&gt;rex-lib[examples]) (4.4.2)\nRequirement already satisfied: gast&gt;=0.3.2 in /usr/local/lib/python3.10/dist-packages (from tensorflow-probability&gt;=0.15.0-&gt;distrax&gt;=0.1.5-&gt;rex-lib[examples]) (0.6.0)\nRequirement already satisfied: dm-tree in /usr/local/lib/python3.10/dist-packages (from tensorflow-probability&gt;=0.15.0-&gt;distrax&gt;=0.1.5-&gt;rex-lib[examples]) (0.1.8)\nRequirement already satisfied: Werkzeug&gt;=2.2.2 in /usr/local/lib/python3.10/dist-packages (from flask-&gt;brax&gt;=0.10.5-&gt;rex-lib[examples]) (3.0.4)\nRequirement already satisfied: itsdangerous&gt;=2.0 in /usr/local/lib/python3.10/dist-packages (from flask-&gt;brax&gt;=0.10.5-&gt;rex-lib[examples]) (2.2.0)\nRequirement already satisfied: click&gt;=8.0 in /usr/local/lib/python3.10/dist-packages (from flask-&gt;brax&gt;=0.10.5-&gt;rex-lib[examples]) (8.1.7)\nRequirement already satisfied: MarkupSafe&gt;=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2-&gt;brax&gt;=0.10.5-&gt;rex-lib[examples]) (2.1.5)\nRequirement already satisfied: gym-notices&gt;=0.0.4 in /usr/local/lib/python3.10/dist-packages (from gym-&gt;brax&gt;=0.10.5-&gt;rex-lib[examples]) (0.0.8)\nRequirement already satisfied: contextlib2 in /usr/local/lib/python3.10/dist-packages (from ml-collections-&gt;brax&gt;=0.10.5-&gt;rex-lib[examples]) (21.6.0)\nCollecting glfw (from mujoco-&gt;brax&gt;=0.10.5-&gt;rex-lib[examples])\n  Downloading glfw-2.7.0-py2.py27.py3.py30.py31.py32.py33.py34.py35.py36.py37.py38-none-manylinux2014_x86_64.whl.metadata (5.4 kB)\nRequirement already satisfied: pyopengl in /usr/local/lib/python3.10/dist-packages (from mujoco-&gt;brax&gt;=0.10.5-&gt;rex-lib[examples]) (3.1.7)\nRequirement already satisfied: nest_asyncio in /usr/local/lib/python3.10/dist-packages (from orbax-checkpoint-&gt;brax&gt;=0.10.5-&gt;rex-lib[examples]) (1.6.0)\nRequirement already satisfied: protobuf in /usr/local/lib/python3.10/dist-packages (from orbax-checkpoint-&gt;brax&gt;=0.10.5-&gt;rex-lib[examples]) (3.20.3)\nRequirement already satisfied: humanize in /usr/local/lib/python3.10/dist-packages (from orbax-checkpoint-&gt;brax&gt;=0.10.5-&gt;rex-lib[examples]) (4.10.0)\nRequirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py&gt;=2.2.0-&gt;rich&gt;=11.1-&gt;flax&gt;=0.8.5-&gt;rex-lib[examples]) (0.1.2)\nRequirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from etils-&gt;brax&gt;=0.10.5-&gt;rex-lib[examples]) (2024.6.1)\nRequirement already satisfied: importlib_resources in /usr/local/lib/python3.10/dist-packages (from etils-&gt;brax&gt;=0.10.5-&gt;rex-lib[examples]) (6.4.5)\nRequirement already satisfied: zipp in /usr/local/lib/python3.10/dist-packages (from etils-&gt;brax&gt;=0.10.5-&gt;rex-lib[examples]) (3.20.2)\nDownloading brax-0.11.0-py3-none-any.whl (998 kB)\n   \u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501 998.6/998.6 kB 11.8 MB/s eta 0:00:00\nDownloading dill-0.3.9-py3-none-any.whl (119 kB)\n   \u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501 119.4/119.4 kB 5.3 MB/s eta 0:00:00\nDownloading distrax-0.1.5-py3-none-any.whl (319 kB)\n   \u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501 319.7/319.7 kB 9.4 MB/s eta 0:00:00\nDownloading equinox-0.11.7-py3-none-any.whl (178 kB)\n   \u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501 178.4/178.4 kB 7.5 MB/s eta 0:00:00\nDownloading evosax-0.1.6-py3-none-any.whl (240 kB)\n   \u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501 240.4/240.4 kB 8.5 MB/s eta 0:00:00\nDownloading gymnasium-1.0.0-py3-none-any.whl (958 kB)\n   \u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501 958.1/958.1 kB 13.3 MB/s eta 0:00:00\nDownloading seaborn-0.13.2-py3-none-any.whl (294 kB)\n   \u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501 294.9/294.9 kB 8.4 MB/s eta 0:00:00\nDownloading supergraph-0.0.8-py3-none-any.whl (65 kB)\n   \u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501 65.5/65.5 kB 2.5 MB/s eta 0:00:00\nDownloading rex_lib-0.0.5-py3-none-any.whl (115 kB)\n   \u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501 115.1/115.1 kB 4.9 MB/s eta 0:00:00\nDownloading Farama_Notifications-0.0.4-py3-none-any.whl (2.5 kB)\nDownloading jaxtyping-0.2.34-py3-none-any.whl (42 kB)\n   \u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501 42.4/42.4 kB 1.7 MB/s eta 0:00:00\nDownloading typeguard-2.13.3-py3-none-any.whl (17 kB)\nDownloading dm_env-1.6-py3-none-any.whl (26 kB)\nDownloading dotmap-1.3.30-py3-none-any.whl (11 kB)\nDownloading Flask_Cors-5.0.0-py2.py3-none-any.whl (14 kB)\nDownloading jaxopt-0.8.3-py3-none-any.whl (172 kB)\n   \u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501 172.3/172.3 kB 5.2 MB/s eta 0:00:00\nDownloading mujoco-3.2.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (6.1 MB)\n   \u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501 6.1/6.1 MB 23.7 MB/s eta 0:00:00\nDownloading mujoco_mjx-3.2.3-py3-none-any.whl (6.7 MB)\n   \u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501 6.7/6.7 MB 12.5 MB/s eta 0:00:00\nDownloading pytinyrenderer-0.0.14-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.9 MB)\n   \u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501 1.9/1.9 MB 13.9 MB/s eta 0:00:00\nDownloading tensorboardX-2.6.2.2-py2.py3-none-any.whl (101 kB)\n   \u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501 101.7/101.7 kB 3.1 MB/s eta 0:00:00\nDownloading trimesh-4.4.9-py3-none-any.whl (700 kB)\n   \u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501 700.1/700.1 kB 20.4 MB/s eta 0:00:00\nDownloading glfw-2.7.0-py2.py27.py3.py30.py31.py32.py33.py34.py35.py36.py37.py38-none-manylinux2014_x86_64.whl (211 kB)\n   \u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501 211.8/211.8 kB 7.5 MB/s eta 0:00:00\nBuilding wheels for collected packages: ml-collections\n  Building wheel for ml-collections (setup.py) ... done\n  Created wheel for ml-collections: filename=ml_collections-0.1.1-py3-none-any.whl size=94507 sha256=8b83b1225aa4d52136d84206a5cb94da537f08a16dbd7b480fa90dd833c1cf78\n  Stored in directory: /root/.cache/pip/wheels/7b/89/c9/a9b87790789e94aadcfc393c283e3ecd5ab916aed0a31be8fe\nSuccessfully built ml-collections\nInstalling collected packages: pytinyrenderer, glfw, farama-notifications, dotmap, typeguard, trimesh, tensorboardX, supergraph, ml-collections, gymnasium, dm-env, dill, jaxtyping, seaborn, mujoco, flask-cors, mujoco-mjx, jaxopt, equinox, distrax, evosax, brax, rex-lib\n  Attempting uninstall: typeguard\n    Found existing installation: typeguard 4.3.0\n    Uninstalling typeguard-4.3.0:\n      Successfully uninstalled typeguard-4.3.0\n  Attempting uninstall: seaborn\n    Found existing installation: seaborn 0.13.1\n    Uninstalling seaborn-0.13.1:\n      Successfully uninstalled seaborn-0.13.1\nERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\ninflect 7.4.0 requires typeguard&gt;=4.0.1, but you have typeguard 2.13.3 which is incompatible.\nSuccessfully installed brax-0.11.0 dill-0.3.9 distrax-0.1.5 dm-env-1.6 dotmap-1.3.30 equinox-0.11.7 evosax-0.1.6 farama-notifications-0.0.4 flask-cors-5.0.0 glfw-2.7.0 gymnasium-1.0.0 jaxopt-0.8.3 jaxtyping-0.2.34 ml-collections-0.1.1 mujoco-3.2.3 mujoco-mjx-3.2.3 pytinyrenderer-0.0.14 rex-lib-0.0.5 seaborn-0.13.2 supergraph-0.0.8 tensorboardX-2.6.2.2 trimesh-4.4.9 typeguard-2.13.3\n</code>\n</pre> <pre><code># @title Example: Agent\n\nfrom typing import Tuple, Union\n\nimport jax\nfrom flax import struct\nfrom flax.core import FrozenDict\nfrom jax import numpy as jnp\n\nfrom rex import base\nfrom rex.base import GraphState, StepState\nfrom rex.node import BaseNode\nfrom rex.ppo import Policy\n\n\n@struct.dataclass\nclass AgentOutput(base.Base):\n    \"\"\"Agent's output\"\"\"\n\n    action: jax.typing.ArrayLike  # Torque to apply to the pendulum\n\n\n@struct.dataclass\nclass AgentParams(base.Base):\n    # Policy\n    policy: Policy\n    # Observations\n    num_act: Union[int, jax.typing.ArrayLike] = struct.field(pytree_node=False)  # Action history length\n    num_obs: Union[int, jax.typing.ArrayLike] = struct.field(pytree_node=False)  # Observation history length\n    # Action\n    max_torque: Union[float, jax.typing.ArrayLike]\n    # Initial state\n    init_method: str = struct.field(pytree_node=False)  # \"random\", \"parametrized\"\n    parametrized: jax.typing.ArrayLike\n    max_th: Union[float, jax.typing.ArrayLike]\n    max_thdot: Union[float, jax.typing.ArrayLike]\n    # Train\n    gamma: Union[float, jax.typing.ArrayLike]\n    tmax: Union[float, jax.typing.ArrayLike]\n\n    @staticmethod\n    def process_inputs(inputs: FrozenDict[str, base.InputState]) -&gt; jax.Array:\n        th, thdot = inputs[\"sensor\"][-1].data.th, inputs[\"sensor\"][-1].data.thdot\n        obs = jnp.array([jnp.cos(th), jnp.sin(th), thdot])\n        return obs\n\n    @staticmethod\n    def get_observation(step_state: StepState) -&gt; jax.Array:\n        # Unpack StepState\n        inputs, state = step_state.inputs, step_state.state\n\n        # Convert inputs to single observation\n        single_obs = AgentParams.process_inputs(inputs)\n\n        # Concatenate with previous observations\n        obs = jnp.concatenate([single_obs, state.history_obs.flatten(), state.history_act.flatten()])\n        return obs\n\n    @staticmethod\n    def update_state(step_state: StepState, action: jax.Array) -&gt; \"AgentState\":\n        # Unpack StepState\n        state, params, inputs = step_state.state, step_state.params, step_state.inputs\n\n        # Convert inputs to observation\n        single_obs = AgentParams.process_inputs(inputs)\n\n        # Update obs history\n        if params.num_obs &gt; 0:\n            history_obs = jnp.roll(state.history_obs, shift=1, axis=0)\n            history_obs = history_obs.at[0].set(single_obs)\n        else:\n            history_obs = state.history_obs\n\n        # Update act history\n        if params.num_act &gt; 0:\n            history_act = jnp.roll(state.history_act, shift=1, axis=0)\n            history_act = history_act.at[0].set(action)\n        else:\n            history_act = state.history_act\n\n        new_state = state.replace(history_obs=history_obs, history_act=history_act)\n        return new_state\n\n    @staticmethod\n    def to_output(action: jax.Array) -&gt; AgentOutput:\n        return AgentOutput(action=action)\n\n\n@struct.dataclass\nclass AgentState(base.Base):\n    history_act: jax.typing.ArrayLike  # History of actions\n    history_obs: jax.typing.ArrayLike  # History of observations\n    init_th: Union[float, jax.typing.ArrayLike]  # Initial angle of the pendulum\n    init_thdot: Union[float, jax.typing.ArrayLike]  # Initial angular velocity of the pendulum\n\n\nclass Agent(BaseNode):\n    def init_params(self, rng: jax.Array = None, graph_state: GraphState = None) -&gt; AgentParams:\n        return AgentParams(\n            policy=None,  # Policy must be set by the user\n            num_act=4,  # Number of actions to keep in history\n            num_obs=4,  # Number of observations to keep in history\n            max_torque=2.0,  # Maximum torque that can be applied to the pendulum\n            init_method=\"parametrized\",  # \"random\" or \"parametrized\"\n            parametrized=jnp.array([jnp.pi, 0.0]),  # [th, thdot]\n            max_th=jnp.pi,  # Maximum initial angle of the pendulum\n            max_thdot=9.0,  # Maximum initial angular velocity of the pendulum\n            gamma=0.99,  # Discount factor  (used during training)\n            tmax=3.0,  # Maximum time for an episode (used during training)\n        )\n\n    def init_state(self, rng: jax.Array = None, graph_state: GraphState = None) -&gt; AgentState:\n        graph_state = graph_state or base.GraphState()\n        params = graph_state.params.get(self.name, self.init_params(rng, graph_state))\n        history_act = jnp.zeros((params.num_act, 1), dtype=jnp.float32)  # [torque]\n        history_obs = jnp.zeros((params.num_obs, 3), dtype=jnp.float32)  # [cos(th), sin(th), thdot]\n\n        # Set the initial state of the pendulum\n        if params.init_method == \"parametrized\":\n            init_th, init_thdot = params.parametrized\n        elif params.init_method == \"random\":\n            rng = rng if rng is not None else jax.random.PRNGKey(0)\n            rngs = jax.random.split(rng, num=2)\n            init_th = jax.random.uniform(rngs[0], shape=(), minval=-params.max_th, maxval=params.max_th)\n            init_thdot = jax.random.uniform(rngs[1], shape=(), minval=-params.max_thdot, maxval=params.max_thdot)\n        else:\n            raise ValueError(f\"Invalid init_method: {params.init_method}\")\n        return AgentState(history_act=history_act, history_obs=history_obs, init_th=init_th, init_thdot=init_thdot)\n\n    def init_output(self, rng: jax.Array = None, graph_state: GraphState = None) -&gt; AgentOutput:\n        \"\"\"Default output of the node.\"\"\"\n        rng = jax.random.PRNGKey(0) if rng is None else rng\n        graph_state = graph_state or base.GraphState()\n        params = graph_state.params.get(self.name, self.init_params(rng, graph_state))\n        action = jax.random.uniform(rng, shape=(1,), minval=-params.max_torque, maxval=params.max_torque)\n        return AgentOutput(action=action)\n\n    def step(self, step_state: StepState) -&gt; Tuple[StepState, AgentOutput]:\n        \"\"\"Step the node.\"\"\"\n        # Unpack StepState\n        rng, params = step_state.rng, step_state.params\n\n        # Prepare output\n        rng, rng_net = jax.random.split(rng)\n        if params.policy is not None:  # Use policy to get action\n            obs = AgentParams.get_observation(step_state)\n            action = params.policy.get_action(obs, rng=None)  # Supply rng for stochastic policies\n        else:  # Random action if no policy is set\n            action = jax.random.uniform(rng_net, shape=(1,), minval=-params.max_torque, maxval=params.max_torque)\n        output = AgentParams.to_output(action)  # Convert action to output message\n\n        # Update step_state (observation and action history)\n        new_state = params.update_state(step_state, action)  # Update state\n        new_step_state = step_state.replace(rng=rng, state=new_state)  # Update step_state\n        return new_step_state, output\n</code></pre> <pre><code># @title Example: Actuator\n\nfrom typing import Tuple, Union\n\nimport jax\nimport numpy as onp\nfrom flax import struct\n\nfrom rex import base\nfrom rex.base import GraphState, StepState\nfrom rex.jax_utils import tree_dynamic_slice\nfrom rex.node import BaseNode\n\n\n@struct.dataclass\nclass ActuatorOutput(base.Base):\n    \"\"\"Pendulum actuator output\"\"\"\n\n    action: jax.typing.ArrayLike  # Torque to apply to the pendulum\n\n\n@struct.dataclass\nclass ActuatorParams(base.Base):\n    \"\"\"Pendulum actuator param definition\"\"\"\n\n    actuator_delay: Union[float, jax.typing.ArrayLike]\n\n\nclass Actuator(BaseNode):\n    \"\"\"This is a simple actuator node definition that could interface a real actuator.\n\n    When interfacing real hardware, you would send the action to real hardware in the .step method.\n    Optionally, you could also specify a startup routine that is called right before an episode starts.\n    Finally, a stop routine is called after the episode is done.\n    \"\"\"\n\n    def __init__(self, *args, **kwargs):\n        \"\"\"No special initialization needed.\"\"\"\n        super().__init__(*args, **kwargs)\n\n    def init_params(self, rng: jax.Array = None, graph_state: GraphState = None) -&gt; ActuatorParams:\n        \"\"\"Default params of the node.\"\"\"\n        actuator_delay = 0.05\n        return ActuatorParams(actuator_delay=actuator_delay)\n\n    def init_output(self, rng: jax.Array = None, graph_state: GraphState = None) -&gt; ActuatorOutput:\n        \"\"\"Default output of the node.\"\"\"\n        return ActuatorOutput(action=jnp.array([0.0], dtype=jnp.float32))\n\n    def startup(self, graph_state: base.GraphState, timeout: float = None) -&gt; bool:\n        \"\"\"Starts the node in the state specified by graph_state.\n\n        This method is called right before an episode starts.\n        It can be used to move (a real) robot to a starting position as specified by the graph_state.\n\n        Not used when running in compiled mode.\n        :param graph_state: The graph state.\n        :param timeout: The timeout of the startup.\n        :return: Whether the node has started successfully.\n        \"\"\"\n        # Move robot to starting position specified by graph_state (e.g. graph_state.state[\"agent\"].init_th)\n        return True  # Not doing anything here\n\n    def step(self, step_state: StepState) -&gt; Tuple[StepState, ActuatorOutput]:\n        \"\"\"If we were to control a real robot, you would send the action to the robot here.\"\"\"\n        # Prepare output\n        output = step_state.inputs[\"agent\"][-1].data\n        output = ActuatorOutput(action=output.action)\n\n        def _apply_action(action):\n            \"\"\"\n            Not really doing anything here, just a dummy implementation.\n            Include some side-effecting code here (e.g. sending the action to a real robot).\n\n            The .step method may be jit-compiled, it is important to wrap any side-effecting code in a host_callback.\n            See the jax documentation for more information on how to do this:\n            https://jax.readthedocs.io/en/latest/external-callbacks.html\n            \"\"\"\n            # print(f\"Applying action: {action}\") # Apply action to the robot\n            return onp.array(1.0)  # Must match dtype and shape of return_shape\n\n        # Apply action to the robot\n        return_shape = jnp.array(1.0)  # Must match dtype and shape of return_shape\n        _ = jax.experimental.io_callback(_apply_action, return_shape, output)\n\n        # Update state\n        new_step_state = step_state\n        return new_step_state, output\n\n    def stop(self, timeout: float = None) -&gt; bool:\n        \"\"\"Stopping routine that is called after the episode is done.\n\n        **IMPORTANT** It may happen that stop is called *before* the final .step call of an episode returns,\n        which may cause unsafe behavior when the final step undoes the work of the .stop method.\n        This should be handled by the user. For example, by stopping \"longer\" before returning here.\n\n        Only ran when running asynchronously.\n        :param timeout: The timeout of the stop\n        :return: Whether the node has stopped successfully.\n        \"\"\"\n        # Stop the robot (e.g. set the torque to 0)\n        return True\n\n\nclass SimActuator(BaseNode):\n    \"\"\"This is a simple simulated actuator node definition that can either\n    1. Feedthrough the agent's action (for normal operation, e.g., training).\n       Optionally, you could include some noise or other modifications to the action.\n    2. Reapply the recorded actuator outputs for system identification if available.\n    \"\"\"\n\n    def __init__(self, *args, outputs: ActuatorOutput = None, **kwargs):\n        \"\"\"Initialize Actuator for system identification.\n\n        Here, we will reapply the recorded actuator outputs for system identification if available.\n\n        :param outputs: Recorded actuator Outputs to be used for system identification.\n        \"\"\"\n        super().__init__(*args, **kwargs)\n        self._outputs = outputs\n\n    def init_params(self, rng: jax.Array = None, graph_state: GraphState = None) -&gt; ActuatorParams:\n        \"\"\"Default params of the node.\"\"\"\n        actuator_delay = 0.05\n        return ActuatorParams(actuator_delay=actuator_delay)\n\n    def init_output(self, rng: jax.Array = None, graph_state: GraphState = None) -&gt; ActuatorOutput:\n        \"\"\"Default output of the node.\"\"\"\n        return ActuatorOutput(action=jnp.array([0.0], dtype=jnp.float32))\n\n    def step(self, step_state: StepState) -&gt; Tuple[StepState, ActuatorOutput]:\n        # Get action from dataset if available, else use the one provided by the agent\n        if self._outputs is not None:  # Use the recorded action (for system identification)\n            output = tree_dynamic_slice(self._outputs, jnp.array([step_state.eps, step_state.seq]))\n            output = jax.tree_util.tree_map(lambda _o: _o[0, 0], output)\n        else:  # Feedthrough the agent's action (for normal operation, e.g., training)\n            output = step_state.inputs[\"agent\"][-1].data\n            output = ActuatorOutput(action=output.action)\n        new_step_state = step_state\n        return new_step_state, output\n</code></pre> <pre><code># @title Example: Sensor\n\nfrom typing import Dict, Tuple, Union\n\nimport jax\nfrom flax import struct\n\nfrom rex import base\nfrom rex.base import GraphState, StepState\nfrom rex.node import BaseNode\n\n\n@struct.dataclass\nclass SensorOutput(base.Base):\n    \"\"\"Output message definition of the sensor node.\"\"\"\n\n    th: Union[float, jax.typing.ArrayLike]\n    thdot: Union[float, jax.typing.ArrayLike]\n\n\n@struct.dataclass\nclass SensorParams(base.Base):\n    \"\"\"\n    Other than the sensor delay, we don't have any other parameters.\n    You could add more parameters here if needed, such as noise levels etc.\n    \"\"\"\n\n    sensor_delay: Union[float, jax.typing.ArrayLike]\n\n\n@struct.dataclass\nclass SensorState:\n    \"\"\"We use this state to record the reconstruction loss.\"\"\"\n\n    loss_th: Union[float, jax.typing.ArrayLike]\n    loss_thdot: Union[float, jax.typing.ArrayLike]\n\n\nclass Sensor(BaseNode):\n    \"\"\"This is a simple sensor node definition that interfaces a real sensor.\n\n    When interfacing real hardware, you would grab the sensor measurement in the .step method.\n    Optionally, you could also specify a startup routine that is called right before an episode starts.\n    Finally, a stop routine is called after the episode is done.\n    \"\"\"\n\n    def __init__(self, *args, **kwargs):\n        \"\"\"No special initialization needed.\"\"\"\n        super().__init__(*args, **kwargs)\n\n    def init_params(self, rng: jax.Array = None, graph_state: GraphState = None) -&gt; SensorParams:\n        \"\"\"Default params of the node.\"\"\"\n        sensor_delay = 0.05\n        return SensorParams(sensor_delay=sensor_delay)\n\n    def init_output(self, rng: jax.Array = None, graph_state: GraphState = None) -&gt; SensorOutput:\n        \"\"\"Default output of the node.\"\"\"\n        # Randomly define some initial sensor values\n        th = jnp.pi\n        thdot = 0.0\n        return SensorOutput(th=th, thdot=thdot)\n\n    def startup(self, graph_state: base.GraphState, timeout: float = None) -&gt; bool:\n        \"\"\"Starts the node in the state specified by graph_state.\n\n        This method is called right before an episode starts.\n        It can be used to move (a real) robot to a starting position as specified by the graph_state.\n\n        Not used when running in compiled mode.\n        :param graph_state: The graph state.\n        :param timeout: The timeout of the startup.\n        :return: Whether the node has started successfully.\n        \"\"\"\n        return True  # Not doing anything here\n\n    def step(self, step_state: StepState) -&gt; Tuple[StepState, SensorOutput]:\n        \"\"\"If we were to interface a real hardware, you would grab the sensor measurement here.\"\"\"\n\n        \"\"\"\n        As the .step method may be jit-compiled, it is important to wrap any side-effecting code in a host_callback.\n        See the jax documentation for more information on how to do this:\n        https://jax.readthedocs.io/en/latest/external-callbacks.html\n        \"\"\"\n        world = step_state.inputs[\"world\"][-1].data\n\n        def _grab_measurement():\n            \"\"\"\n            Not really doing anything here, just a dummy implementation.\n            Include some side-effecting code here (e.g. grabbing measurement from sensor).\n\n            The .step method may be jit-compiled, it is important to wrap any side-effecting code in a host_callback.\n            See the jax documentation for more information on how to do this:\n            https://jax.readthedocs.io/en/latest/external-callbacks.html\n            \"\"\"\n            # print(\"Grabbing sensor measurement\")\n            sensor_msg = onp.array(1.0)  # Dummy sensor measurement (not actually used)\n            return sensor_msg  # Must match dtype and shape of return_shape\n\n        # Grab sensor measurement\n        return_shape = jnp.array(1.0)  # Must match dtype and shape of return_shape\n        _ = jax.experimental.io_callback(_grab_measurement, return_shape)\n\n        # Prepare output\n        output = SensorOutput(th=world.th, thdot=world.thdot)\n\n        # Update state (NOOP)\n        new_step_state = step_state\n\n        return new_step_state, output\n\n    def stop(self, timeout: float = None) -&gt; bool:\n        \"\"\"Stopping routine that is called after the episode is done.\n\n        **IMPORTANT** It may happen that stop is called *before* the final .step call of an episode returns,\n        which may cause unsafe behavior when the final step undoes the work of the .stop method.\n        This should be handled by the user. For example, by stopping \"longer\" before returning here.\n\n        Only ran when running asynchronously.\n        :param timeout: The timeout of the stop\n        :return: Whether the node has stopped successfully.\n        \"\"\"\n        return True  # Not doing anything here\n\n\nclass SimSensor(BaseNode):\n    \"\"\"This is a simple simulated sensor node definition that can either\n    1. Convert the world state into a realistic sensor measurement (for normal operation, e.g., training).\n       Optionally, you could include some noise or other modifications to the sensor measurement.\n    2. Calculate a reconstruction loss based on the sensor measurement and the recorded sensor outputs.\n\n    By calculating and aggregating the reconstruction loss here, we take time-scale differences and delays into account.\n    \"\"\"\n\n    def __init__(self, *args, outputs: SensorOutput = None, **kwargs):\n        \"\"\"Initialize a simulated sensor for system identification.\n\n        If outputs are provided, we will calculate the reconstruction loss based on the recorded sensor outputs.\n\n        :param outputs: Recorded sensor Outputs to be used for system identification.\n        \"\"\"\n        super().__init__(*args, **kwargs)\n        self._outputs = outputs\n\n    def init_params(self, rng: jax.Array = None, graph_state: GraphState = None) -&gt; SensorParams:\n        \"\"\"Default params of the node.\"\"\"\n        sensor_delay = 0.05\n        return SensorParams(sensor_delay=sensor_delay)\n\n    def init_state(self, rng: jax.Array = None, graph_state: GraphState = None) -&gt; SensorState:\n        \"\"\"Default state of the node.\"\"\"\n        return SensorState(loss_th=0.0, loss_thdot=0.0)  # Initialize reconstruction loss to zero at the start of the episode\n\n    def init_output(self, rng: jax.Array = None, graph_state: GraphState = None) -&gt; SensorOutput:\n        \"\"\"Default output of the node.\"\"\"\n        # Randomly define some initial sensor values\n        th = jnp.pi\n        thdot = 0.0\n        return SensorOutput(th=th, thdot=thdot)  # Fix the initial sensor values\n\n    def init_delays(\n        self, rng: jax.Array = None, graph_state: base.GraphState = None\n    ) -&gt; Dict[str, Union[float, jax.typing.ArrayLike]]:\n        \"\"\"Initialize trainable communication delays.\n\n        **Note** These only include trainable delays that were specified while connecting the nodes.\n\n        :param rng: Random number generator.\n        :param graph_state: The graph state that may be used to get the default output.\n        :return: Trainable delays (e.g., {input_name: delay}). Can be an incomplete dictionary.\n                 Entries for non-trainable delays or non-existent connections are ignored.\n        \"\"\"\n        graph_state = graph_state or GraphState()\n        params = graph_state.params.get(self.name, self.init_params(rng, graph_state))\n        delays = {\"world\": params.sensor_delay}\n        return delays\n\n    def step(self, step_state: StepState) -&gt; Tuple[StepState, SensorOutput]:\n        # Determine output\n        data = step_state.inputs[\"world\"][-1].data\n        output = SensorOutput(th=data.th, thdot=data.thdot)\n\n        # Calculate loss\n        if self._outputs is not None:  # Calculate reconstruction loss and aggregate in state\n            output_rec = tree_dynamic_slice(self._outputs, jnp.array([step_state.eps, step_state.seq]))\n            output_rec = jax.tree_util.tree_map(lambda _o: _o[0, 0], output_rec)\n            th_rec, thdot_rec = output_rec.th, output_rec.thdot\n            state = step_state.state\n            loss_th = state.loss_th + (jnp.sin(output.th) - jnp.sin(th_rec)) ** 2 + (jnp.cos(output.th) - jnp.cos(th_rec)) ** 2\n            loss_thdot = state.loss_thdot + (output.thdot - thdot_rec) ** 2\n            new_state = state.replace(loss_th=loss_th, loss_thdot=loss_thdot)\n        else:  # NOOP\n            new_state = step_state.state\n\n        # Update step_state\n        new_step_state = step_state.replace(state=new_state)\n        return new_step_state, output\n</code></pre> <pre><code># @title Example: ODE simulation node\n\nfrom math import ceil\nfrom typing import Dict, Tuple, Union\n\nimport jax\nfrom flax import struct\n\nfrom rex import base\nfrom rex.base import GraphState, StepState\nfrom rex.node import BaseWorld\n\n\n@struct.dataclass\nclass OdeParams(base.Base):\n    \"\"\"Pendulum ode param definition\"\"\"\n\n    max_speed: Union[float, jax.typing.ArrayLike]\n    J: Union[float, jax.typing.ArrayLike]\n    mass: Union[float, jax.typing.ArrayLike]\n    length: Union[float, jax.typing.ArrayLike]\n    b: Union[float, jax.typing.ArrayLike]\n    K: Union[float, jax.typing.ArrayLike]\n    R: Union[float, jax.typing.ArrayLike]\n    c: Union[float, jax.typing.ArrayLike]\n    dt_substeps_min: Union[float, jax.typing.ArrayLike] = struct.field(pytree_node=False)\n    dt: Union[float, jax.typing.ArrayLike] = struct.field(pytree_node=False)\n\n    @property\n    def substeps(self) -&gt; int:\n        substeps = ceil(self.dt / self.dt_substeps_min)\n        return int(substeps)\n\n    @property\n    def dt_substeps(self) -&gt; float:\n        substeps = self.substeps\n        dt_substeps = self.dt / substeps\n        return dt_substeps\n\n    def step(\n        self, substeps: int, dt_substeps: jax.typing.ArrayLike, x: \"OdeState\", us: jax.typing.ArrayLike\n    ) -&gt; Tuple[\"OdeState\", \"OdeState\"]:\n        \"\"\"Step the pendulum ode.\"\"\"\n\n        def _scan_fn(_x, _u):\n            next_x = self._runge_kutta4(dt_substeps, _x, _u)\n            # Clip velocity\n            clip_thdot = jnp.clip(next_x.thdot, -self.max_speed, self.max_speed)\n            next_x = next_x.replace(thdot=clip_thdot)\n            return next_x, next_x\n\n        x_final, x_substeps = jax.lax.scan(_scan_fn, x, us, length=substeps)\n        return x_final, x_substeps\n\n    def _runge_kutta4(self, dt: jax.typing.ArrayLike, x: \"OdeState\", u: jax.typing.ArrayLike) -&gt; \"OdeState\":\n        k1 = self._ode(x, u)\n        k2 = self._ode(x + k1 * dt * 0.5, u)\n        k3 = self._ode(x + k2 * dt * 0.5, u)\n        k4 = self._ode(x + k3 * dt, u)\n        return x + (k1 + k2 * 2 + k3 * 2 + k4) * (dt / 6)\n\n    def _ode(self, x: \"OdeState\", u: jax.typing.ArrayLike) -&gt; \"OdeState\":\n        \"\"\"dx function for the pendulum ode\"\"\"\n        # Downward := [pi, 0], Upward := [0, 0]\n        g, J, m, l, b, K, R, c = 9.81, self.J, self.mass, self.length, self.b, self.K, self.R, self.c  # noqa: E741\n        th, thdot = x.th, x.thdot\n        activation = jnp.sign(thdot)\n        ddx = (u * K / R + m * g * l * jnp.sin(th) - b * thdot - thdot * K * K / R - c * activation) / J\n        return OdeState(th=thdot, thdot=ddx, loss_task=0.0)  # No derivative for loss_task\n\n\n@struct.dataclass\nclass OdeState(base.Base):\n    \"\"\"Pendulum state definition\"\"\"\n\n    loss_task: Union[float, jax.typing.ArrayLike]\n    th: Union[float, jax.typing.ArrayLike]\n    thdot: Union[float, jax.typing.ArrayLike]\n\n\n@struct.dataclass\nclass OdeOutput(base.Base):\n    \"\"\"World output definition\"\"\"\n\n    th: Union[float, jax.typing.ArrayLike]\n    thdot: Union[float, jax.typing.ArrayLike]\n\n\nclass OdeWorld(BaseWorld):  # We inherit from BaseWorld for convenience, but you can inherit from BaseNode if you want\n    def init_params(self, rng: jax.Array = None, graph_state: GraphState = None) -&gt; OdeParams:\n        \"\"\"Default params of the node.\"\"\"\n        return OdeParams(\n            max_speed=40.0,  # Clip angular velocity to this value\n            J=0.000159931461600856,  # 0.000159931461600856,\n            mass=0.0508581731919534,  # 0.0508581731919534,\n            length=0.0415233722862552,  # 0.0415233722862552,\n            b=1.43298488e-05,  # 1.43298488358436e-05,\n            K=0.03333912,  # 0.0333391179016334,\n            R=7.73125142,  # 7.73125142447252,\n            c=0.000975041213361349,  # 0.000975041213361349,\n            # Backend parameters\n            dt_substeps_min=1 / 100,  # Minimum substep size for ode integration\n            dt=1 / self.rate,  # Time step per .step() call\n        )\n\n    def init_state(self, rng: jax.Array = None, graph_state: GraphState = None) -&gt; OdeState:\n        \"\"\"Default state of the node.\"\"\"\n        graph_state = graph_state or GraphState()\n\n        # Try to grab state from graph_state\n        state = graph_state.state.get(\"agent\", None)\n        init_th = state.init_th if state is not None else jnp.pi\n        init_thdot = state.init_thdot if state is not None else 0.0\n        return OdeState(th=init_th, thdot=init_thdot, loss_task=0.0)\n\n    def init_output(self, rng: jax.Array = None, graph_state: GraphState = None) -&gt; OdeOutput:\n        \"\"\"Default output of the node.\"\"\"\n        graph_state = graph_state or GraphState()\n        # Grab output from state\n        world_state = graph_state.state.get(self.name, self.init_state(rng, graph_state))\n        return OdeOutput(th=world_state.th, thdot=world_state.thdot)\n\n    def init_delays(\n        self, rng: jax.Array = None, graph_state: base.GraphState = None\n    ) -&gt; Dict[str, Union[float, jax.typing.ArrayLike]]:\n        graph_state = graph_state or GraphState()\n        params = graph_state.params.get(\"actuator\")\n        delays = {}\n        if hasattr(params, \"actuator_delay\"):\n            delays[\"actuator\"] = params.actuator_delay\n        return delays\n\n    def step(self, step_state: StepState) -&gt; Tuple[StepState, OdeOutput]:\n        \"\"\"Step the node.\"\"\"\n        # Unpack StepState\n        _, state, params, inputs = step_state.rng, step_state.state, step_state.params, step_state.inputs\n\n        # Apply dynamics\n        u = inputs[\"actuator\"].data.action[-1][0]  # [-1] to get the latest action, [0] reduces the dimension to scalar\n        us = jnp.array([u] * params.substeps)\n        new_state = params.step(params.substeps, params.dt_substeps, state, us)[0]\n        next_th, next_thdot = new_state.th, new_state.thdot\n        output = OdeOutput(th=next_th, thdot=next_thdot)  # Prepare output\n\n        # Calculate cost (penalize angle error, angular velocity and input voltage)\n        norm_next_th = next_th - 2 * jnp.pi * jnp.floor((next_th + jnp.pi) / (2 * jnp.pi))\n        loss_task = state.loss_task + norm_next_th**2 + 0.1 * (next_thdot / (1 + 10 * abs(norm_next_th))) ** 2 + 0.01 * u**2\n\n        # Update state\n        new_state = new_state.replace(loss_task=loss_task)\n        new_step_state = step_state.replace(state=new_state)\n        return new_step_state, output\n</code></pre> <pre><code># @title Example: Brax simulation node\nfrom typing import Tuple, Union\n\nimport jax\nfrom flax import struct\n\nfrom rex import base\nfrom rex.base import GraphState, StepState\nfrom rex.node import BaseWorld\n\n\ntry:\n    from brax.generalized import pipeline as gen_pipeline\n    from brax.io import mjcf\n    from brax.positional import pipeline as pos_pipeline\n    from brax.spring import pipeline as spring_pipeline\n\n    Systems = Union[gen_pipeline.System, spring_pipeline.System, pos_pipeline.System]\n    Pipelines = Union[gen_pipeline.State, spring_pipeline.State, pos_pipeline.State]\nexcept ModuleNotFoundError as e:\n    print(\"Brax not installed. Install it with `pip install brax`\")\n    raise e\n\n\n@struct.dataclass\nclass BraxParams(base.Base):\n    max_speed: Union[float, jax.typing.ArrayLike]\n    damping: Union[float, jax.typing.ArrayLike]\n    armature: Union[float, jax.typing.ArrayLike]\n    gear: Union[float, jax.typing.ArrayLike]\n    mass_weight: Union[float, jax.typing.ArrayLike]\n    radius_weight: Union[float, jax.typing.ArrayLike]\n    offset: Union[float, jax.typing.ArrayLike]\n    friction_loss: Union[float, jax.typing.ArrayLike]\n    backend: str = struct.field(pytree_node=False)\n    dt: Union[float, jax.typing.ArrayLike] = struct.field(pytree_node=False)\n\n    @property\n    def substeps(self) -&gt; int:\n        dt_substeps_per_backend = {\"generalized\": 1 / 100, \"spring\": 1 / 100, \"positional\": 1 / 100}[self.backend]\n        substeps = ceil(self.dt / dt_substeps_per_backend)\n        return int(substeps)\n\n    @property\n    def dt_substeps(self) -&gt; float:\n        substeps = self.substeps\n        dt_substeps = self.dt / substeps\n        return dt_substeps\n\n    @property\n    def pipeline(self) -&gt; Pipelines:\n        return {\"generalized\": gen_pipeline, \"spring\": spring_pipeline, \"positional\": pos_pipeline}[self.backend]\n\n    @property\n    def sys(self) -&gt; Systems:\n        base_sys = mjcf.loads(DISK_PENDULUM_XML)\n        # Appropriately replace parameters for the disk pendulum\n        itransform = base_sys.link.inertia.transform.replace(pos=jnp.array([[0.0, self.offset, 0.0]]))\n        i = base_sys.link.inertia.i.at[0, 0, 0].set(\n            0.5 * self.mass_weight * self.radius_weight**2\n        )  # inertia of cylinder in local frame.\n        inertia = base_sys.link.inertia.replace(transform=itransform, mass=jnp.array([self.mass_weight]), i=i)\n        link = base_sys.link.replace(inertia=inertia)\n        actuator = base_sys.actuator.replace(gear=jnp.array([self.gear]))\n        dof = base_sys.dof.replace(armature=jnp.array([self.armature]), damping=jnp.array([self.damping]))\n        opt = base_sys.opt.replace(timestep=self.dt_substeps)\n        new_sys = base_sys.replace(link=link, actuator=actuator, dof=dof, opt=opt)\n        return new_sys\n\n    def step(\n        self, substeps: int, dt_substeps: jax.typing.ArrayLike, x: Pipelines, us: jax.typing.ArrayLike\n    ) -&gt; Tuple[Pipelines, Pipelines]:\n        \"\"\"Step the pendulum ode.\"\"\"\n        # Appropriately replace timestep for the disk pendulum\n        sys = self.sys.replace(opt=self.sys.opt.replace(timestep=dt_substeps))\n\n        def _scan_fn(_x, _u):\n            # Add friction loss\n            thdot = x.qd[0]\n            activation = jnp.sign(thdot)\n            friction = self.friction_loss * activation / sys.actuator.gear[0]\n            _u_friction = _u - friction\n            # Step\n            next_x = gen_pipeline.step(sys, _x, jnp.array(_u_friction)[None])\n            # Clip velocity\n            next_x = next_x.replace(qd=jnp.clip(next_x.qd, -self.max_speed, self.max_speed))\n            return next_x, next_x\n\n        x_final, x_substeps = jax.lax.scan(_scan_fn, x, us, length=substeps)\n        return x_final, x_substeps\n\n\n@struct.dataclass\nclass BraxState(base.Base):\n    \"\"\"Pendulum state definition\"\"\"\n\n    loss_task: Union[float, jax.typing.ArrayLike]\n    pipeline_state: Pipelines\n\n    @property\n    def th(self):\n        return self.pipeline_state.q[..., 0]\n\n    @property\n    def thdot(self):\n        return self.pipeline_state.qd[..., 0]\n\n\n@struct.dataclass\nclass BraxOutput(base.Base):\n    \"\"\"World output definition\"\"\"\n\n    th: Union[float, jax.typing.ArrayLike]\n    thdot: Union[float, jax.typing.ArrayLike]\n\n\nclass BraxWorld(BaseWorld):  # We inherit from BaseWorld for convenience, but you can inherit from BaseNode if you want\n    def __init__(self, *args, backend: str = \"generalized\", **kwargs):\n        super().__init__(*args, **kwargs)\n        self.backend = backend\n\n    def init_params(self, rng: jax.Array = None, graph_state: GraphState = None) -&gt; BraxParams:\n        \"\"\"Default params of the node.\"\"\"\n        return BraxParams(\n            # Realistic parameters for the disk pendulum\n            max_speed=40.0,\n            damping=0.00015877,\n            armature=6.4940527e-06,\n            gear=0.00428677,\n            mass_weight=0.05076142,\n            radius_weight=0.05121992,\n            offset=0.04161447,\n            friction_loss=0.00097525,\n            # Backend parameters\n            dt=1 / self.rate,\n            backend=self.backend,\n        )\n\n    def init_state(self, rng: jax.Array = None, graph_state: GraphState = None) -&gt; BraxState:\n        \"\"\"Default state of the node.\"\"\"\n        graph_state = graph_state or GraphState()\n\n        # Try to grab state from graph_state\n        state = graph_state.state.get(\"agent\", None)\n        init_th = state.init_th if state is not None else jnp.pi\n        init_thdot = state.init_thdot if state is not None else 0.0\n\n        # Set the initial state of the disk pendulum\n        params = graph_state.params.get(self.name, self.init_params(rng, graph_state))\n        sys = params.sys\n        q = sys.init_q.at[0].set(init_th)\n        qd = jnp.array([init_thdot])\n        pipeline_state = params.pipeline.init(sys, q, qd)\n        return BraxState(pipeline_state=pipeline_state, loss_task=0.0)\n\n    def init_output(self, rng: jax.Array = None, graph_state: GraphState = None) -&gt; BraxOutput:\n        \"\"\"Default output of the node.\"\"\"\n        graph_state = graph_state or GraphState()\n        # Grab output from state\n        state = graph_state.state.get(self.name, self.init_state(rng, graph_state))\n        return BraxOutput(th=state.pipeline_state.q[0], thdot=state.pipeline_state.qd[0])\n\n    def step(self, step_state: StepState) -&gt; Tuple[StepState, BraxOutput]:\n        \"\"\"Step the node.\"\"\"\n\n        # Unpack StepState\n        _, state, params, inputs = step_state.rng, step_state.state, step_state.params, step_state.inputs\n\n        # Apply dynamics\n        u = inputs[\"actuator\"].data.action[-1][0]  # [-1] to get the latest action, [0] reduces the dimension to scalar\n        us = jnp.array([u] * params.substeps)\n        x = state.pipeline_state\n        next_x = params.step(params.substeps, params.dt_substeps, x, us)[0]\n        new_state = state.replace(pipeline_state=next_x)\n        next_th, next_thdot = new_state.th, new_state.thdot\n        output = BraxOutput(th=next_th, thdot=next_thdot)  # Prepare output\n\n        # Calculate cost (penalize angle error, angular velocity and input voltage)\n        norm_next_th = next_th - 2 * jnp.pi * jnp.floor((next_th + jnp.pi) / (2 * jnp.pi))\n        loss_task = state.loss_task + norm_next_th**2 + 0.1 * (next_thdot / (1 + 10 * abs(norm_next_th))) ** 2 + 0.01 * u**2\n\n        # Update state\n        new_state = new_state.replace(loss_task=loss_task)\n        new_step_state = step_state.replace(state=new_state)\n        return new_step_state, output\n\n\nDISK_PENDULUM_XML = \"\"\"\n&lt;mujoco model=\"disk_pendulum\"&gt;\n    &lt;compiler inertiafromgeom=\"auto\" angle=\"radian\" coordinate=\"local\" eulerseq=\"xyz\" autolimits=\"true\"/&gt;\n    &lt;option gravity=\"0 0 -9.81\" timestep=\"0.01\" iterations=\"10\"/&gt;\n    &lt;custom&gt;\n        &lt;numeric data=\"10\" name=\"constraint_ang_damping\"/&gt; &lt;!-- positional &amp; spring --&gt;\n        &lt;numeric data=\"1\" name=\"spring_inertia_scale\"/&gt;  &lt;!-- positional &amp; spring --&gt;\n        &lt;numeric data=\"0\" name=\"ang_damping\"/&gt;  &lt;!-- positional &amp; spring --&gt;\n        &lt;numeric data=\"0\" name=\"spring_mass_scale\"/&gt;  &lt;!-- positional &amp; spring --&gt;\n        &lt;numeric data=\"0.5\" name=\"joint_scale_pos\"/&gt; &lt;!-- positional --&gt;\n        &lt;numeric data=\"0.1\" name=\"joint_scale_ang\"/&gt; &lt;!-- positional --&gt;\n        &lt;numeric data=\"3000\" name=\"constraint_stiffness\"/&gt;  &lt;!-- spring --&gt;\n        &lt;numeric data=\"10000\" name=\"constraint_limit_stiffness\"/&gt;  &lt;!-- spring --&gt;\n        &lt;numeric data=\"50\" name=\"constraint_vel_damping\"/&gt;  &lt;!-- spring --&gt;\n        &lt;numeric data=\"10\" name=\"solver_maxls\"/&gt;  &lt;!-- generalized --&gt;\n    &lt;/custom&gt;\n\n    &lt;asset&gt;\n        &lt;texture builtin=\"flat\" height=\"1278\" mark=\"cross\" markrgb=\"1 1 1\" name=\"texgeom\" random=\"0.01\" rgb1=\"0.8 0.6 0.4\" rgb2=\"0.8 0.6 0.4\" type=\"cube\" width=\"127\"/&gt;\n        &lt;material name=\"geom\" texture=\"texgeom\" texuniform=\"true\"/&gt;\n    &lt;/asset&gt;\n\n    &lt;default&gt;\n        &lt;geom contype=\"0\" friction=\"1 0.1 0.1\" material=\"geom\"/&gt;\n    &lt;/default&gt;\n\n    &lt;worldbody&gt;\n        &lt;light cutoff=\"100\" diffuse=\"1 1 1\" dir=\"-0 0 -1.3\" directional=\"true\" exponent=\"1\" pos=\"0 0 1.3\" specular=\".1 .1 .1\"/&gt;\n        &lt;geom name=\"table\" type=\"plane\" pos=\"0 0.0 -0.1\" size=\"1 1 0.1\" contype=\"8\" conaffinity=\"11\" condim=\"3\"/&gt;\n        &lt;body name=\"disk\" pos=\"0.0 0.0 0.0\" euler=\"1.5708 0.0 0.0\"&gt;\n            &lt;joint name=\"hinge_joint\" type=\"hinge\" axis=\"0 0 1\" range=\"-180 180\" armature=\"0.00022993\" damping=\"0.0001\" limited=\"false\"/&gt;\n            &lt;geom name=\"disk_geom\" type=\"cylinder\" size=\"0.06 0.001\" contype=\"0\" conaffinity=\"0\" condim=\"3\" mass=\"0.0\"/&gt;\n            &lt;geom name=\"mass_geom\" type=\"cylinder\" size=\"0.02 0.005\" contype=\"0\" conaffinity=\"0\"  condim=\"3\" rgba=\"0.04 0.04 0.04 1\"\n                  pos=\"0.0 0.04 0.\" mass=\"0.05085817\"/&gt;\n        &lt;/body&gt;\n    &lt;/worldbody&gt;\n\n    &lt;actuator&gt;\n        &lt;motor joint=\"hinge_joint\" ctrllimited=\"false\" ctrlrange=\"-3.0 3.0\"  gear=\"0.01\"/&gt;\n    &lt;/actuator&gt;\n&lt;/mujoco&gt;\n\"\"\"\n</code></pre> <pre><code>\n</code></pre>"},{"location":"examples/node_definitions.html#defining-nodes-in-rex-robotic-environments-with-jax","title":"Defining Nodes in rex (Robotic Environments with jaX)","text":"<p>This notebook offers an introductory tutorial for rex (Robotic Environments with jaX), a JAX-based framework for creating graph-based environments tailored for sim2real robotics.</p> <p>In this tutorial, we will guide you through the process of defining nodes, which are the fundamental building blocks for constructing graph-based simulations and real-world systems within rex. Specifically, we will demonstrate how to define the nodes used in the sim2real.ipynb notebook.</p>"},{"location":"examples/node_definitions.html#introduction-to-nodes-in-rex","title":"Introduction to Nodes in Rex","text":"<p>In Rex, a node represents a fundamental computational unit within a graph-based system. Nodes encapsulate specific functionality and interact by passing data through connections, forming a network that can model complex systems. This tutorial introduces how to define nodes, specify their properties like rates and delays, and manage their interactions within a graph.</p>"},{"location":"examples/node_definitions.html#defining-nodes","title":"Defining Nodes","text":"<p>Nodes are defined by creating subclasses of the <code>BaseNode</code> class. This base class provides a standardized API and essential functionality that all nodes inherit. When defining a node, you can specify several parameters directly in the <code>__init__</code> method:</p> <ul> <li><code>name</code>: A unique identifier for the node.</li> <li><code>rate</code>: The frequency at which the node's <code>step</code> method is called (in Hz).</li> <li><code>delay</code> (optional): The expected computation delay of the node (in seconds).</li> <li><code>delay_dist</code>: A distribution representing variability in the node's computation delay, useful for simulations.</li> <li><code>advance</code>: If <code>True</code>, the node's <code>step</code> method triggers when all inputs are ready; if <code>False</code>, it throttles until the scheduled time.</li> <li><code>scheduling</code>: Determines how the node's execution is scheduled. Options include <code>Scheduling.FREQUENCY</code> and <code>Scheduling.PHASE</code>.</li> <li><code>color</code>: Used for visualization purposes.</li> <li><code>order</code>: Determines the node's order in visualizations.</li> </ul> <p>Here's a basic example of a node definition:</p> <pre><code>class MyNode(BaseNode):\n    def __init__(\n        self,\n        name: str,\n        rate: float,\n        delay: float = None,  # Expected computation delay (used for phase-shifting)\n        delay_dist: Union[DelayDistribution, distrax.Distribution] = None,  # Sim. computation delay\n        advance: bool = False,\n        scheduling: Scheduling = Scheduling.FREQUENCY,\n        color: str = None,\n        order: int = None\n    ):\n        super().__init__(name, rate, delay, delay_dist, advance, scheduling, color, order)\n        # Additional initialization if needed\n\n    def init_params(self, rng: jax.Array = None, graph_state: GraphState = None):\n        # Initialize parameters\n        return MyParams()\n\n    def init_state(self, rng: jax.Array = None, graph_state: GraphState = None):\n        # Initialize state\n        return MyState()\n\n    def init_output(self, rng: jax.Array = None, graph_state: GraphState = None):\n        # Initialize default output\n        return MyOutput()\n\n    def step(self, step_state: StepState):\n        # Node's computation logic\n        new_state = ...\n        output = ...\n        return step_state.replace(state=new_state), output\n</code></pre>"},{"location":"examples/node_definitions.html#connecting-nodes","title":"Connecting Nodes","text":"<p>Nodes interact by passing outputs from one node to the inputs of another. This is achieved through the <code>connect</code> method, which establishes a connection between two nodes.</p>"},{"location":"examples/node_definitions.html#connection-api","title":"Connection API","text":"<p>When connecting nodes, you can specify several parameters that control the nature of the connection:</p> <ul> <li><code>output_node</code>: node whose output will be connected as an input.</li> <li><code>blocking</code>: <code>True</code>, the receiving node waits for the input before proceeding. This can create dependencies between nodes.</li> <li><code>delay</code>: An additional delay introduced in the connection, which can control the phase shift between nodes.</li> <li><code>delay_dist</code>: Used in simulation to model communication delays between nodes.</li> <li><code>window</code>: Determines how many past messages are stored and accessible in the input buffer.</li> <li><code>skip</code>: If <code>True</code>, the connection is skipped when messages arrive simultaneously, helping resolve cyclic dependencies.</li> <li><code>jitter</code>: Controls how to handle irregularities in message timing (e.g., <code>Jitter.LATEST</code> uses the most recent message).</li> <li><code>name</code>: A shadow name for the input; defaults to the output node's name.</li> </ul>"},{"location":"examples/node_definitions.html#including-delay_dist-in-connection","title":"Including <code>delay_dist</code> in Connection","text":"<p>The <code>delay_dist</code> parameter allows you to specify a distribution that models the variability in communication delay between nodes. This is particularly useful in simulations where network latency or message passing delays are significant.</p>"},{"location":"examples/node_definitions.html#resolving-cyclic-dependencies-with-skip","title":"Resolving Cyclic Dependencies with <code>skip</code>","text":"<p>In graphs where nodes depend on each other's outputs (creating a cycle), the <code>skip</code> parameter can be used to resolve the dependency. By setting <code>skip=True</code> on a connection, you instruct the receiving node to proceed without waiting for the current message if it arrives simultaneously. This breaks the cycle and allows the system to function.</p>"},{"location":"examples/node_definitions.html#example-connection","title":"Example Connection","text":"<pre><code>node_a.connect(\n    output_node=node_b,\n    blocking=True,\n    delay=0.01,  # Expected communication delay (used for phase-shifting)\n    delay_dist=distrax.Normal(loc=0.01, scale=0.005), # Sim. communication delay\n    window=5,\n    skip=False,\n    jitter=Jitter.LATEST,\n    name=\"input_from_b\"\n)\n</code></pre> <p>In this example, <code>node_a</code> connects to <code>node_b</code> with a blocking connection, an added delay of 0.01 seconds, and a delay distribution for simulation purposes. The <code>window</code> size is set to 5, meaning the last five messages are stored. The <code>skip</code> parameter is <code>False</code>, so the node will wait for the input.</p>"},{"location":"examples/node_definitions.html#node-data-structure","title":"Node Data Structure","text":"<p>Nodes manage four main types of data (defined as pytrees), typically defined using immutable dataclasses for efficiency and safety:</p> <ol> <li>Parameters: Static configurations that usually remain constant during execution.</li> <li>State: Dynamic data that evolves over time with each <code>step</code>.</li> <li>Outputs: Data produced by a node's <code>step</code> method and sent to connected nodes.</li> <li>Inputs: Buffers that hold incoming data from other nodes, respecting the specified window size.</li> </ol>"},{"location":"examples/node_definitions.html#immutable-dataclasses","title":"Immutable Dataclasses","text":"<p>Using immutable dataclasses (e.g., via <code>@struct.dataclass</code> from Flax) ensures that the data structures are compatible with JAX's JIT compilation and functional programming paradigms. Additionally, dataclasses allow you to define specific methods related to the data structure, providing encapsulation and clarity.</p> <pre><code>@struct.dataclass\nclass MyParams:\n    some_parameter: float\n\n    def adjust_parameter(self, factor: float):\n        return self.replace(some_parameter=self.some_parameter * factor)\n\n@struct.dataclass\nclass MyState:\n    some_state_variable: jax.Array\n\n    def update_state(self, delta: jax.Array):\n        return self.replace(some_state_variable=self.some_state_variable + delta)\n\n@struct.dataclass\nclass MyOutput:\n    some_output_data: jax.Array\n</code></pre> <p>In this example, <code>MyParams</code> and <code>MyState</code> include methods to adjust parameters and update state, respectively. This encapsulation enhances code organization and readability.</p>"},{"location":"examples/node_definitions.html#initialization","title":"Initialization","text":"<p>Node data is initialized using specific methods that you should override:</p> <ul> <li><code>init_params</code>: Initializes the node's parameters.</li> <li><code>init_state</code>: Initializes the node's state.</li> <li><code>init_output</code>: Provides a default output, useful for initializing input buffers in connected nodes.</li> </ul> <p>These methods are typically called during the graph's initialization phase using <code>graph.init()</code>.</p>"},{"location":"examples/node_definitions.html#the-step-method-in-detail","title":"The <code>step</code> Method in Detail","text":"<p>The <code>step</code> method defines how a node processes inputs and updates its state at each timestep. It receives a <code>StepState</code> object with all necessary information.</p>"},{"location":"examples/node_definitions.html#stepstate-attributes","title":"<code>StepState</code> Attributes","text":"<ul> <li><code>rng</code>: Random number generator (updated if used).</li> <li><code>state</code>: Node's current state.</li> <li><code>params</code>: Static parameters influencing behavior.</li> <li><code>inputs</code>: Dictionary of <code>InputState</code> instances (keyed by input names).</li> <li><code>eps</code>: Episode number relates to the current computation graph used for simulation (unrelated to RL episode number).</li> <li><code>seq</code>: Current step number (auto-increments with each step).</li> <li><code>ts</code>: Timestamp at the start of the step.</li> </ul>"},{"location":"examples/node_definitions.html#accessing-inputs","title":"Accessing Inputs","text":"<p>Each <code>InputState</code> in <code>step_state.inputs</code> contains:</p> <ul> <li><code>data</code>: Messages from the connected node.</li> <li><code>seq</code>: Sequence numbers of the received messages.</li> <li><code>ts_sent</code>: Timestamps when messages were sent.</li> <li><code>ts_recv</code>: Timestamps when messages were received.</li> </ul> <p>For example, accessing the most recent message:</p> <pre><code>latest_sensor_input = step_state.inputs['sensor'][-1].data\n</code></pre>"},{"location":"examples/node_definitions.html#implementing-the-step-method","title":"Implementing the <code>step</code> Method","text":"<p>The typical steps to implement the <code>step</code> method can be condensed into the following block:</p> <pre><code>def step(self, step_state: StepState):\n    # Unpack StepState\n    rng, state, params, inputs = step_state.rng, step_state.state, step_state.params, step_state.inputs\n\n    # Access latest input\n    control_signal = inputs['controller'][-1].data\n\n    # Update state\n    new_state_variable = state.some_state_variable + control_signal * params.gain\n    new_state = state.replace(some_state_variable=new_state_variable)\n\n    # Produce output\n    output = MyOutput(some_output_data=new_state_variable)\n\n    # Update RNG if randomness is involved\n    rng, _ = jax.random.split(rng)\n\n    # Return updated StepState and output\n    return step_state.replace(state=new_state, rng=rng), output\n</code></pre>"},{"location":"examples/node_definitions.html#working-with-time-and-sequence","title":"Working with Time and Sequence","text":"<p>Use <code>eps</code>, <code>ts</code> and <code>seq</code> for time-dependent logic:</p> <pre><code>if step_state.ts &gt; params.activation_time:\n    # Perform time-based logic\n    pass\n</code></pre>"},{"location":"examples/node_definitions.html#handling-input-windows","title":"Handling Input Windows","text":"<p>If the input window size is greater than 1, you can access past messages:</p> <pre><code>recent_sensor_data = inputs['sensor_input'][-3:].data\n</code></pre>"},{"location":"examples/node_definitions.html#jit-compilation-and-side-effects-handling-with-external-callbacks","title":"JIT Compilation and Side Effects Handling with External Callbacks","text":"<p>Rex advocates for JIT-compiling the <code>step</code> method of each node to enhance performance. However, interfacing with real hardware often involves side effects that JAX's JIT compilation doesn't handle natively.</p> <p>To include side-effecting code (e.g., sending commands to actuators, reading sensor data), you must use JAX's external callback mechanism. This involves wrapping side-effecting functions with <code>jax.experimental.io_callback</code> to ensure compatibility with JIT compilation.</p> <p>Refer to the JAX documentation on external callbacks for detailed guidance.</p> <pre><code>def step(self, step_state: StepState):\n    # Compute outputs\n    output = ...\n\n    # Side-effecting function\n    def _apply_action(action):\n        # Code that interacts with hardware\n        return np.array(1.0)  # Dummy return value\n\n    # Wrap side-effecting code\n    _ = jax.experimental.io_callback(\n        _apply_action,\n        result_shape=jnp.array(1.0),\n        arg=output.some_output_data\n    )\n\n    # Update state and return\n    return step_state, output\n</code></pre>"},{"location":"examples/node_definitions.html#real-world-nodes-and-lifecycle-methods","title":"Real-World Nodes and Lifecycle Methods","text":"<p>When nodes interface with real hardware or external systems, additional lifecycle management is necessary. The <code>BaseNode</code> API accommodates this through:</p> <ul> <li><code>startup</code>: Called before an episode starts, allowing the node to prepare (e.g., initialize hardware).</li> <li><code>stop</code>: Called after an episode ends, enabling the node to clean up resources or safely shut down hardware.</li> </ul> <pre><code>class RealWorldNode(BaseNode):\n    def __init__(\n        self,\n        name: str,\n        rate: float,\n        delay: float = None,\n        delay_dist: Union[DelayDistribution, distrax.Distribution] = None,\n        advance: bool = False,\n        scheduling: Scheduling = Scheduling.FREQUENCY,\n        color: str = None,\n        order: int = None\n    ):\n        super().__init__(name, rate, delay, delay_dist, advance, scheduling, color, order)\n        # Additional initialization if needed\n\n    def startup(self, graph_state: GraphState, timeout: float = None):\n        # Initialize hardware connections\n        return True  # Return True if successful\n\n    def stop(self, timeout: float = None):\n        # Safely shut down hardware\n        return True\n</code></pre>"},{"location":"examples/node_definitions.html#summary","title":"Summary","text":"<p>By following these guidelines, you can define robust and efficient nodes within the Rex framework. Nodes can be customized extensively through their parameters and state, connected flexibly to form complex graphs, and optimized using JIT compilation. Proper handling of side effects ensures that nodes interfacing with real-world systems remain performant and reliable.</p> <p>In the following examples, we'll implement specific nodes that illustrate these concepts in practice.</p>"},{"location":"examples/sim2real.html","title":"Sim2real with a pendulum","text":"<pre><code># @title Install Necessary Libraries\n# @markdown This cell installs the required libraries for the project.\n# @markdown If you are running this notebook in Google Colab, most libraries should already be installed.\n\nimport multiprocessing\nimport os\n\n\nos.environ[\"XLA_FLAGS\"] = \"--xla_force_host_platform_device_count={}\".format(max(multiprocessing.cpu_count(), 1))\nos.environ[\"XLA_PYTHON_CLIENT_PREALLOCATE\"] = \"false\"\n\ntry:\n    import rex\n\n    print(\"Rex already installed\")\nexcept ImportError:\n    print(\n        \"Installing rex via `pip install rex-lib[examples]`. \"\n        \"If you are running this in a Colab notebook, you can ignore this message.\"\n    )\n    !pip install rex-lib[examples]\n    import rex\n</code></pre> <pre><code># @title Import Libraries &amp; Check GPU Availability\n# @markdown We import all necessary libraries here, including JAX, numpy, and others.\n# @markdown Additionally, we check if a GPU is available and display the number of CPU cores.\n\nimport functools\nimport itertools\n\nimport equinox as eqx\nimport jax\nimport jax.numpy as jnp\nimport matplotlib\nimport matplotlib.pyplot as plt\nimport seaborn as sns\nimport supergraph\nimport tqdm\nfrom distrax import Normal\nfrom IPython.display import HTML\n\n\nsns.set()\n\nimport rex.base as base\nimport rex.utils as rutils\nfrom rex.base import TrainableDist\nfrom rex.constants import Clock, RealTimeFactor\nfrom rex.open_colors import ecolor, fcolor\n\n\n# Check if we have a GPU\ntry:\n    gpu = jax.devices(\"gpu\")\n    gpu = gpu[0] if len(gpu) &gt; 0 else None\n    print(\"GPU found!\")\nexcept RuntimeError:\n    print(\"Warning: No GPU found, falling back to CPU. Speedups will be less pronounced.\")\n    print(\n        \"Hint: if you are using Google Colab, try to change the runtime to GPU: \"\n        \"Runtime -&gt; Change runtime type -&gt; Hardware accelerator -&gt; GPU.\"\n    )\n    gpu = None\n\n# Check the number of available CPU cores\nprint(f\"CPU cores available: {len(jax.devices('cpu'))}\")\ncpus = itertools.cycle(jax.devices(\"cpu\"))\n</code></pre> <pre>\n<code>GPU found!\nCPU cores available: 16\n</code>\n</pre> <pre><code># @title Define Pendulum System as an Interconnection of Nodes\n# @markdown We will use nodes defined in the pendulum example to simulate the system.\n# @markdown Since we do not have access to a real-world pendulum, the Brax simulation will act as our \"real-world\" system.\n# @markdown Data from Brax will help us identify the delays and parameters of a simple ODE model.\n# @markdown In a separate notebook, we demonstrate how to define nodes.\n# @markdown Optionally, you can test the system with zero delays by uncommenting the relevant code.\nimport rex.examples.pendulum as pdm\n\n\n# `Color` and `order` arguments are merely for visualization purposes.\n# Delay distributions are used to simulate the delays as if the nodes were real-world systems.\n# For real-world systems, it is normally not necessary to specify the delay distributions.\nsensor = pdm.Sensor(\n    name=\"sensor\",\n    rate=50,\n    color=\"pink\",\n    order=1,  # Sensor that reads the angle from the pendulum\n    delay_dist=Normal(loc=0.0075, scale=0.003),\n)  # Computation delay of the sensor\nagent = pdm.Agent(\n    name=\"agent\",\n    rate=50,\n    color=\"teal\",\n    order=3,  # Agent that generates random actions\n    delay_dist=Normal(loc=0.01, scale=0.003),\n)  # Computation delay of the agent\nactuator = pdm.Actuator(\n    name=\"actuator\",\n    rate=50,\n    color=\"orange\",\n    order=2,  # Actuator that applies the action to the pendulum\n    delay_dist=Normal(loc=0.0075, scale=0.003),\n)  # Computation delay of the actuator\n# Computation delay of the world is the world's step size (i.e. 1/rate)\nworld = pdm.BraxWorld(name=\"world\", rate=50, color=\"grape\", order=0)  # Brax world that simulates the pendulum\nnodes = dict(world=world, sensor=sensor, agent=agent, actuator=actuator)\n\n# Connect nodes\n# The window determine the buffer size, i.e., the number of previous messages that are stored and can be accessed\n# in the .step() method of the node. The window should be at least 1, as the most recent message is always stored.\n# Blocking connections are synchronous, i.e., the receiving node waits for the sending node to send a message.\n# The window determines the number of messages that are stored and can be accessed in the .step() method of the node.\nagent.connect(\n    sensor,\n    window=3,\n    name=\"sensor\",\n    blocking=True,  # Use the last three sensor messages as input (sync communication)\n    delay_dist=Normal(loc=0.002, scale=0.002),\n)  # Communication delay of the sensor\nactuator.connect(\n    agent,\n    window=1,\n    name=\"agent\",\n    blocking=True,  # Agent receives the most recent action (sync communication)\n    delay_dist=Normal(loc=0.002, scale=0.002),\n)  # Communication delay of the agent\n\n# Connections below would not be necessary in a real-world system,\n# but are used to communicate the action to brax, and convert brax's state to a sensor message\n# Delay distributions are used to simulate the delays in the real-world system\nsensor_delay, actuator_delay = 0.01, 0.01\nstd_delay = 0.002\nworld.connect(\n    actuator,\n    window=1,\n    name=\"actuator\",\n    skip=True,  # Sends the action to the brax world (skip=True to resolve circular dependency)\n    delay_dist=Normal(loc=actuator_delay, scale=std_delay),\n)  # Actuator delay between applying the action, and the action being effective in the world\nsensor.connect(\n    world,\n    window=1,\n    name=\"world\",  # Communicate brax's state to the sensor node\n    delay_dist=Normal(loc=sensor_delay, scale=std_delay),\n)  # Sensor delay between reading the state, and the world's state corresponding to the sensor reading.\n\n# If you want to test with zero delays, uncomment below.\n# sensor_delay, actuator_delay = 0.0, 0.0\n# std_delay = 0.0\n# for n in [sensor, agent, actuator]:\n#     n.set_delay(delay_dist=Deterministic(loc=0.0), delay=0.0)\n#     for i in n.inputs.values():\n#         i.set_delay(delay_dist=Deterministic(loc=0.0), delay=0.0)\n# world.inputs[\"actuator\"].set_delay(delay_dist=Deterministic(loc=0.0), delay=0.0)\n\n# Visualize the system\nnode_infos = {name: n.info for name, n in nodes.items()}\nfig, ax = plt.subplots(1, 1, figsize=(8, 3))\nrutils.plot_system(node_infos, ax=ax, k=1)\nax.legend()\nax.set_title(\"Brax System\");\n</code></pre> <pre>\n<code>2024-10-09 16:49:09.740762: W external/xla/xla/service/gpu/nvptx_compiler.cc:765] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.6.77). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.\n</code>\n</pre> <pre><code># @title Apply Open-Loop Control to the Pendulum System to Gather Data\n# @markdown This section collects data such as delays, actions, and sensor readings,\n# @markdown by applying open-loop control to the simulated pendulum.\n\n# Build the graph\n# Note that one of the nodes is designated as the supervisor (agent).\n# To make a comparison with the standard Gym-like approach, the supervisor node is the agent, and the other nodes are the environment.\n# This means that the graph will be executed in a step-by-step manner, where the agent's rate determines the rate of the environment.\nfrom rex.asynchronous import AsyncGraph\nfrom rex.constants import LogLevel\nfrom rex.utils import set_log_level\n\n\ngraph = AsyncGraph(\n    nodes=nodes,\n    supervisor=nodes[\"agent\"],\n    # Settings for simulating at fast as possible speed according to specified delays\n    clock=Clock.SIMULATED,\n    real_time_factor=RealTimeFactor.FAST_AS_POSSIBLE,\n    # Settings for simulating at real-time speed according to specified delays\n    # clock=Clock.SIMULATED, real_time_factor=RealTimeFactor.REAL_TIME,\n    # Settings for real-world deployment\n    # clock=Clock.WALL_CLOCK, real_time_factor=RealTimeFactor.REAL_TIME,\n)\n\n# Specify what we want to record (params, state, output) for each node,\ngraph.set_record_settings(params=True, inputs=False, state=True, output=True)\n\n# Get initial graph state (aggregate of all node states)\nrng = jax.random.PRNGKey(2)\nrng, rng_init = jax.random.split(rng)\n# 'order' defines the order in which the nodes must be initialized (some node initialization procedures may depend on the result of others)\ngs_init = graph.init(rng_init, order=(\"agent\",))\ngs_init_real = gs_init  # Used later for evaluating the trained model from the same initial state\n\n# Ahead-of-time compilation of the step method of each node\n# Place all nodes on the CPU, except the agent, which is placed on the GPU (if available)\n[set_log_level(LogLevel.DEBUG, n) for n in nodes.values()]  # Silence the log output\ndevices_step = {k: next(cpus) if k != \"agent\" or gpu is None else gpu for k in nodes}\ngraph.warmup(gs_init, devices_step, jit_step=True, profile=True)  # Profile=True for profiling the step function\n\n# Prepare open-loop action sequence\nrng, rng_actions = jax.random.split(rng)\ndt_action = 2.0\nnum_actions = 6\nactions = jnp.array([-1.7, 1.7, -1, 1, 0.0, 0.1])[:, None]\nactions = jnp.repeat(\n    actions, int(jnp.ceil(dt_action * nodes[\"agent\"].rate)), axis=0\n)  # Repeat actions for the duration of the agent's rate\nnum_steps = actions.shape[0]\n\n# Execution: Gym-like API with .reset() &amp; .step() methods\n# We use the graph state obtained with .init() and perform step-by-step execution with .reset() and .step().\ngs, ss = graph.reset(gs_init)  # Reset the graph to the initial state (returns the gs and the step state of the agent)\nfor i in tqdm.tqdm(range(num_steps), desc=\"brax | gather data\"):\n    # Access the last sensor message of the input buffer\n    # -1 is the most recent message, -2 the second most recent, etc. up until the window size\n    sensor_msg = ss.inputs[\"sensor\"][-1].data  # .data grabs the pytree message object\n    action = actions[i]  # Get the action for the current time step\n    output = ss.params.to_output(action)  # Convert the action to an output message\n    # Step the graph (i.e., executes the next time step by sending the output message to the actuator node)\n    gs, ss = graph.step(gs, ss, output)  # Step the graph with the agent's output\ngraph.stop()  # Stops all nodes that were running asynchronously in the background\n\n# Get the episode data (params, delays, outputs, etc.)\nrecord = graph.get_record()  # Gets the records of all nodes\n\n# Filter out the world node, as it would not be available in a real-world system\nrollout_real = record.nodes[\"world\"].steps.state\nnodes_real = {name: n for name, n in nodes.items() if name != \"world\"}\nrecord = record.filter(nodes_real)\n</code></pre> <pre>\n<code>brax | gather data: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 600/600 [00:02&lt;00:00, 204.31it/s]\n</code>\n</pre> <pre><code># @title Visualize Actions and Sensor Readings\n# @markdown The plots below display the actions and sensor readings as might be observed in a real-world system.\n\nfig_data, axes = plt.subplots(1, 3, figsize=(12, 3))\naxes[0].plot(record.nodes[\"agent\"].steps.ts_end[:-1], record.nodes[\"agent\"].steps.output.action, label=\"action\")\naxes[0].set_xlabel(\"Time [s]\")\naxes[0].set_ylabel(\"Torque [Nm]\")\naxes[0].legend()\n\naxes[1].plot(record.nodes[\"sensor\"].steps.ts_end, record.nodes[\"sensor\"].steps.output.th, label=\"th\")\naxes[1].set_xlabel(\"Time [s]\")\naxes[1].set_ylabel(\"Angle [rad]\")\naxes[1].legend()\n\naxes[2].plot(record.nodes[\"sensor\"].steps.ts_end, record.nodes[\"sensor\"].steps.output.thdot, label=\"thdot\")\naxes[2].set_xlabel(\"Time [s]\")\naxes[2].set_ylabel(\"Ang. Vel. [rad/s]\")\naxes[2].legend();\n</code></pre> <pre><code># @title Fit GMM to Communication and Computation Delays\n# @markdown We will fit a Gaussian Mixture Model (GMM) to the delays observed in the communication between the sensor and agent,\n# @markdown as well as the computation delay of the agent\u2019s step method.\n# @markdown Other delays, such as actuator delays, can be fitted similarly.\n\nfrom rex.gmm_estimator import GMMEstimator\n\n\n# Fit GMM to communication delay between sensor and agent\ndelay_comm = record.nodes[\"agent\"].inputs[\"sensor\"].messages.delay\ngmm_comm = GMMEstimator(delay_comm, \"communication_delay\")\ngmm_comm.fit(num_steps=100, num_components=2, step_size=0.05, seed=0)\ndist_comm = gmm_comm.get_dist()\n\n# Fit GMM to computation delay of the agent's step method\ndelay_comp = record.nodes[\"agent\"].steps.delay\ngmm_comp = GMMEstimator(delay_comp, \"computation_delay\")\ngmm_comp.fit(num_steps=100, num_components=2, step_size=0.05, seed=0)\ndist_comp = gmm_comp.get_dist()\n</code></pre> <pre>\n<code>communication_delay | Time taken: 1.38 seconds.\ncomputation_delay | Time taken: 1.22 seconds.\n</code>\n</pre> <pre><code># @title Visualize Fitted GMM for Delays\n# @markdown This cell plots the GMM fitting process to the delays for both the sensor-agent communication and the agent's computation.\n\n%matplotlib agg\n\n# Plot GMMs\n# with plt.ioff():\nfig_gmm, axes = plt.subplots(1, 2, figsize=(8, 3))\ngmm_comm.plot_hist(ax=axes[0], edgecolor=ecolor.communication, facecolor=fcolor.communication, plot_dist=True)\naxes[0].set_title(\"Delay (sensor-&gt;agent)\")\ngmm_comp.plot_hist(ax=axes[1], edgecolor=ecolor.computation, facecolor=fcolor.computation, plot_dist=False)\naxes[1].set_title(\"Delay (agent.step)\")\nfor ax, dist in zip(axes, [dist_comm, dist_comp]):\n    ax.xaxis.set_major_formatter(matplotlib.ticker.FormatStrFormatter(\"%.3f\"))\n    ax.tick_params(axis=\"both\", which=\"major\", labelsize=10)\n    ax.tick_params(axis=\"both\", which=\"minor\", labelsize=8)\n    ax.set_xlabel(\"delay (s)\", fontsize=10)\n    ax.set_ylabel(\"density\", fontsize=10)\n    ax.set_xlim([0, dist.quantile(0.99)])  # Limit the x-axis to the 99th percentile of the delay\n\n# Animate training\nani = gmm_comp.animate_training(fig=fig_gmm, ax=axes[1], num_frames=50)\n# If you are running into an AttributeError regarding \"_val_or_rc\", skip the HTML display and run the next cell.\n# This seems to be a Python 3.10 + matplotlib 3.9.x issue.\n# Resolve by downgrading matplotlib to 3.7.x. Run `!pip install matplotlib==3.7.5`.\nHTML(ani.to_html5_video())\n</code></pre>    Your browser does not support the video tag.  <pre><code># @title Visualize Data Flow in the Real-World System\n# @markdown The top plot shows how long each node takes to process data and forward it to the next node.\n# @markdown The bottom plot provides a graph representation that will form the basis for the computational graph used for system identification.\n# @markdown - Each vertex represents a step call of a node, and each edge represents message transmission between two nodes.\n# @markdown - Edges between consecutive steps of the same node represent the transmission of the internal state of the node.\n# @markdown - Nodes start processing after an initial phase-shift, which can be controlled in the node definition.\n\nplt.close(fig_gmm)  # Close fig_gmm to prevent it from displaying in the next cell\n%matplotlib inline\n\ndf = record.to_graph()\ntiming_mode = \"arrival\"  # \"arrival\" or \"usage\"\nG = rutils.to_networkx_graph(df, nodes=nodes)\nfig, axes = plt.subplots(2, 1, figsize=(12, 6))\nrutils.plot_graph(\n    G,\n    max_x=0.5,\n    ax=axes[0],\n    message_arrow_timing_mode=timing_mode,\n    edge_linewidth=1.4,\n    arrowsize=10,\n    show_labels=True,\n    height=0.6,\n    label_loc=\"center\",\n)\nsupergraph.plot_graph(G, max_x=0.5, ax=axes[1])\nfig.suptitle(\"Real-world data flow from recording\")\naxes[-1].set_xlabel(\"Time [s]\");\n</code></pre> <pre><code># @title Build an ODE Simulation Environment to Identify Hidden Delays and Parameters\n# @markdown We use collected data to build and identify delays and parameters in a simple ODE model.\n# @markdown This model incorporates the communication and computation delays identified for the agent.\n\n# Prepare the recorded data that we are going to use for system identification\noutputs = {name: n.steps.output[None] for name, n in record.nodes.items()}\n\n# By reinitializing the nodes via the `from_info` method, we can reuse the exact same configuration (rate, delay_dist, etc.).\n# We can overwrite (e.g., delay_dist) or specify extra parameters (e.g., outputs) as keyword arguments.\n# The info data is stored in the record, but can also be obtained from the nodes themselves with node.info.\nactuator = pdm.SimActuator.from_info(\n    record.nodes[\"actuator\"].info, outputs=outputs[\"actuator\"]\n)  # Actuator data to replay the actions\nsensor = pdm.SimSensor.from_info(\n    record.nodes[\"sensor\"].info, outputs=outputs[\"sensor\"]\n)  # Sensor data to calculate reconstruction error\nagent = pdm.Agent.from_info(record.nodes[\"agent\"].info, delay_dist=dist_comp)\nnodes_sim = dict(sensor=sensor, agent=agent, actuator=actuator)\n\n# Connect nodes according to real-world system\n[n.connect_from_info(record.nodes[name].info.inputs, nodes_sim) for name, n in nodes_sim.items()]\n\n# Create the world node that is going to simulate the ODE system\nworld = pdm.OdeWorld.from_info(\n    nodes[\"world\"].info\n)  # Initialize OdeWorld with the same parameters (rate, etc.) as the brax world\n\n# Next, we connect the world node to the nodes that interface with hardware (actuator and sensor)\n# We specify trainable delays to represent sensor and actuator delays that we want to identify in addition to the ode parameters\nworld.connect(\n    actuator,\n    window=1,\n    name=\"actuator\",\n    skip=True,  # Sends the action to the ODE world (skip=True to resolve circular dependency)\n    # Trainable delay to represent the actuator delay\n    # delay, min, and max are seconds, interp in [\"zoh\", \"linear\"]\n    delay_dist=TrainableDist.create(delay=0.0, min=0, max=0.3, interp=\"linear\"),\n)\nsensor.connect(\n    world,\n    window=1,\n    name=\"world\",  # Communicate the ODE world's state to the sensor node\n    # Trainable delay to represent the sensor delay\n    # delay, min, and max are seconds, interp in [\"zoh\", \"linear\"]\n    delay_dist=TrainableDist.create(delay=0.0, min=0, max=0.3, interp=\"linear\"),\n)\nnodes_sim[\"world\"] = world  # Add the world node to the nodes\n\n# Visualize the system\nnode_infos = {name: n.info for name, n in nodes_sim.items()}\nfig, ax = plt.subplots(1, 1, figsize=(8, 3))\nrutils.plot_system(node_infos, ax=ax, k=1)\nax.legend()\nax.set_title(\"ODE System\");\n</code></pre> <pre><code># @title Build Computational Graph for System Identification\n# @markdown This graph includes vertices representing simulator (i.e. world) steps and edges representing sensor and actuator\n# @markdown delays between the world and the sensor/actuator nodes.\n# @markdown The min/max values from the trainable delay distributions are used to define these edges.\n\nrng, rng_aug = jax.random.split(rng)\ncg = rex.artificial.augment_graphs(df, nodes_sim, rng_aug)\ntiming_mode = \"arrival\"  # \"arrival\" or \"usage\"\nG = rutils.to_networkx_graph(cg, nodes=nodes)\nfig, axes = plt.subplots(2, 1, figsize=(12, 6))\nrutils.plot_graph(\n    G,\n    max_x=0.5,\n    ax=axes[0],\n    message_arrow_timing_mode=timing_mode,\n    edge_linewidth=1.4,\n    arrowsize=10,\n    show_labels=True,\n    height=0.6,\n    label_loc=\"center\",\n)\nsupergraph.plot_graph(G, max_x=0.5, ax=axes[1])\nfig.suptitle(\"Computation graph (extended with simulator nodes)\")\naxes[-1].set_xlabel(\"Time [s]\");\n</code></pre> <pre><code># @title Define Subset of Trainable Parameters (Delays and ODE Parameters)\n# @markdown The following loop describes the training process for identifying hidden delays and system parameters:\n# @markdown 1. Sample normalized parameters from a search distribution.\n# @markdown 2. Denormalize based on parameter min/max values.\n# @markdown 3. Extend trainable parameters with non-trainable ones.\n# @markdown 4. Run simulation and collect reconstruction errors.\n# @markdown 5. Update search distribution based on the error.\n# @markdown 6. Repeat until convergence.\n\n# Initialize a graph that can be compiled and parallelized for system identification\n# Note, we could choose to skip running the agent node for computational efficiency,\n# as we know it does not affect the world node in this case, as we are replaying the actions in the actuator node.\ngraph_sim = rex.graph.Graph(nodes_sim, nodes_sim[\"agent\"], cg)\n\n# Get initial graph state (aggregate of all node states)\nrng, rng_init = jax.random.split(rng)\ngs_init = graph_sim.init(rng_init, order=(\"agent\",))\ngs_init_sim = gs_init\n\n# Define the set of trainable parameters and the initial values\n# We only want to optimize for a subset of the parameters, e.g., the delays and the parameters of the ODE system.\n# Hence, we take all parameters, set them them to None (i.e., not trainable),\n# and then set the ones we want to optimize to trainable values.\nbase_params = gs_init.params.unfreeze().copy()  # Get base structure for params\ninit_params = jax.tree_util.tree_map(lambda x: None, base_params)  # Set all parameters to None (i.e. not trainable)\ninit_params[\"world\"] = init_params[\"world\"].replace(\n    J=0.0001,  # Inertia of the pendulum (trainable)\n    mass=0.05,  # Mass of the pendulum (trainable)\n    length=0.03,  # Length of the pendulum (trainable)\n    b=1.0e-05,  # Damping of the pendulum (trainable)\n    K=0.02,  # Spring constant of the pendulum (trainable)\n    R=5.0,  # DC-motor resistance of the pendulum (trainable)\n    c=0.0007,\n)  # Coulomb friction of the pendulum (trainable)\ninit_params[\"sensor\"] = init_params[\"sensor\"].replace(sensor_delay=0.15)  # Sensor delay (trainable)\ninit_params[\"actuator\"] = init_params[\"actuator\"].replace(actuator_delay=0.15)  # actuator delay (trainable)\ninit_params[\"agent\"] = init_params[\"agent\"].replace(\n    init_method=\"parametrized\",  # Set to \"parametrized\" avoid random state initialization\n    parametrized=jnp.array([0.5 * jnp.pi, 0.0]),\n)  # Initial state (trainable)\n\n# Print the initial parameters\nprint(\"Initial parameters (None means not trainable, some are static):\")\neqx.tree_pprint(init_params, short_arrays=False)\n\n# It's also good practice to perform a search over normalized parameters, provided we are given a min and max for each parameter.\nmin_params, max_params = init_params.copy(), init_params.copy()  # Get base structure for min and max params\n# Set the min and max for the ODE parameters\nmin_params[\"world\"] = jax.tree_util.tree_map(lambda x: x * 0.25, min_params[\"world\"])  # Set the min for the ODE parameters\nmax_params[\"world\"] = jax.tree_util.tree_map(lambda x: x * 2.0, max_params[\"world\"])  # Set the max for the ODE parameters\n# Set the min and max for the delays\nmin_params[\"sensor\"] = min_params[\"sensor\"].replace(sensor_delay=0.0)  # Set the min for the sensor delay\nmax_params[\"sensor\"] = max_params[\"sensor\"].replace(sensor_delay=0.3)  # Set the max for the sensor delay\nmin_params[\"actuator\"] = min_params[\"actuator\"].replace(actuator_delay=0.0)  # Set the min for the actuator delay\nmax_params[\"actuator\"] = max_params[\"actuator\"].replace(actuator_delay=0.3)  # Set the max for the actuator delay\n# Ensure agent's initial state has a non-zero range, as 0.5*0 = 0, and 1.5*0 = 0\n# if max_params[\"agent\"].parametrized is not None:  # todo: remove\nmin_params = eqx.tree_at(\n    lambda _min: _min[\"agent\"].parametrized, min_params, jnp.array([-jnp.pi, -0.2])\n)  # Set the min for the initial state\nmax_params = eqx.tree_at(\n    lambda _max: _max[\"agent\"].parametrized, max_params, jnp.array([jnp.pi, 0.2])\n)  # Set the max for the initial state\n\n# Next, we define the transform that transforms the normalized candidate parameters to the full parameter structure\n# First, we denormalize the parameters, then extend with the non-trainable parameters (e.g., max_speed of the ODE world)\ndenorm = base.Denormalize.init(min_params, max_params)  # Create a transform to denormalize a set of normalized parameters\nextend = base.Extend.init(base_params, init_params)  # Create a transform to extend the trainable params with the non-trainable\ndenorm_extend = base.Chain.init(denorm, extend)\n\n# Normalize the initial, min, and max parameters\nnorm_init_params = denorm.inv(init_params)  # Normalize the initial parameters\nnorm_min_params = denorm.inv(min_params)  # Normalize the min parameters\nnorm_max_params = denorm.inv(max_params)  # Normalize the max parameters\n</code></pre> <pre>\n<code>Growing supergraph: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 601/601 [00:00&lt;00:00, 907.73it/s, 1/1 graphs, 2411/2411 matched (66.86% efficiency, 6 nodes (pre-filtered: 6 nodes))]\n</code>\n</pre> <pre>\n<code>Initial parameters (None means not trainable, some are static):\n{\n  'actuator':\n  ActuatorParams(actuator_delay=0.15),\n  'agent':\n  AgentParams(\n    policy=None,\n    num_act=4,\n    num_obs=4,\n    max_torque=None,\n    init_method='parametrized',\n    parametrized=Array([1.5707964, 0.       ], dtype=float32),\n    max_th=None,\n    max_thdot=None,\n    gamma=None,\n    tmax=None\n  ),\n  'sensor':\n  SensorParams(sensor_delay=0.15),\n  'world':\n  OdeParams(\n    max_speed=None,\n    J=0.0001,\n    mass=0.05,\n    length=0.03,\n    b=1e-05,\n    K=0.02,\n    R=5.0,\n    c=0.0007,\n    dt_substeps_min=0.01,\n    dt=0.02\n  )\n}\n</code>\n</pre> <pre><code># @title Define Loss Function for Identifying Delays and Parameters\n# @markdown This function calculates the reconstruction error for a given set of normalized parameters.\n# @markdown The error is used to guide optimization during the training process.\n\n\ndef get_loss(norm_params, transform, rng):\n    # Transform normalized parameters to full parameter structure\n    params = transform.apply(norm_params)  # := denorm_extend.apply(norm_params)\n\n    # Initialize the graph state\n    # By supplying the params, we override the params generated by every node's init_params method\n    # This allows us to run the graph with the specified parameters\n    gs_init = graph_sim.init(rng=rng, params=params, order=(\"agent\",))\n\n    # Rollout graph\n    final_gs = graph_sim.rollout(gs_init, carry_only=True)\n\n    # Get the reconstruction error\n    loss_th = final_gs.state[\"sensor\"].loss_th\n    loss_thdot = final_gs.state[\"sensor\"].loss_thdot\n    loss = loss_th + loss_thdot\n    return loss\n\n\n# Get cost of initial guess\ninit_loss = get_loss(norm_init_params, denorm_extend, rng)  # Get the initial loss\nprint(f\"Loss of initial guess: {init_loss}\")  # Loss using the initial parameters\n</code></pre> <pre>\n<code>Loss of initial guess: 11958.76171875\n</code>\n</pre> <pre><code># @title Initialize CMA-ES Solver for System Identification\n# @markdown We will use the Covariance Matrix Adaptation Evolution Strategy (CMA-ES) to optimize parameters.\n# @markdown The solver is initialized with the normalized parameter bounds (min and max).\n\nimport rex.evo as evo\n\n\n# Initialize the solver\nmax_steps = 50  # Number of optimization steps\nstrategy_kwargs = dict(popsize=200, elite_ratio=0.1, sigma_init=0.4, mean_decay=0.0, n_devices=1)\nsolver = evo.EvoSolver.init(norm_min_params, norm_max_params, strategy=\"CMA_ES\", strategy_kwargs=strategy_kwargs)\ninit_sol_state = solver.init_state(norm_init_params)  # Initialize the solver state\n\n# Run the optimization\nrng, rng_sol = jax.random.split(rng)\ninit_log_state = solver.init_logger(num_generations=max_steps)\nwith rutils.timer(\"evo | compile + optimize\"):\n    sol_state, log_state, losses = evo.evo(\n        get_loss, solver, init_sol_state, denorm_extend, max_steps=max_steps, rng=rng_sol, verbose=True, logger=init_log_state\n    )\nnorm_opt_params = solver.unflatten(sol_state.best_member)\nopt_params = denorm_extend.apply(norm_opt_params)\n\n# Print identified delays vs true delays\n# Note that it's inherently not possible to distinguish between sensor and actuator delays, but we can estimate their sum.\n# Hence, we compare the sum of the identified delays with the sum of the true delays.\n# print(f\"Sensor delay | true={sensor_delay:.3f}\\u00B1{std_delay:.3f}, opt={opt_params['sensor'].sensor_delay:.3f}, init={init_params['sensor'].sensor_delay:.3f}\")\n# print(f\"Actuator delay | true={actuator_delay:.3f}\\u00B1{std_delay:.3f}, opt={opt_params['actuator'].actuator_delay:.3f}, init={init_params['actuator'].actuator_delay:.3f}\")\nprint(\n    f\"Actuator+senor delay | \"\n    f\"true={sensor_delay+actuator_delay:.3f}\\u00b1{std_delay * 2:.3f}, \"\n    f\"opt={opt_params['sensor'].sensor_delay+opt_params['actuator'].actuator_delay:.3f}, \"\n    f\"init={init_params['sensor'].sensor_delay+init_params['actuator'].actuator_delay:.3f}\"\n)\n\n\ndef rollout(params, rng, carry_only: bool = True):\n    # Initialize the graph state\n    # By supplying the params, we override the params generated by every node's init_params method\n    # This allows us to run the graph with the specified parameters\n    gs_init = graph_sim.init(rng=rng, params=params, order=(\"agent\",))\n\n    # Rollout graph\n    gs_rollout = graph_sim.rollout(gs_init, carry_only=carry_only)\n    return gs_rollout\n\n\nrng, rng_rollout = jax.random.split(rng)\ninit_rollout = rollout(extend.apply(init_params), rng_rollout, carry_only=False)\nopt_rollout = rollout(opt_params, rng_rollout, carry_only=False)\n</code></pre> <pre>\n<code>ParameterReshaper: 11 parameters detected for optimization.\nParameterReshaper: 11 parameters detected for optimization.\nstep: 0 | min_loss: 1127.36083984375 | mean_loss: 107012.1796875 | max_loss: 568760.6875 | bestsofar_loss: 1127.36083984375 | total_samples: 200\nstep: 1 | min_loss: 1204.197509765625 | mean_loss: 78481.2890625 | max_loss: 522921.8125 | bestsofar_loss: 1127.36083984375 | total_samples: 400\nstep: 2 | min_loss: 492.77667236328125 | mean_loss: 36297.34375 | max_loss: 560355.5 | bestsofar_loss: 492.77667236328125 | total_samples: 600\nstep: 3 | min_loss: 798.9423217773438 | mean_loss: 24496.9765625 | max_loss: 370595.75 | bestsofar_loss: 492.77667236328125 | total_samples: 800\nstep: 4 | min_loss: 205.81195068359375 | mean_loss: 28551.529296875 | max_loss: 472859.34375 | bestsofar_loss: 205.81195068359375 | total_samples: 1000\nstep: 5 | min_loss: 52.2267951965332 | mean_loss: 23809.375 | max_loss: 484273.875 | bestsofar_loss: 52.2267951965332 | total_samples: 1200\nstep: 6 | min_loss: 230.4680633544922 | mean_loss: 12314.4443359375 | max_loss: 296713.3125 | bestsofar_loss: 52.2267951965332 | total_samples: 1400\nstep: 7 | min_loss: 79.33822631835938 | mean_loss: 4005.635498046875 | max_loss: 144555.671875 | bestsofar_loss: 52.2267951965332 | total_samples: 1600\nstep: 8 | min_loss: 94.1173095703125 | mean_loss: 7078.138671875 | max_loss: 260285.21875 | bestsofar_loss: 52.2267951965332 | total_samples: 1800\nstep: 9 | min_loss: 43.1370849609375 | mean_loss: 5270.595703125 | max_loss: 281373.15625 | bestsofar_loss: 43.1370849609375 | total_samples: 2000\nstep: 10 | min_loss: 52.78352737426758 | mean_loss: 2540.908935546875 | max_loss: 216970.453125 | bestsofar_loss: 43.1370849609375 | total_samples: 2200\nstep: 11 | min_loss: 21.664243698120117 | mean_loss: 1918.5511474609375 | max_loss: 115976.4609375 | bestsofar_loss: 21.664243698120117 | total_samples: 2400\nstep: 12 | min_loss: 23.087072372436523 | mean_loss: 485.77374267578125 | max_loss: 3315.34619140625 | bestsofar_loss: 21.664243698120117 | total_samples: 2600\nstep: 13 | min_loss: 17.36151123046875 | mean_loss: 770.6312866210938 | max_loss: 92419.09375 | bestsofar_loss: 17.36151123046875 | total_samples: 2800\nstep: 14 | min_loss: 15.771574020385742 | mean_loss: 217.87913513183594 | max_loss: 2219.96484375 | bestsofar_loss: 15.771574020385742 | total_samples: 3000\nstep: 15 | min_loss: 13.985638618469238 | mean_loss: 156.47499084472656 | max_loss: 930.321044921875 | bestsofar_loss: 13.985638618469238 | total_samples: 3200\nstep: 16 | min_loss: 14.507802963256836 | mean_loss: 91.28081512451172 | max_loss: 345.5765075683594 | bestsofar_loss: 13.985638618469238 | total_samples: 3400\nstep: 17 | min_loss: 10.358619689941406 | mean_loss: 63.55476379394531 | max_loss: 610.5410766601562 | bestsofar_loss: 10.358619689941406 | total_samples: 3600\nstep: 18 | min_loss: 11.19699764251709 | mean_loss: 43.670433044433594 | max_loss: 201.0476837158203 | bestsofar_loss: 10.358619689941406 | total_samples: 3800\nstep: 19 | min_loss: 10.942270278930664 | mean_loss: 35.828548431396484 | max_loss: 147.9638214111328 | bestsofar_loss: 10.358619689941406 | total_samples: 4000\nstep: 20 | min_loss: 10.034892082214355 | mean_loss: 26.11482810974121 | max_loss: 97.91928100585938 | bestsofar_loss: 10.034892082214355 | total_samples: 4200\nstep: 21 | min_loss: 9.910079956054688 | mean_loss: 24.807268142700195 | max_loss: 62.91225814819336 | bestsofar_loss: 9.910079956054688 | total_samples: 4400\nstep: 22 | min_loss: 9.878521919250488 | mean_loss: 19.883729934692383 | max_loss: 67.5021743774414 | bestsofar_loss: 9.878521919250488 | total_samples: 4600\nstep: 23 | min_loss: 9.988247871398926 | mean_loss: 15.602483749389648 | max_loss: 28.748577117919922 | bestsofar_loss: 9.878521919250488 | total_samples: 4800\nstep: 24 | min_loss: 9.414042472839355 | mean_loss: 13.810941696166992 | max_loss: 25.670448303222656 | bestsofar_loss: 9.414042472839355 | total_samples: 5000\nstep: 25 | min_loss: 9.660612106323242 | mean_loss: 13.065614700317383 | max_loss: 24.568443298339844 | bestsofar_loss: 9.414042472839355 | total_samples: 5200\nstep: 26 | min_loss: 9.325304985046387 | mean_loss: 12.260498046875 | max_loss: 19.5605411529541 | bestsofar_loss: 9.325304985046387 | total_samples: 5400\nstep: 27 | min_loss: 9.467074394226074 | mean_loss: 11.700087547302246 | max_loss: 17.791507720947266 | bestsofar_loss: 9.325304985046387 | total_samples: 5600\nstep: 28 | min_loss: 9.31025218963623 | mean_loss: 11.482476234436035 | max_loss: 22.758207321166992 | bestsofar_loss: 9.31025218963623 | total_samples: 5800\nstep: 29 | min_loss: 9.248895645141602 | mean_loss: 11.012178421020508 | max_loss: 27.078828811645508 | bestsofar_loss: 9.248895645141602 | total_samples: 6000\nstep: 30 | min_loss: 9.093639373779297 | mean_loss: 10.656970024108887 | max_loss: 16.620359420776367 | bestsofar_loss: 9.093639373779297 | total_samples: 6200\nstep: 31 | min_loss: 8.997458457946777 | mean_loss: 10.727160453796387 | max_loss: 21.239608764648438 | bestsofar_loss: 8.997458457946777 | total_samples: 6400\nstep: 32 | min_loss: 8.898832321166992 | mean_loss: 10.218587875366211 | max_loss: 15.635639190673828 | bestsofar_loss: 8.898832321166992 | total_samples: 6600\nstep: 33 | min_loss: 8.828442573547363 | mean_loss: 10.003084182739258 | max_loss: 15.199397087097168 | bestsofar_loss: 8.828442573547363 | total_samples: 6800\nstep: 34 | min_loss: 8.915704727172852 | mean_loss: 9.74703598022461 | max_loss: 12.05509090423584 | bestsofar_loss: 8.828442573547363 | total_samples: 7000\nstep: 35 | min_loss: 8.837331771850586 | mean_loss: 9.452198028564453 | max_loss: 11.589494705200195 | bestsofar_loss: 8.828442573547363 | total_samples: 7200\nstep: 36 | min_loss: 8.703187942504883 | mean_loss: 9.30600357055664 | max_loss: 10.700343132019043 | bestsofar_loss: 8.703187942504883 | total_samples: 7400\nstep: 37 | min_loss: 8.815437316894531 | mean_loss: 9.199464797973633 | max_loss: 10.256194114685059 | bestsofar_loss: 8.703187942504883 | total_samples: 7600\nstep: 38 | min_loss: 8.725391387939453 | mean_loss: 9.130782127380371 | max_loss: 9.658961296081543 | bestsofar_loss: 8.703187942504883 | total_samples: 7800\nstep: 39 | min_loss: 8.653090476989746 | mean_loss: 9.083394050598145 | max_loss: 9.625836372375488 | bestsofar_loss: 8.653090476989746 | total_samples: 8000\nstep: 40 | min_loss: 8.663204193115234 | mean_loss: 9.016453742980957 | max_loss: 9.460453987121582 | bestsofar_loss: 8.653090476989746 | total_samples: 8200\nstep: 41 | min_loss: 8.673643112182617 | mean_loss: 8.977757453918457 | max_loss: 9.457200050354004 | bestsofar_loss: 8.653090476989746 | total_samples: 8400\nstep: 42 | min_loss: 8.62848949432373 | mean_loss: 8.939888954162598 | max_loss: 9.32252311706543 | bestsofar_loss: 8.62848949432373 | total_samples: 8600\nstep: 43 | min_loss: 8.626632690429688 | mean_loss: 8.919873237609863 | max_loss: 9.297188758850098 | bestsofar_loss: 8.626632690429688 | total_samples: 8800\nstep: 44 | min_loss: 8.698042869567871 | mean_loss: 8.914870262145996 | max_loss: 9.22347640991211 | bestsofar_loss: 8.626632690429688 | total_samples: 9000\nstep: 45 | min_loss: 8.631864547729492 | mean_loss: 8.902450561523438 | max_loss: 9.235267639160156 | bestsofar_loss: 8.626632690429688 | total_samples: 9200\nstep: 46 | min_loss: 8.659423828125 | mean_loss: 8.88314437866211 | max_loss: 9.170170783996582 | bestsofar_loss: 8.626632690429688 | total_samples: 9400\nstep: 47 | min_loss: 8.633057594299316 | mean_loss: 8.862720489501953 | max_loss: 9.144399642944336 | bestsofar_loss: 8.626632690429688 | total_samples: 9600\nstep: 48 | min_loss: 8.608020782470703 | mean_loss: 8.86151123046875 | max_loss: 9.157724380493164 | bestsofar_loss: 8.608020782470703 | total_samples: 9800\nstep: 49 | min_loss: 8.585338592529297 | mean_loss: 8.846010208129883 | max_loss: 9.186552047729492 | bestsofar_loss: 8.585338592529297 | total_samples: 10000\n[47821][MainThread               ][tracer              ][evo | compile + optimize] Elapsed: 12.5141 sec\nActuator+senor delay | true=0.020\u00b10.004, opt=0.029, init=0.300\n</code>\n</pre> <pre><code># @title Plot Optimization Loss\n# @markdown This plot shows the loss during the parameter optimization process.\n# @markdown Lower losses indicate better fit between the model and collected data.\n\nfig_loss, ax_loss = plt.subplots(1, 1, figsize=(4, 3))\nlog_state.plot(\"Loss\", fig=fig_loss, ax=ax_loss)\nax_loss.set_yscale(\"log\");\n</code></pre> <pre><code># @title Visualize Reconstructed and True Sensor Readings\n# @markdown The following plots show the comparison between the true sensor readings and the reconstructed readings.\n# @markdown A close match with both the observe and optimized lines on top of each other suggests the model is accurately capturing the system behavior.\n\n# Rollout the optimized parameters\ninit_sensor = init_rollout.inputs[\"agent\"][\"sensor\"].data[:, -1]\ninit_ts_sensor = init_rollout.inputs[\"agent\"][\"sensor\"].ts_sent[:, -1]\ninit_actuator = init_rollout.inputs[\"world\"][\"actuator\"].data[:, -1]\ninit_ts_actuator = init_rollout.inputs[\"world\"][\"actuator\"].ts_sent[:, -1]\n\nopt_sensor = opt_rollout.inputs[\"agent\"][\"sensor\"].data[:, -1]\nopt_ts_sensor = opt_rollout.inputs[\"agent\"][\"sensor\"].ts_sent[:, -1]\nopt_actuator = opt_rollout.inputs[\"world\"][\"actuator\"].data[:, -1]\nopt_ts_actuator = opt_rollout.inputs[\"world\"][\"actuator\"].ts_sent[:, -1]\n\nfig, axes = plt.subplots(1, 3, figsize=(12, 3))\naxes[0].plot(record.nodes[\"agent\"].steps.ts_end[:-1], record.nodes[\"agent\"].steps.output.action, label=\"action\")\naxes[0].plot(opt_ts_actuator, opt_actuator.action[:, 0], label=\"action (ode, opt)\")\naxes[0].plot(init_ts_actuator, init_actuator.action[:, 0], label=\"action (ode, init)\")\naxes[0].set_xlabel(\"Time [s]\")\naxes[0].set_ylabel(\"Torque [Nm]\")\naxes[0].legend()\n\naxes[1].plot(record.nodes[\"sensor\"].steps.ts_end, record.nodes[\"sensor\"].steps.output.th, label=\"th (brax)\")\naxes[1].plot(opt_ts_sensor, opt_sensor.th, label=\"th (ode, opt)\")\naxes[1].plot(init_ts_sensor, init_sensor.th, label=\"th (ode, init)\")\naxes[1].set_xlabel(\"Time [s]\")\naxes[1].set_ylabel(\"Angle [rad]\")\naxes[1].legend()\n\naxes[2].plot(record.nodes[\"sensor\"].steps.ts_end, record.nodes[\"sensor\"].steps.output.thdot, label=\"thdot (brax)\")\naxes[2].plot(opt_ts_sensor, opt_sensor.thdot, label=\"thdot (ode, opt)\")\naxes[2].plot(init_ts_sensor, init_sensor.thdot, label=\"thdot (ode, init)\")\naxes[2].set_xlabel(\"Time [s]\")\naxes[2].set_ylabel(\"Ang. Vel. [rad/s]\")\naxes[2].legend();\n</code></pre> <pre><code># @title Train a Policy to Swing-Up the Pendulum Using PPO\n# @markdown We will train a policy to swing up the pendulum using Proximal Policy Optimization (PPO) on the identified system.\n# @markdown The success rate is the percentage of steps where the pendulum remains upright (cos(theta) &gt; 0.95 and |theta_dot| &lt; 0.5).\n# @markdown We train 5 policies in parallel and select the best one based on the mean return\n\n# Reinitialize a graph with nodes (that do not replay actions and calculate reconstruction error)\ninfos_sim = {name: n.info for name, n in nodes_sim.items()}\nnodes_rl = {name: n.from_info(infos_sim[name]) for name, n in nodes_sim.items()}\n[n.connect_from_info(infos_sim[name].inputs, nodes_rl) for name, n in nodes_rl.items()]\ngraph_rl = rex.graph.Graph(nodes_rl, nodes_rl[\"agent\"], cg)\n\n# Define the environment\nenv = pdm.rl.SwingUpEnv(graph=graph_rl)\n\n# Set RL params\nrl_params = opt_params.copy()  # Get base structure for params\nrl_params[\"agent\"] = rl_params[\"agent\"].replace(init_method=\"random\")\nenv.set_params(rl_params)\n\n# Initialize PPO config\n# sweep_pmv2r1zf is a PPO hyperparameter sweep that was found to work well for the pendulum swing-up task\nconfig = pdm.rl.sweep_pmv2r1zf\n\n# Train (success rate is the percentage of steps where the pendulum remains upright)\nimport rex.ppo as ppo\n\n\nrng, rng_train = jax.random.split(rng)\nrngs_train = jax.random.split(rng_train, num=5)  # Train 5 policies in parallel\ntrain = functools.partial(ppo.train, env)\nwith rutils.timer(\"ppo | compile\"):\n    train_v = jax.vmap(train, in_axes=(None, 0))\n    train_vjit = jax.jit(train_v)\n    train_vjit = train_vjit.lower(config, rngs_train).compile()\nwith rutils.timer(\"ppo | train\"):\n    res = train_vjit(config, rngs_train)\n\n# Get best policy (based on res.metrics[\"eval/mean_returns\"])\nbest_idx = jnp.argmax(res.metrics[\"eval/mean_returns\"][:, -1])\nbest_policy = res.policy[best_idx]\neval_params = rl_params.copy()\neval_params[\"agent\"] = eval_params[\"agent\"].replace(init_method=\"random\", policy=best_policy)\n</code></pre> <pre>\n<code>Growing supergraph: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 601/601 [00:00&lt;00:00, 1297.43it/s, 1/1 graphs, 2411/2411 matched (66.86% efficiency, 6 nodes (pre-filtered: 6 nodes))]\n</code>\n</pre> <pre>\n<code>[47821][MainThread               ][tracer              ][ppo | compile       ] Elapsed: 32.4849 sec\ntrain_steps=249856 | eval_eps=20 | return=-723.0+-131.5 | length=147+-0.0 | approxkl=0.0033 | success_rate=0.01\ntrain_steps=249856 | eval_eps=20 | return=-1036.5+-83.7 | length=147+-0.0 | approxkl=0.0035 | success_rate=0.00\ntrain_steps=249856 | eval_eps=20 | return=-466.0+-90.1 | length=147+-0.0 | approxkl=0.0034 | success_rate=0.04\ntrain_steps=249856 | eval_eps=20 | return=-385.1+-68.9 | length=147+-0.0 | approxkl=0.0035 | success_rate=0.04\ntrain_steps=249856 | eval_eps=20 | return=-369.4+-74.8 | length=147+-0.0 | approxkl=0.0034 | success_rate=0.03\ntrain_steps=499712 | eval_eps=20 | return=-429.4+-90.5 | length=147+-0.0 | approxkl=0.0032 | success_rate=0.03\ntrain_steps=499712 | eval_eps=20 | return=-1008.8+-84.3 | length=147+-0.0 | approxkl=0.0025 | success_rate=0.00\ntrain_steps=499712 | eval_eps=20 | return=-416.1+-83.5 | length=147+-0.0 | approxkl=0.0032 | success_rate=0.04\ntrain_steps=499712 | eval_eps=20 | return=-363.7+-61.2 | length=147+-0.0 | approxkl=0.0033 | success_rate=0.05\ntrain_steps=499712 | eval_eps=20 | return=-369.1+-80.7 | length=147+-0.0 | approxkl=0.0037 | success_rate=0.08\ntrain_steps=749568 | eval_eps=20 | return=-367.7+-92.7 | length=147+-0.0 | approxkl=0.0040 | success_rate=0.10\ntrain_steps=749568 | eval_eps=20 | return=-1052.2+-64.1 | length=147+-0.0 | approxkl=0.0021 | success_rate=0.00\ntrain_steps=749568 | eval_eps=20 | return=-353.8+-72.0 | length=147+-0.0 | approxkl=0.0032 | success_rate=0.06\ntrain_steps=749568 | eval_eps=20 | return=-364.4+-77.3 | length=147+-0.0 | approxkl=0.0035 | success_rate=0.08\ntrain_steps=749568 | eval_eps=20 | return=-338.8+-98.4 | length=147+-0.0 | approxkl=0.0042 | success_rate=0.08\ntrain_steps=999424 | eval_eps=20 | return=-185.5+-108.5 | length=147+-0.0 | approxkl=0.0049 | success_rate=0.49\ntrain_steps=999424 | eval_eps=20 | return=-1036.0+-89.5 | length=147+-0.0 | approxkl=0.0019 | success_rate=0.00\ntrain_steps=999424 | eval_eps=20 | return=-367.8+-71.5 | length=147+-0.0 | approxkl=0.0035 | success_rate=0.07\ntrain_steps=999424 | eval_eps=20 | return=-369.3+-92.4 | length=147+-0.0 | approxkl=0.0037 | success_rate=0.08\ntrain_steps=999424 | eval_eps=20 | return=-322.1+-83.2 | length=147+-0.0 | approxkl=0.0043 | success_rate=0.11\ntrain_steps=1249280 | eval_eps=20 | return=-235.5+-229.1 | length=147+-0.0 | approxkl=0.0059 | success_rate=0.48\ntrain_steps=1249280 | eval_eps=20 | return=-1011.7+-101.8 | length=147+-0.0 | approxkl=0.0019 | success_rate=0.00\ntrain_steps=1249280 | eval_eps=20 | return=-348.9+-86.0 | length=147+-0.0 | approxkl=0.0040 | success_rate=0.06\ntrain_steps=1249280 | eval_eps=20 | return=-341.9+-60.0 | length=147+-0.0 | approxkl=0.0040 | success_rate=0.08\ntrain_steps=1249280 | eval_eps=20 | return=-314.3+-76.8 | length=147+-0.0 | approxkl=0.0048 | success_rate=0.17\ntrain_steps=1499136 | eval_eps=20 | return=-189.4+-267.2 | length=147+-0.0 | approxkl=0.0054 | success_rate=0.59\ntrain_steps=1499136 | eval_eps=20 | return=-1047.6+-75.7 | length=147+-0.0 | approxkl=0.0019 | success_rate=0.00\ntrain_steps=1499136 | eval_eps=20 | return=-370.2+-89.9 | length=147+-0.0 | approxkl=0.0039 | success_rate=0.08\ntrain_steps=1499136 | eval_eps=20 | return=-353.9+-97.6 | length=147+-0.0 | approxkl=0.0044 | success_rate=0.10\ntrain_steps=1499136 | eval_eps=20 | return=-325.4+-95.4 | length=147+-0.0 | approxkl=0.0054 | success_rate=0.10\ntrain_steps=1748992 | eval_eps=20 | return=-218.9+-240.5 | length=147+-0.0 | approxkl=0.0059 | success_rate=0.56\ntrain_steps=1748992 | eval_eps=20 | return=-1025.6+-91.8 | length=147+-0.0 | approxkl=0.0017 | success_rate=0.00\ntrain_steps=1748992 | eval_eps=20 | return=-356.3+-76.1 | length=147+-0.0 | approxkl=0.0047 | success_rate=0.06\ntrain_steps=1748992 | eval_eps=20 | return=-352.3+-82.7 | length=147+-0.0 | approxkl=0.0046 | success_rate=0.07\ntrain_steps=1748992 | eval_eps=20 | return=-383.8+-85.8 | length=147+-0.0 | approxkl=0.0058 | success_rate=0.05\ntrain_steps=1998848 | eval_eps=20 | return=-157.8+-234.5 | length=147+-0.0 | approxkl=0.0051 | success_rate=0.63\ntrain_steps=1998848 | eval_eps=20 | return=-1063.5+-79.0 | length=147+-0.0 | approxkl=0.0019 | success_rate=0.00\ntrain_steps=1998848 | eval_eps=20 | return=-342.3+-54.8 | length=147+-0.0 | approxkl=0.0051 | success_rate=0.11\ntrain_steps=1998848 | eval_eps=20 | return=-354.5+-72.3 | length=147+-0.0 | approxkl=0.0052 | success_rate=0.09\ntrain_steps=1998848 | eval_eps=20 | return=-365.6+-70.5 | length=147+-0.0 | approxkl=0.0061 | success_rate=0.07\ntrain_steps=2248704 | eval_eps=20 | return=-197.4+-236.4 | length=147+-0.0 | approxkl=0.0046 | success_rate=0.56\ntrain_steps=2248704 | eval_eps=20 | return=-997.9+-96.3 | length=147+-0.0 | approxkl=0.0018 | success_rate=0.01\ntrain_steps=2248704 | eval_eps=20 | return=-319.3+-75.6 | length=147+-0.0 | approxkl=0.0053 | success_rate=0.09\ntrain_steps=2248704 | eval_eps=20 | return=-350.2+-193.7 | length=147+-0.0 | approxkl=0.0054 | success_rate=0.10\ntrain_steps=2248704 | eval_eps=20 | return=-377.5+-242.6 | length=147+-0.0 | approxkl=0.0065 | success_rate=0.20\ntrain_steps=2498560 | eval_eps=20 | return=-187.4+-231.2 | length=147+-0.0 | approxkl=0.0047 | success_rate=0.59\ntrain_steps=2498560 | eval_eps=20 | return=-1011.7+-79.2 | length=147+-0.0 | approxkl=0.0017 | success_rate=0.00\ntrain_steps=2498560 | eval_eps=20 | return=-312.9+-78.5 | length=147+-0.0 | approxkl=0.0051 | success_rate=0.09\ntrain_steps=2498560 | eval_eps=20 | return=-309.7+-96.8 | length=147+-0.0 | approxkl=0.0057 | success_rate=0.11\ntrain_steps=2498560 | eval_eps=20 | return=-342.6+-111.9 | length=147+-0.0 | approxkl=0.0070 | success_rate=0.20\ntrain_steps=2748416 | eval_eps=20 | return=-126.0+-68.6 | length=147+-0.0 | approxkl=0.0047 | success_rate=0.63\ntrain_steps=2748416 | eval_eps=20 | return=-859.8+-186.3 | length=147+-0.0 | approxkl=0.0017 | success_rate=0.03\ntrain_steps=2748416 | eval_eps=20 | return=-353.9+-86.6 | length=147+-0.0 | approxkl=0.0050 | success_rate=0.08\ntrain_steps=2748416 | eval_eps=20 | return=-305.2+-84.7 | length=147+-0.0 | approxkl=0.0063 | success_rate=0.14\ntrain_steps=2748416 | eval_eps=20 | return=-273.4+-106.8 | length=147+-0.0 | approxkl=0.0076 | success_rate=0.34\ntrain_steps=2998272 | eval_eps=20 | return=-144.8+-86.9 | length=147+-0.0 | approxkl=0.0055 | success_rate=0.62\ntrain_steps=2998272 | eval_eps=20 | return=-400.5+-92.5 | length=147+-0.0 | approxkl=0.0026 | success_rate=0.05\ntrain_steps=2998272 | eval_eps=20 | return=-320.5+-91.8 | length=147+-0.0 | approxkl=0.0056 | success_rate=0.10\ntrain_steps=2998272 | eval_eps=20 | return=-320.8+-99.0 | length=147+-0.0 | approxkl=0.0070 | success_rate=0.13\ntrain_steps=2998272 | eval_eps=20 | return=-250.0+-109.5 | length=147+-0.0 | approxkl=0.0094 | success_rate=0.41\ntrain_steps=3248128 | eval_eps=20 | return=-110.1+-64.6 | length=147+-0.0 | approxkl=0.0060 | success_rate=0.64\ntrain_steps=3248128 | eval_eps=20 | return=-151.2+-111.2 | length=147+-0.0 | approxkl=0.0033 | success_rate=0.48\ntrain_steps=3248128 | eval_eps=20 | return=-298.4+-107.5 | length=147+-0.0 | approxkl=0.0063 | success_rate=0.16\ntrain_steps=3248128 | eval_eps=20 | return=-308.5+-89.5 | length=147+-0.0 | approxkl=0.0073 | success_rate=0.12\ntrain_steps=3248128 | eval_eps=20 | return=-154.7+-78.6 | length=147+-0.0 | approxkl=0.0102 | success_rate=0.55\ntrain_steps=3497984 | eval_eps=20 | return=-148.9+-93.6 | length=147+-0.0 | approxkl=0.0062 | success_rate=0.60\ntrain_steps=3497984 | eval_eps=20 | return=-182.0+-101.1 | length=147+-0.0 | approxkl=0.0046 | success_rate=0.55\ntrain_steps=3497984 | eval_eps=20 | return=-344.1+-81.6 | length=147+-0.0 | approxkl=0.0067 | success_rate=0.08\ntrain_steps=3497984 | eval_eps=20 | return=-246.4+-94.6 | length=147+-0.0 | approxkl=0.0084 | success_rate=0.30\ntrain_steps=3497984 | eval_eps=20 | return=-257.5+-269.4 | length=147+-0.0 | approxkl=0.0088 | success_rate=0.46\ntrain_steps=3747840 | eval_eps=20 | return=-118.4+-82.1 | length=147+-0.0 | approxkl=0.0070 | success_rate=0.65\ntrain_steps=3747840 | eval_eps=20 | return=-190.9+-117.0 | length=147+-0.0 | approxkl=0.0048 | success_rate=0.56\ntrain_steps=3747840 | eval_eps=20 | return=-350.0+-195.4 | length=147+-0.0 | approxkl=0.0070 | success_rate=0.16\ntrain_steps=3747840 | eval_eps=20 | return=-320.2+-88.5 | length=147+-0.0 | approxkl=0.0091 | success_rate=0.17\ntrain_steps=3747840 | eval_eps=20 | return=-153.3+-102.3 | length=147+-0.0 | approxkl=0.0075 | success_rate=0.58\ntrain_steps=3997696 | eval_eps=20 | return=-158.4+-109.6 | length=147+-0.0 | approxkl=0.0083 | success_rate=0.61\ntrain_steps=3997696 | eval_eps=20 | return=-111.6+-77.6 | length=147+-0.0 | approxkl=0.0051 | success_rate=0.64\ntrain_steps=3997696 | eval_eps=20 | return=-288.1+-67.9 | length=147+-0.0 | approxkl=0.0079 | success_rate=0.12\ntrain_steps=3997696 | eval_eps=20 | return=-288.1+-214.5 | length=147+-0.0 | approxkl=0.0095 | success_rate=0.31\ntrain_steps=3997696 | eval_eps=20 | return=-135.9+-92.0 | length=147+-0.0 | approxkl=0.0072 | success_rate=0.62\ntrain_steps=4247552 | eval_eps=20 | return=-134.3+-82.8 | length=147+-0.0 | approxkl=0.0081 | success_rate=0.63\ntrain_steps=4247552 | eval_eps=20 | return=-146.5+-108.1 | length=147+-0.0 | approxkl=0.0051 | success_rate=0.63\ntrain_steps=4247552 | eval_eps=20 | return=-280.4+-72.4 | length=147+-0.0 | approxkl=0.0084 | success_rate=0.16\ntrain_steps=4247552 | eval_eps=20 | return=-252.2+-219.3 | length=147+-0.0 | approxkl=0.0107 | success_rate=0.36\ntrain_steps=4247552 | eval_eps=20 | return=-192.1+-245.3 | length=147+-0.0 | approxkl=0.0068 | success_rate=0.60\ntrain_steps=4497408 | eval_eps=20 | return=-125.9+-81.8 | length=147+-0.0 | approxkl=0.0085 | success_rate=0.62\ntrain_steps=4497408 | eval_eps=20 | return=-209.4+-127.2 | length=147+-0.0 | approxkl=0.0051 | success_rate=0.54\ntrain_steps=4497408 | eval_eps=20 | return=-297.3+-90.2 | length=147+-0.0 | approxkl=0.0090 | success_rate=0.18\ntrain_steps=4497408 | eval_eps=20 | return=-250.1+-214.0 | length=147+-0.0 | approxkl=0.0115 | success_rate=0.39\ntrain_steps=4497408 | eval_eps=20 | return=-153.8+-98.9 | length=147+-0.0 | approxkl=0.0076 | success_rate=0.62\ntrain_steps=4747264 | eval_eps=20 | return=-147.7+-88.6 | length=147+-0.0 | approxkl=0.0088 | success_rate=0.62\ntrain_steps=4747264 | eval_eps=20 | return=-185.2+-114.9 | length=147+-0.0 | approxkl=0.0056 | success_rate=0.59\ntrain_steps=4747264 | eval_eps=20 | return=-395.1+-259.8 | length=147+-0.0 | approxkl=0.0101 | success_rate=0.09\ntrain_steps=4747264 | eval_eps=20 | return=-214.4+-81.5 | length=147+-0.0 | approxkl=0.0115 | success_rate=0.40\ntrain_steps=4747264 | eval_eps=20 | return=-194.1+-123.8 | length=147+-0.0 | approxkl=0.0079 | success_rate=0.57\ntrain_steps=4997120 | eval_eps=20 | return=-161.5+-102.1 | length=147+-0.0 | approxkl=0.0072 | success_rate=0.57\ntrain_steps=4997120 | eval_eps=20 | return=-170.6+-66.3 | length=147+-0.0 | approxkl=0.0070 | success_rate=0.58\ntrain_steps=4997120 | eval_eps=20 | return=-243.9+-118.5 | length=147+-0.0 | approxkl=0.0110 | success_rate=0.28\ntrain_steps=4997120 | eval_eps=20 | return=-244.7+-219.0 | length=147+-0.0 | approxkl=0.0115 | success_rate=0.40\ntrain_steps=4997120 | eval_eps=20 | return=-148.2+-111.6 | length=147+-0.0 | approxkl=0.0076 | success_rate=0.64\n[47821][MainThread               ][tracer              ][ppo | train         ] Elapsed: 50.7621 sec\n</code>\n</pre> <pre><code># @title Visualize PPO Training Progress\n# @markdown The plots below show the training progress of the PPO algorithm in terms of returns, success rate, and policy KL divergence.\n\nfig_ppo, axes_ppo = plt.subplots(1, 3, figsize=(12, 3))\ntotal_steps = res.metrics[\"train/total_steps\"].transpose()\nmean, std = res.metrics[\"eval/mean_returns\"].transpose(), res.metrics[\"eval/std_returns\"].transpose()\naxes_ppo[0].plot(total_steps, mean, label=\"mean\")\naxes_ppo[0].set_title(\"Returns\")\naxes_ppo[0].set_xlabel(\"Total steps\")\naxes_ppo[0].set_ylabel(\"Cum. return\")\nmean = res.metrics[\"eval/success_rate\"].transpose()\naxes_ppo[1].plot(total_steps, mean, label=\"mean\")\naxes_ppo[1].set_title(r\"Success ($\\cos(\\theta) &gt; 0.95$ &amp; $|\\dot{\\theta}| &lt; 0.5$)\")\naxes_ppo[1].set_xlabel(\"Total steps\")\naxes_ppo[1].set_ylabel(\"Upright [% of steps]\")\nmean, std = res.metrics[\"train/mean_approxkl\"].transpose(), res.metrics[\"train/std_approxkl\"].transpose()\naxes_ppo[2].plot(total_steps, mean, label=\"mean\")\naxes_ppo[2].set_title(\"Policy KL\")\naxes_ppo[2].set_xlabel(\"Total steps\")\naxes_ppo[2].set_ylabel(\"Approx. kl\");\n</code></pre> <pre><code># @title Evaluate the Learned Policy on the Simulated System (i.e. used during training)\n# @markdown We evaluate the learned policy by running multiple rollouts in parallel.\n# @markdown The success rate is calculated as the percentage of time the pendulum remains upright and still.\n\nnum_rollouts = 20_000  # Lower if memory is an issue\nmax_steps = int(5 * nodes_sim[\"agent\"].rate)  # 5 seconds\n\n# Check if we have a GPU\ntry:\n    gpu = jax.devices(\"gpu\")\nexcept RuntimeError:\n    num_rollouts = 100  # Lower if no GPU is available\n    print(\n        \"Warning: No GPU found, falling back to CPU. Speedups will be less pronounced. Lowering the number of rollouts to 100.\"\n    )\n    print(\n        \"Hint: if you are using Google Colab, try to change the runtime to GPU: \"\n        \"Runtime -&gt; Change runtime type -&gt; Hardware accelerator -&gt; GPU.\"\n    )\n\n\ndef rollout_fn(rng):\n    # Initialize graph state\n    _gs = graph_rl.init(rng, params=eval_params, order=(\"agent\",))\n    # Make sure to record the state\n    _gs = graph_rl.init_record(_gs, params=False, state=True, output=False)\n    # Run the graph for a fixed number of steps\n    _gs_rollout = graph_rl.rollout(_gs, carry_only=True, max_steps=max_steps)\n    # This returns a record that may only be partially filled.\n    record = _gs_rollout.aux[\"record\"]\n    is_filled = record.nodes[\"world\"].steps.seq &gt;= 0  # Unfilled steps are marked with -1\n    return is_filled, record.nodes[\"world\"].steps.state\n\n\nrng, rng_rollout = jax.random.split(rng)\nrngs_rollout = jax.random.split(rng_rollout, num=num_rollouts)\nt_jit = rutils.timer(\n    f\"Vectorized evaluation of {num_rollouts} rollouts | compile\", log_level=100\n)  # Makes them available outside the context manager\nwith t_jit:\n    rollout_fn_jv = jax.jit(jax.vmap(rollout_fn))\n    rollout_fn_jv = rollout_fn_jv.lower(rngs_rollout)\n    rollout_fn_jv = rollout_fn_jv.compile()\nt_run = rutils.timer(f\"Vectorized evaluation of {num_rollouts} rollouts | rollouts\", log_level=100)\nwith t_run:\n    is_filled, final_states = rollout_fn_jv(rngs_rollout)\n    final_states.th.block_until_ready()\n\n# Only keep the filled rollouts (we did not run the full duration of the computation graph)\nfinal_states = final_states[is_filled]\n\n# Calculate success rate\nthr_upright = 0.95  # Cosine of the angle threshold\nthr_static = 0.5  # Angular velocity threshold\ncos_th = jnp.cos(final_states.th)\nthdot = final_states.thdot\nis_upright = cos_th &gt; thr_upright\nis_static = jnp.abs(thdot) &lt; thr_static\nis_valid = jnp.logical_and(is_upright, is_static)\nsuccess_rate = is_valid.sum() / is_valid.size\nprint(f\"sim. eval | Success rate: {success_rate:.2f}\")\nprint(\n    f\"sim. eval | fps: {(num_rollouts * max_steps) / t_run.duration / 1e6:.0f} Million steps/s | compile: {t_jit.duration:.2f} s | run: {t_run.duration:.2f} s\"\n)\n</code></pre> <pre>\n<code>[47821][MainThread               ][tracer              ][Vectorized evaluation of 20000 rollouts | compile] Elapsed: 5.4447 sec\n[47821][MainThread               ][tracer              ][Vectorized evaluation of 20000 rollouts | rollouts] Elapsed: 0.1476 sec\nsim. eval | Success rate: 0.81\nsim. eval | fps: 34 Million steps/s | compile: 5.44 s | run: 0.15 s\n</code>\n</pre> <pre><code># @title Evaluate the Learned Policy on the \"Real\" Brax System (i.e. sim2real transfer)\n# @markdown We will now evaluate the learned policy on the real Brax simulation system, which we used to collect data in the beginning.\n\n\n@jax.jit\ndef get_action(step_state: base.StepState):\n    obs = eval_params[\"agent\"].get_observation(step_state)\n    action = eval_params[\"agent\"].policy.get_action(obs)\n    output = eval_params[\"agent\"].to_output(action)  # Convert the action to an output message\n    new_ss = step_state.replace(state=eval_params[\"agent\"].update_state(step_state, action))\n    return new_ss, output\n\n\n# Run for one episode\ngs, ss = graph.reset(gs_init_real)  # Reset the graph to the initial state (returns the gs and the step state of the agent)\nfor i in tqdm.tqdm(range(max_steps), desc=\"brax | evaluate policy\"):\n    new_ss, output = get_action(ss)\n    gs, ss = graph.step(gs, new_ss, output)  # Step the graph with the agent's output\ngraph.stop()  # Stops all nodes that were running asynchronously in the background\n# Get the record\neval_record = graph.get_record()  # Get the record of the graph\neval_real_rollout = eval_record.nodes[\"world\"].steps.state\n# Filter out the world node, as it would not be available in a real-world system\neval_record = eval_record.filter(nodes_real)\n</code></pre> <pre>\n<code>brax | evaluate policy: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 250/250 [00:01&lt;00:00, 208.06it/s]\n</code>\n</pre> <pre><code># @title Visualize Actions and Sensor Readings in the Real-World System\n# @markdown The following plots display the actions and sensor readings during the evaluation of the policy on the real-world system.\n\nfig, axes = plt.subplots(1, 3, figsize=(12, 3))\naxes[0].plot(eval_record.nodes[\"agent\"].steps.ts_end[:-1], eval_record.nodes[\"agent\"].steps.output.action, label=\"action\")\naxes[0].set_xlabel(\"Time [s]\")\naxes[0].set_ylabel(\"Torque [Nm]\")\naxes[0].legend()\n\naxes[1].plot(eval_record.nodes[\"sensor\"].steps.ts_end, eval_record.nodes[\"sensor\"].steps.output.th, label=\"th\")\naxes[1].set_xlabel(\"Time [s]\")\naxes[1].set_ylabel(\"Angle [rad]\")\naxes[1].legend()\n\naxes[2].plot(eval_record.nodes[\"sensor\"].steps.ts_end, eval_record.nodes[\"sensor\"].steps.output.thdot, label=\"thdot\")\naxes[2].set_xlabel(\"Time [s]\")\naxes[2].set_ylabel(\"Ang. Vel. [rad/s]\")\naxes[2].legend();\n</code></pre> <pre><code># @title Visualize the Rollout\n# @markdown The following visualization shows the rollout of the pendulum swing-up task, displaying the system's behavior over time.\n# @markdown Note: Html visualization may not work properly if rendering simultaneously in multiple cells.\n# @markdown In such cases, comment-out all but one HTML(pendulum.render(rollout)).\nHTML(pdm.render.render(eval_real_rollout, dt=float(1 / world.rate)))\n</code></pre> Brax visualizer <pre><code>&lt;script type=\"module\"&gt;\n  import {Viewer} from 'viewer';\n  const domElement = document.getElementById(\"brax-viewer\");\n  var viewer = new Viewer(domElement, system);\n&lt;/script&gt;\n</code></pre> <pre><code>\n</code></pre>"},{"location":"examples/sim2real.html#sim-to-real-with-rex-robotic-environments-with-jax","title":"Sim-to-Real with rex (Robotic Environments with jaX)","text":"<p>This notebook offers an introductory tutorial for rex (Robotic Environments with jaX), a JAX-based framework for building graph-based environments designed for sim2real robotics.</p> <p>In this tutorial, we will walk through a simple sim-to-real example using rex, where we will: 1. Define a simple pendulum system as an interconnected set of nodes, where:    - brax is used as a stand-in for the real-world system.    - We simulate real-world asynchronous effects by introducing communication and computation delays using predefined delay distributions.    - The node definitions used in this notebook are covered in detail in the node_definitions.ipynb notebook. 2. Apply open-loop control to the pendulum system to gather data. 3. Use the collected data to:    - Fit Gaussian Mixture Models (GMM) to estimate the delays introduced in step (1).    - Build an ODE simulation environment.    - Use evolutionary strategies to identify hidden delays and parameters in the ODE environment that best match the collected data. 4. Train an agent to balance the pendulum in the ODE environment using PPO (Proximal Policy Optimization). 5. Zero-shot transfer the trained agent to the real-world environment.</p> <p>A Colab runtime with GPU acceleration is required. If you're using a CPU-only runtime, you can switch using the menu \"Runtime &gt; Change runtime type\".</p>"}]}