<h1 align="center">MIKASA-Robo</h1>

<h3 align="center">Benchmark for robotic tabletop manipulation memory-intensive tasks</h3>

---

<div align="center" style="display: flex; justify-content: center; gap: 10px;">
    <img src="assets/shell-game-touch-v0.gif" width="200" />
    <img src="assets/chain-of-colors-7.gif" width="200" />
    <img src="assets/take-it-back-v0.gif" width="200" />
    <img src="assets/remember-shape-and-color-5x3.gif" width="200" />
</div>
<p align="center"><i>Example tasks from the MIKASA-Robo benchmark</i></p>

## Table of Contents
- [Overview](#overview)
- [Key Features](#key-features)
- [List of Tasks](#list-of-tasks)
- [Quick Start](#quick-start)
  - [Installation](#installation)
  - [Basic Usage](#basic-usage)
  - [Advanced Usage: Debug Wrappers](#advanced-usage-debug-wrappers)
- [Training](#training)
- [MIKASA-Robo Ideology](#mikasa-robo-ideology)
- [Datasets for Offline RL](#datasets-for-offline-rl)
  - [Download ready-made datasets](#download-ready-made-datasets)
  - [Collect datasets using oracle agents checkpoints](#collect-datasets-using-oracle-agents-checkpoints)
- [Citation](#citation)

## Overview

MIKASA-Robo is a comprehensive benchmark suite for memory-intensive robotic manipulation tasks, part of the MIKASA (Memory-Intensive Skills Assessment Suite for Agents) framework. It features:

- 12 distinct task types with varying difficulty levels
- 32 total tasks covering different memory aspects
- 32 visual-based datasets for Offline RL
- First benchmark specifically designed for testing agent memory in robotic manipulation

## Key Features

- **Diverse Memory Testing**: Covers four fundamental memory types:
  - Object Memory
  - Spatial Memory
  - Sequential Memory
  - Memory Capacity

- **Built on ManiSkill3**: Leverages the powerful [ManiSkill3](https://maniskill.readthedocs.io/en/latest/) framework, providing:
  - GPU parallelization
  - User-friendly interface
  - Customizable environments


## List of Tasks

| Preview | Memory Task | Mode | Brief Description | T | Memory Task Type |
|--------------------------|------------|------|------|---|--|
| <img src="assets/shell-game-touch-v0.gif" width="200"/> | `ShellGame[Mode]-v0` | `Touch`<br>`Push`<br>`Pick` | Memorize the position of the ball after some time being covered by the cups and then interact with the cup the ball is under. | 90 | Object |
| <img src="assets/intercept-medium-v0.gif" width="200"/> | `Intercept[Mode]-v0` | `Slow`<br>`Medium`<br>`Fast` | Memorize the positions of the rolling ball, estimate its velocity through those positions, and then aim the ball at the target. | 90| Spatial |
| <img src="assets/intercept-grab-medium.gif" width="200"/> | `InterceptGrab[Mode]-v0` | `Slow`<br>`Medium`<br>`Fast` | Memorize the positions of the rolling ball, estimate its velocity through those positions, and then catch the ball with the gripper and lift it up. | 90 | Spatial |
| <img src="assets/rotate-lenient-pos-v0.gif" width="200"/> | `RotateLenient[Mode]-v0` | `Pos`<br>`PosNeg` | Memorize the initial position of the peg and rotate it by a given angle. | 90| Spatial |
| <img src="assets/rotate-strict-pos.gif" width="200"/> | `RotateStrict[Mode]-v0` | `Pos`<br>`PosNeg` | Memorize the initial position of the peg and rotate it to a given angle without shifting its center. | 90 | Object |
| <img src="assets/take-it-back-v0.gif" width="200"/> | `TakeItBack-v0` | --- | Memorize the initial position of the cube, move it to the target region, and then return it to its initial position. | 180 | Spatial |
| <img src="assets/remember-color-9-v0.gif" width="200"/> | `RememberColor[Mode]-v0` | `3`/`5`/`9` | Memorize the color of the cube and choose among other colors. | 60 | Object |
| <img src="assets/remember-shape-9-v0.gif" width="200"/> | `RememberShape[Mode]-v0` | `3`/`5`/`9` | Memorize the shape of the cube and choose among other shapes. | 60 | Object |
| <img src="assets/remember-shape-and-color-5x3.gif" width="200"/> | `RememberShapeAndColor[Mode]-v0` | `3×2`/`3×3`<br>`5×3` | Memorize the shape and color of the cube and choose among other shapes and colors. | 60 | Object |
| <img src="assets/bunch-of-colors-7.gif" width="200"/> | `BunchOfColors[Mode]-v0` | `3`/`5`/`7` | Remember the colors of the set of cubes shown simultaneously in the bunch and touch them in any order. | 120 | Capacity |
| <img src="assets/seq-of-colors-7.gif" width="200"/> | `SeqOfColors[Mode]-v0` | `3`/`5`/`7` | Remember the colors of the set of cubes shown sequentially and then select them in any order. | 120 | Capacity |
| <img src="assets/chain-of-colors-7.gif" width="200"/> | `ChainOfColors[Mode]-v0` | `3`/`5`/`7` | Remember the colors of the set of cubes shown sequentially and then select them in the same order. | 120 | Sequential |

**Total: 32 tabletop robotic manipulation memory-intensive tasks in 12 groups**. T - episode timeout.


## Quick Start


## Basic Usage
```python
import mikasa_robo_suite
from mikasa_robo_suite.utils.wrappers import StateOnlyTensorToDictWrapper
from tqdm.notebook import tqdm
import torch
import gymnasium as gym

# Create the environment via gym.make()
# obs_mode="rgb" for modes "RGB", "RGB+joint", "RGB+oracle" etc.
# obs_mode="state" for mode "state"
episode_timeout = 90
env = gym.make("RememberColor9-v0", num_envs=4, obs_mode="rgb", render_mode="all")
env = StateOnlyTensorToDictWrapper(env) # * always use this wrapper!

obs, _ = env.reset(seed=42)
print(obs.keys())
for i in tqdm(range(episode_timeout)):
    action = env.action_space.sample()
    obs, reward, terminated, truncated, info = env.step(torch.from_numpy(action))

env.close()
```

## Advanced Usage: Debug Wrappers
MIKASA-Robo has implemented special task-specific and task-agnostic wrappers that allow you to track the progress of agents training, the reward agents receive, the number of steps agents have taken, and the individual contribution from each reward component. It is not necessary to use these wrappers, but if you do decide not to use them, remember that `env = StateOnlyTensorToDictWrapper(env)` **must always be used** to get the correct observation keys! For mode details see quick_start.ipynb.

### With all task-predefined wrappers
```python
import mikasa_robo_suite
from mikasa_robo_suite.dataset_collectors.get_mikasa_robo_datasets import env_info
from tqdm.notebook import tqdm
import torch
import gymnasium as gym

env_name = "RememberColor9-v0"
obs_mode = "rgb" # or "state"
num_envs = 4
seed = 42

env = gym.make(env_name, num_envs=num_envs, obs_mode=obs_mode, render_mode="all")

state_wrappers_list, episode_timeout = env_info(env_name)
print(f"Episode timeout: {episode_timeout}")
for wrapper_class, wrapper_kwargs in state_wrappers_list:
    env = wrapper_class(env, **wrapper_kwargs)

obs, _ = env.reset(seed=seed)
print(obs.keys())
for i in tqdm(range(episode_timeout)):
    action = env.action_space.sample()
    obs, reward, terminated, truncated, info = env.step(torch.from_numpy(action))

env.close()
```
### With selective wrappers
```python
import mikasa_robo_suite
from mikasa_robo_suite.utils.wrappers import *
from mikasa_robo_suite.memory_envs import *
import gymnasium as gym
from gymnasium.envs.registration import registry
from tqdm.notebook import tqdm

env_name = "ShellGameTouch-v0"
obs_mode = "state"
num_envs = 4
seed = 42

env = gym.make(env_name, num_envs=num_envs, obs_mode=obs_mode, render_mode="all")
max_steps = registry.get(env_name).max_episode_steps
print(f"Episode timeout: {max_steps}")

env = StateOnlyTensorToDictWrapper(env)
env = InitialZeroActionWrapper(env, n_initial_steps=1)
env = ShellGameRenderCupInfoWrapper(env)
env = RenderStepInfoWrapper(env)
env = RenderRewardInfoWrapper(env)
env = DebugRewardWrapper(env)

obs, _ = env.reset(seed=seed)
print(obs.keys())
for i in tqdm(range(max_steps)):
    action = env.action_space.sample()
    obs, reward, terminated, truncated, info = env.step(torch.from_numpy(action))

env.close()
```

## Training
MIKASA-Robo supports multiple training configurations:

### PPO with MLP (State-Based)
```bash
python3 baselines/ppo/ppo_memtasks.py \
    --env_id=RememberColor9-v0 \
    --exp-name=remember-color-9-v0 \
    --num-steps=60 \
    --num_eval_steps=180 \
    --include-state
```

### PPO with MLP (RGB + Joint)
```bash
python3 baselines/ppo/ppo_memtasks.py \
    --env_id=RememberColor9-v0 \
    --exp-name=remember-color-9-v0 \
    --num-steps=60 \
    --num_eval_steps=180 \
    --include-rgb \
    --include-joints
```

### PPO with LSTM (RGB + Joint)
```bash
python3 baselines/ppo/ppo_memtasks_lstm.py \
    --env_id=RememberColor9-v0 \
    --exp-name=remember-color-9-v0 \
    --num-steps=60 \
    --num_eval_steps=180 \
    --include-rgb \
    --include-joints
```

To train with sparse rewards, add `--reward-mode=sparse`.

## MIKASA-Robo Ideology
The agent's memory capabilities can be accessed not only when the environment demands memory, but also when the observations are provided in the correct format. Currently, we have implemented several training modes:

- `state`: In this mode, the agent receives comprehensive, vectorized information about the environment, joints, and TCP pose, along with oracle data that is essential for solving memory-intensive tasks. When trained in this way, the agent addresses the MDP problem and **does not require memory**.

- `RGB+joints`: Here, the agent receives image data from a camera mounted above and from the manipulator's gripper, along with the position and velocity of its joints. This mode provides no additional information, meaning the agent must learn to store and utilize oracle data. It is designed to **test the agent's memory** capabilities.

These training modes are obtained by using correct flags. Thus,
```bash
# To train in `state` mode:
--include-state

# To train in `RGB+joints` mode:
--include-rgb \
--include-joints

# Additionally, for debugging you can add oracle information to the observation:
--include-oracle
```

## Datasets for Offline RL
For Offline RL we have prepared several ready-made datasets available for use immediately after download, as well as checkpoints of trained oracle agents to collect datasets of any size for all MIKASA-Robo tasks using single script.

### Download ready-made datasets
To allow you to quickly start offline training, we provide datasets for all 32 MIKASA-Robo tasks, consisting of 1000 episodes each and available on (we will reveal the non-anonymous link to download later) (~200G in total):

Example of the dataset structure for `ShellGameTouch-v0` with episode timeout T = 90:
```python
import numpy as mp
episode = np.load(f'ShellGameTouch-v0/train_data_781.npz')

print(episode['rgb'].shape) # (90, 128, 128, 6) - two RGB images (view from above and from the gripper)
print(episode['joints'].shape) # (90, 25) - joint positions and velocities, and Tool Center Point (TCP) position and rotation
print(episode['action'].shape) # (90, 8) - action (8-dimensional vector)
print(episode['reward'].shape) # (90, ) - (dense) reward for each step
print(episode['success'].shape) # (90,) - (sparse) success flag for each step
print(episode['done'].shape) # (90, ) - done flag for each step
```

### Collect datasets using oracle agents checkpoints
Download checkpoints (160Mb) of pretrained oracle agents for further datasets collection:
```bash
cd MIKASA-Robo

wget secret/website/oracle_checkpoints.zip

unzip oracle_checkpoints.zip
```

Or, if you want to train oracle agents from scratch, use this code:
```bash
# For single task:
python3 mikasa_robo_suite/dataset_collectors/get_dataset_collectors_ckpt.py --env_id=ShellGameTouch-v0

# For all tasks:
python3 mikasa_robo_suite/dataset_collectors/parallel_training_manager.py
```

Once you download / trained oracle agents checkpoints, you can build datasets of arbitrary size (multiples of 250 episodes) for any MIKASA-Robo task:
```bash
# For single task:
python3 mikasa_robo_suite/dataset_collectors/get_mikasa_robo_datasets.py \
    --env-id=ShellGameTouch-v0 \
    --path-to-save-data="data" \
    --ckpt-dir="." \
    --num-train-data=1000

# For all tasks:
python3 mikasa_robo_suite/dataset_collectors/parallel_dataset_collection_manager.py \
    --path-to-save-data="data" \
    --ckpt-dir="." \
    --num-train-data=1000
```

# Dataset Collection and Conversion to RLDS Format

This repository provides tools and scripts for dataset collection and conversion to RLDS format, tailored specifically for fine-tuning and evaluation of VLA models within the MIKASA-Robo tasks.

## Dataset Collection for MIKASA-Robo

To begin, collect datasets for each MIKASA-Robo task:

1. Follow the dataset collection instructions provided in the original repository.
2. Replace the scripts in the directory `MIKASA-Robo/mikasa_robo_suite/dataset_collectors` with scripts located in `dataset_builder/dataset_collectors` from this repository.

## Conversion to RLDS Format

After collecting the datasets, convert them to the standardized RLDS format as follows:

1. Clone the RLDS dataset builder repository:

```bash
git clone https://github.com/kpertsch/rlds_dataset_builder
```

2. Follow the instructions in the cloned repository to build the `example_dataset`.
3. Copy the folder `dataset_builder/mikasa_dataset` from this repository into the RLDS repository directory.
4. Execute the following command in the `mikasa_dataset` directory to build the dataset:

```bash
tfds build --overwrite
```

# OpenVLA

This section provides instructions for fine-tuning and evaluating the OpenVLA model.

## Fine-tuning OpenVLA

1. Set up the conda environment and install all required packages by running:

```bash
bash openvla/env_setup.sh
```

2. Execute the fine-tuning script (`openvla/vla-scripts/finetune.py`) with your specific dataset paths and parameters. Modify `data_root_dir`, `run_root_dir`, and adjust `--nproc-per-node` to match your GPU availability:

```bash
torchrun --standalone --nnodes 1 --nproc-per-node 1 vla-scripts/finetune.py \
  --vla_path openvla/openvla-7b \
  --data_root_dir "datasets/remember_color_9" \
  --dataset_name mikasa_dataset \
  --run_root_dir "mikasa_finetune_remember_color_9" \
  --use_l1_regression True \
  --use_diffusion False \
  --use_film False \
  --num_images_in_input 2 \
  --use_proprio False \
  --batch_size 8 \
  --learning_rate 5e-4 \
  --num_steps_before_decay 100000 \
  --max_steps 50000 \
  --save_freq 10000 \
  --save_latest_checkpoint_only False \
  --image_aug True \
  --lora_rank 32 \
  --run_id_note parallel_dec--8_acts_chunk--continuous_acts--L1_regression--base_img--wrist_img
```

## Evaluation of OpenVLA

For evaluation on all environments reported in the paper, run the following evaluation script, replacing `pretrained_checkpoint` with the path to your trained model checkpoint:

```bash
bash openvla/eval_all.sh
```

# Octo

This section covers fine-tuning and evaluation steps for the Octo model.

## Fine-tuning Octo

1. Set up the conda environment and required packages:

```bash
bash octo/env_setup.sh
```

2. Run the fine-tuning script (`octo/mikasa_finetune/finetune_new_observation_action_mikasa.py`) specifying your dataset (`data_dir`) and the save location (`save_dir`):

```bash
python mikasa_finetune/finetune_new_observation_action_mikasa.py \
  --pretrained_path=hf://rail-berkeley/octo-small \
  --data_dir="datasets/remember_color_5" \
  --save_dir=mikasa_octo_finetuned_remcol5 \
  --batch_size=32 \
  --freeze_transformer
```

## Evaluation of Octo

Run the evaluation script provided, specifying the path to your fine-tuned checkpoint (`ckpt_path`):

```bash
bash octo/mikasa_finetune/eval_all.sh
```

# Pi0

This section covers fine-tuning and evaluation steps for the Pi0 model.

## Fine-tuning Pi0

1. Set up the conda environment and clone lerobot repo:

```bash
git clone https://github.com/huggingface/lerobot.git
cd lerobot
pip install -e .
```

3. Convert rlds mikasa dataset to lerobot format using (`pi0/convert_to_lerobot.py`)

4. Run the fine-tuning script (`pi0/finetune.py`) specifying your dataset (`dataset.root`, `dataset.repo_id`, `policy.repo_id`):

```bash
python pi0/finetune.py \
  --policy.path=lerobot/pi0 \
  --dataset.root="" \
  --dataset.repo_id="" \
  --policy.repo_id="" \
  --save_freq=500
```

5. Run the evaluation script provided, specifying the path to your fine-tuned checkpoint:

```bash
python pi0_mikasa_eval.py
```


# SpatialVLA

This section provides instructions for fine-tuning and evaluating the SpatialVLA model.

## Fine-tuning SpatialVLA

1. Set up the conda environment and install all required packages by running:

```bash
bash spatialvla/env_setup.sh
```

2. Execute the fine-tuning script (`spatialvla/train/spatialvla_finetune.py`) with your specific dataset paths and parameters. Modify `data_root_dir`, `run_root_dir`, and adjust other params to match your GPU availability:

```bash

GPUS=${GPUS:-8}
GPUS_PER_NODE=${GPUS_PER_NODE:-8}
NODES=$((GPUS / GPUS_PER_NODE))
PER_DEVICE_BATCH_SIZE=${PER_DEVICE_BATCH_SIZE:-32}
BATCH_SIZE=${BATCH_SIZE:-$((GPUS * PER_DEVICE_BATCH_SIZE))}
GRADIENT_ACC=$((BATCH_SIZE / PER_DEVICE_BATCH_SIZE / GPUS))

NUM_WORKERS=${NUM_WORKERS:-1}
shuffle_buffer_size=${shuffle_buffer_size:-8192} 

lr=5e-4
lora=32
lora_alpha=32
lora_target="linear"

model_name_or_path=""
OUTPUT_DIR=""

export LAUNCHER="pytorch"
TORCH_RUN_ARGS=${TORCH_RUN_ARGS:-"--nnodes $NODES --nproc-per-node $GPUS_PER_NODE --master_addr $MASTER_ADDR --master_port $MASTER_PORT"}

torchrun $TORCH_RUN_ARGS \
  train/spatialvla_finetune.py \
  --model_name_or_path ${model_name_or_path} \
  --lora ${lora} \
  --lora_alpha ${lora_alpha} \
  --lora_target ${lora_target} \
  --ignore_data_skip True \
  --data_root_dir "datasets/remember_color_9" \
  --data_mix "mikasa" \
  --shuffle_buffer_size ${shuffle_buffer_size} \
  --obs_backward_steps 0 \
  --obs_backward_delta 1 \
  --action_forward_steps 3 \
  --flash_attn True \
  --output_dir ${OUTPUT_DIR} \
  --overwrite_output_dir False \
  --freeze_vision_tower False \
  --dataloader_num_workers ${NUM_WORKERS} \
  --bf16 True \
  --tf32 True \
  --num_train_epochs ${epoch} \
  --per_device_train_batch_size ${PER_DEVICE_BATCH_SIZE} \
  --gradient_accumulation_steps ${GRADIENT_ACC} \
  --save_strategy steps \
  --save_steps ${save_steps} \
  --save_total_limit 3 \
  --learning_rate ${lr} \
  --weight_decay 0.0 \
  --warmup_ratio 0.005 \
  --lr_scheduler_type linear \
  --logging_steps 500 \
  --do_train True \
  --grad_checkpoint True \
  --deepspeed scripts/zero1.json \
  --report_to tensorboard \
  --log_level warning
```

---
