#!/usr/bin/env bash

SITE_PACKAGES_DIR="$(python -c 'import sysconfig; print(sysconfig.get_paths()["purelib"])')" # path to /path/to/conda/envs/XXX/lib/pythonX.Y/site-packages

patch -p1 -d "${SITE_PACKAGES_DIR}/ray" << 'EOF'
--- a/rllib/evaluation/collectors/simple_list_collector.py
+++ b/rllib/evaluation/collectors/simple_list_collector.py
@@ -734,6 +734,7 @@ class SimpleListCollector(SampleCollector):
                 if data_col
                 in [
                     SampleBatch.OBS,
+                    SampleBatch.INFOS,
                     SampleBatch.ENV_ID,
                     SampleBatch.EPS_ID,
                     SampleBatch.AGENT_INDEX,

--- a/rllib/evaluation/metrics.py
+++ b/rllib/evaluation/metrics.py
@@ -174,10 +174,12 @@ def summarize_episodes(
         min_reward = min(episode_rewards)
         max_reward = max(episode_rewards)
         avg_reward = np.mean(episode_rewards)
+        stddev_reward = np.std(episode_rewards, ddof=1)
     else:
         min_reward = float("nan")
         max_reward = float("nan")
         avg_reward = float("nan")
+        stddev_reward = float("nan")
     if episode_lengths:
         avg_length = np.mean(episode_lengths)
     else:
@@ -190,10 +192,12 @@ def summarize_episodes(
     policy_reward_min = {}
     policy_reward_mean = {}
     policy_reward_max = {}
+    policy_reward_stddev = {}
     for policy_id, rewards in policy_rewards.copy().items():
         policy_reward_min[policy_id] = np.min(rewards)
         policy_reward_mean[policy_id] = np.mean(rewards)
         policy_reward_max[policy_id] = np.max(rewards)
+        policy_reward_stddev[policy_id] = np.std(rewards, ddof=1)

         # Show as histogram distributions.
         hist_stats["policy_{}_reward".format(policy_id)] = rewards
@@ -204,6 +208,7 @@ def summarize_episodes(
             custom_metrics[k] = filt
         else:
             custom_metrics[k + "_mean"] = np.mean(filt)
+            custom_metrics[k + "_stddev"] = np.std(filt, ddof=1)
             if filt:
                 custom_metrics[k + "_min"] = np.min(filt)
                 custom_metrics[k + "_max"] = np.max(filt)
@@ -229,12 +234,14 @@ def summarize_episodes(
         episode_reward_max=max_reward,
         episode_reward_min=min_reward,
         episode_reward_mean=avg_reward,
+        episode_reward_stddev=stddev_reward,
         episode_len_mean=avg_length,
         episode_media=dict(episode_media),
         episodes_this_iter=len(new_episodes),
         policy_reward_min=policy_reward_min,
         policy_reward_max=policy_reward_max,
         policy_reward_mean=policy_reward_mean,
+        policy_reward_stddev=policy_reward_stddev,
         custom_metrics=dict(custom_metrics),
         hist_stats=dict(hist_stats),
         sampler_perf=dict(perf_stats),

--- a/rllib/policy/rnn_sequencing.py
+++ b/rllib/policy/rnn_sequencing.py
@@ -115,7 +115,7 @@ def pad_batch_to_sequences_of_same_size(
         elif (
             not feature_keys
             and not k.startswith("state_out_")
-            and k not in ["infos", SampleBatch.SEQ_LENS]
+            and k not in [SampleBatch.INFOS, SampleBatch.SEQ_LENS]
         ):
             feature_keys_.append(k)

@@ -204,11 +204,13 @@ def add_time_dimension(

         # Dynamically reshape the padded batch to introduce a time dimension.
         new_batch_size = padded_batch_size // max_seq_len
+        batch_major_shape = (new_batch_size, max_seq_len) + padded_inputs.shape[1:]
+        padded_outputs = padded_inputs.view(batch_major_shape)
+
         if time_major:
-            new_shape = (max_seq_len, new_batch_size) + padded_inputs.shape[1:]
-        else:
-            new_shape = (new_batch_size, max_seq_len) + padded_inputs.shape[1:]
-        return torch.reshape(padded_inputs, new_shape)
+            # Swap the batch and time dimensions
+            padded_outputs = padded_outputs.transpose(0, 1)
+        return padded_outputs


 @DeveloperAPI

--- a/rllib/policy/sample_batch.py
+++ b/rllib/policy/sample_batch.py
@@ -913,26 +913,30 @@ class SampleBatch(dict):
             # Build our slice-map, if not done already.
             if not self._slice_map:
                 sum_ = 0
-                for i, l in enumerate(self[SampleBatch.SEQ_LENS]):
-                    for _ in range(l):
-                        self._slice_map.append((i, sum_))
-                    sum_ += l
+                for i, l in enumerate(map(int, self[SampleBatch.SEQ_LENS])):
+                    self._slice_map.extend([(i, sum_)] * l)
+                    sum_ = sum_ + l
                 # In case `stop` points to the very end (lengths of this
                 # batch), return the last sequence (the -1 here makes sure we
                 # never go beyond it; would result in an index error below).
                 self._slice_map.append((len(self[SampleBatch.SEQ_LENS]), sum_))

-            start_seq_len, start = self._slice_map[start]
-            stop_seq_len, stop = self._slice_map[stop]
+            start_seq_len, start_unpadded = self._slice_map[start]
+            stop_seq_len, stop_unpadded = self._slice_map[stop]
+            start_padded = start_unpadded
+            stop_padded = stop_unpadded
             if self.zero_padded:
-                start = start_seq_len * self.max_seq_len
-                stop = stop_seq_len * self.max_seq_len
+                start_padded = start_seq_len * self.max_seq_len
+                stop_padded = stop_seq_len * self.max_seq_len

             def map_(path, value):
                 if path[0] != SampleBatch.SEQ_LENS and not path[0].startswith(
                     "state_in_"
                 ):
-                    return value[start:stop]
+                    if path[0] != SampleBatch.INFOS:
+                        return value[start_padded:stop_padded]
+                    else:
+                        return value[start_unpadded:stop_unpadded]
                 else:
                     return value[start_seq_len:stop_seq_len]

EOF
