{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "0545110e",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/miniconda3/envs/vllm082/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "2025-05-25 12:55:57,639\tINFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n"
     ]
    },
    {
     "ename": "ModuleNotFoundError",
     "evalue": "No module named 'verl'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mModuleNotFoundError\u001b[0m                       Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[2], line 37\u001b[0m\n\u001b[1;32m     34\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mtensordict\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m TensorDict\n\u001b[1;32m     35\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mutils\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdata\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m DataLoader\n\u001b[0;32m---> 37\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mverl\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mutils\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpy_functional\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m union_two_dict\n\u001b[1;32m     38\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mverl\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mutils\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mtorch_functional\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m allgather_dict_tensors\n\u001b[1;32m     40\u001b[0m __all__ \u001b[38;5;241m=\u001b[39m [\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mDataProto\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124munion_tensor_dict\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n",
      "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'verl'"
     ]
    }
   ],
   "source": [
    "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n",
    "#\n",
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "#     http://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License.\n",
    "\"\"\"\n",
    "Implement base data transfer protocol between any two functions, modules.\n",
    "We can subclass Protocol to define more detailed batch info with specific keys\n",
    "\"\"\"\n",
    "\n",
    "import contextlib\n",
    "import copy\n",
    "import logging\n",
    "import os\n",
    "import pickle\n",
    "from dataclasses import dataclass, field\n",
    "from typing import Callable, Dict, List, Union\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import ray\n",
    "import tensordict\n",
    "import torch\n",
    "import torch.distributed\n",
    "from packaging import version\n",
    "from tensordict import TensorDict\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "from verl.utils.py_functional import union_two_dict\n",
    "from verl.utils.torch_functional import allgather_dict_tensors\n",
    "\n",
    "__all__ = [\"DataProto\", \"union_tensor_dict\"]\n",
    "\n",
    "with contextlib.suppress(Exception):\n",
    "    tensordict.set_lazy_legacy(False).set()\n",
    "\n",
    "\n",
    "class _DataProtoConfigMeta(type):\n",
    "    _config = {}\n",
    "\n",
    "    auto_padding_key = \"_verl_auto_padding\"\n",
    "\n",
    "    @property\n",
    "    def auto_padding(cls):\n",
    "        enabled_by_env = os.getenv(\"VERL_AUTO_PADDING\", \"FALSE\").upper() in [\"TRUE\", \"1\"]\n",
    "        return enabled_by_env or cls._config.get(cls.auto_padding_key, False)\n",
    "\n",
    "    @auto_padding.setter\n",
    "    def auto_padding(cls, enabled: bool):\n",
    "        assert isinstance(enabled, bool), f\"enabled must be a boolean, got {enabled} as {type(enabled)}\"\n",
    "        cls._config[cls.auto_padding_key] = enabled\n",
    "\n",
    "\n",
    "class DataProtoConfig(metaclass=_DataProtoConfigMeta):\n",
    "    pass\n",
    "\n",
    "\n",
    "_padding_size_key = \"_padding_size_key_x123d\"\n",
    "\n",
    "\n",
    "def pad_dataproto_to_divisor(data: \"DataProto\", size_divisor: int):\n",
    "    \"\"\"Pad a DataProto to size divisible by size_divisor\n",
    "\n",
    "    Args:\n",
    "        size_divisor (int): size divisor\n",
    "\n",
    "    Returns:\n",
    "        data: (DataProto): the padded DataProto\n",
    "        pad_size (int)\n",
    "    \"\"\"\n",
    "    assert isinstance(data, DataProto), \"data must be a DataProto\"\n",
    "    if len(data) % size_divisor != 0:\n",
    "        pad_size = size_divisor - len(data) % size_divisor\n",
    "        padding_protos = []\n",
    "        remaining_pad = pad_size\n",
    "        while remaining_pad > 0:\n",
    "            take_size = min(remaining_pad, len(data))\n",
    "            padding_protos.append(data[:take_size])\n",
    "            remaining_pad -= take_size\n",
    "        data_padded = DataProto.concat([data] + padding_protos)\n",
    "    else:\n",
    "        if len(data) == 0:\n",
    "            logging.warning(\"padding a DataProto with no item, no changed made\")\n",
    "        pad_size = 0\n",
    "        data_padded = data\n",
    "    return data_padded, pad_size\n",
    "\n",
    "\n",
    "def unpad_dataproto(data: \"DataProto\", pad_size):\n",
    "    if pad_size != 0:\n",
    "        data = data[:-pad_size]\n",
    "    return data\n",
    "\n",
    "\n",
    "def union_tensor_dict(tensor_dict1: TensorDict, tensor_dict2: TensorDict) -> TensorDict:\n",
    "    \"\"\"Union two tensordicts.\"\"\"\n",
    "    assert tensor_dict1.batch_size == tensor_dict2.batch_size, f\"Two tensor dict must have identical batch size. Got {tensor_dict1.batch_size} and {tensor_dict2.batch_size}\"\n",
    "    for key in tensor_dict2.keys():\n",
    "        if key not in tensor_dict1.keys():\n",
    "            tensor_dict1[key] = tensor_dict2[key]\n",
    "        else:\n",
    "            assert tensor_dict1[key].equal(tensor_dict2[key]), f\"{key} in tensor_dict1 and tensor_dict2 are not the same object\"\n",
    "\n",
    "    return tensor_dict1\n",
    "\n",
    "\n",
    "def union_numpy_dict(tensor_dict1: dict[str, np.ndarray], tensor_dict2: dict[str, np.ndarray]) -> dict[str, np.ndarray]:\n",
    "    for key, val in tensor_dict2.items():\n",
    "        if key in tensor_dict1:\n",
    "            assert isinstance(tensor_dict2[key], np.ndarray)\n",
    "            assert isinstance(tensor_dict1[key], np.ndarray)\n",
    "            # to properly deal with nan and object type\n",
    "            assert pd.DataFrame(tensor_dict2[key]).equals(pd.DataFrame(tensor_dict1[key])), f\"{key} in tensor_dict1 and tensor_dict2 are not the same object\"\n",
    "        tensor_dict1[key] = val\n",
    "\n",
    "    return tensor_dict1\n",
    "\n",
    "\n",
    "def list_of_dict_to_dict_of_list(list_of_dict: list[dict]):\n",
    "    if len(list_of_dict) == 0:\n",
    "        return {}\n",
    "    keys = list_of_dict[0].keys()\n",
    "    output = {key: [] for key in keys}\n",
    "    for data in list_of_dict:\n",
    "        for key, item in data.items():\n",
    "            assert key in output\n",
    "            output[key].append(item)\n",
    "    return output\n",
    "\n",
    "\n",
    "def fold_batch_dim(data: \"DataProto\", new_batch_size):\n",
    "    \"\"\"\n",
    "    Fold a batch dim from [bsz, xxx] into [new_bsz, bsz // new_bsz, xxx]\n",
    "    \"\"\"\n",
    "    batch_size = data.batch.batch_size[0]\n",
    "\n",
    "    assert batch_size % new_batch_size == 0\n",
    "\n",
    "    tensor: TensorDict = data.batch\n",
    "    non_tensor = data.non_tensor_batch\n",
    "\n",
    "    tensor = tensor.view(new_batch_size, -1)\n",
    "    tensor.auto_batch_size_(batch_dims=1)\n",
    "\n",
    "    for key, val in non_tensor.items():\n",
    "        non_tensor[key] = np.reshape(val, newshape=(new_batch_size, -1, *val.shape[1:]))\n",
    "\n",
    "    return type(data)(batch=tensor, non_tensor_batch=non_tensor, meta_info=data.meta_info)\n",
    "\n",
    "\n",
    "def unfold_batch_dim(data: \"DataProto\", batch_dims=2):\n",
    "    \"\"\"\n",
    "    Unfold the first n dims as new batch dim\n",
    "    \"\"\"\n",
    "    tensor: TensorDict = data.batch\n",
    "    non_tensor = data.non_tensor_batch\n",
    "    tensor.auto_batch_size_(batch_dims=batch_dims)\n",
    "    tensor = tensor.view(-1)\n",
    "\n",
    "    batch_size = tensor.batch_size[0]\n",
    "\n",
    "    non_tensor_new = {}\n",
    "\n",
    "    for key, val in non_tensor.items():\n",
    "        non_tensor_new[key] = np.reshape(val, newshape=(batch_size, *val.shape[batch_dims:]))\n",
    "\n",
    "    return type(data)(batch=tensor, non_tensor_batch=non_tensor_new, meta_info=data.meta_info)\n",
    "\n",
    "\n",
    "def collate_fn(x: list[\"DataProtoItem\"]):\n",
    "    batch = []\n",
    "    non_tensor_batch = []\n",
    "    for data in x:\n",
    "        batch.append(data.batch)\n",
    "        non_tensor_batch.append(data.non_tensor_batch)\n",
    "    batch = torch.stack(batch).contiguous()\n",
    "    non_tensor_batch = list_of_dict_to_dict_of_list(non_tensor_batch)\n",
    "    for key, val in non_tensor_batch.items():\n",
    "        non_tensor_batch[key] = np.array(val, dtype=object)\n",
    "    return DataProto(batch=batch, non_tensor_batch=non_tensor_batch)\n",
    "\n",
    "\n",
    "@dataclass\n",
    "class DataProtoItem:\n",
    "    # TODO(zhangchi.usc1992) add consistency check\n",
    "    batch: TensorDict = None\n",
    "    non_tensor_batch: Dict = field(default_factory=dict)\n",
    "    meta_info: Dict = field(default_factory=dict)\n",
    "\n",
    "\n",
    "@dataclass\n",
    "class DataProto:\n",
    "    \"\"\"\n",
    "    A DataProto is a data structure that aims to provide a standard protocol for data exchange between functions.\n",
    "    It contains a batch (TensorDict) and a meta_info (Dict). The batch is a TensorDict https://pytorch.org/tensordict/.\n",
    "    TensorDict allows you to manipulate a dictionary of Tensors like a single Tensor. Ideally, the tensors with the\n",
    "    same batch size should be put inside batch.\n",
    "    \"\"\"\n",
    "\n",
    "    batch: TensorDict = None\n",
    "    non_tensor_batch: Dict = field(default_factory=dict)\n",
    "    meta_info: Dict = field(default_factory=dict)\n",
    "\n",
    "    def __post_init__(self):\n",
    "        # perform necessary checking\n",
    "        self.check_consistency()\n",
    "\n",
    "    def __len__(self):\n",
    "        if self.batch is not None:\n",
    "            return self.batch.batch_size[0]\n",
    "        elif self.non_tensor_batch is not None and len(self.non_tensor_batch) > 0:\n",
    "            random_key = list(self.non_tensor_batch.keys())[0]\n",
    "            return self.non_tensor_batch[random_key].shape[0]\n",
    "        else:\n",
    "            return 0\n",
    "\n",
    "    def __getitem__(self, item):\n",
    "        \"\"\"\n",
    "        Enhanced indexing for DataProto objects.\n",
    "\n",
    "        Args:\n",
    "            item: Can be one of:\n",
    "                - int: A single index\n",
    "                - slice: A slice object (start:stop:step)\n",
    "                - list: A list of indices\n",
    "                - numpy.ndarray: An array of indices\n",
    "                - torch.Tensor: A tensor of indices\n",
    "\n",
    "        Returns:\n",
    "            DataProto: For all indexing types except single integers\n",
    "            DataProtoItem: Only for single integer indices\n",
    "        \"\"\"\n",
    "        # Case 1: Slice object - use the slice method\n",
    "        if isinstance(item, slice):\n",
    "            return self.slice(item.start, item.stop, item.step)\n",
    "\n",
    "        # Case 2: List, numpy array, or torch tensor - use sel_idxs\n",
    "        elif isinstance(item, (list, np.ndarray, torch.Tensor)):\n",
    "            return self.select_idxs(item)\n",
    "\n",
    "        # Case 3: Single integer - return DataProtoItem for backward compatibility\n",
    "        elif isinstance(item, (int, np.integer)):\n",
    "            tensor_data = self.batch[item]\n",
    "            non_tensor_data = {key: val[item] for key, val in self.non_tensor_batch.items()}\n",
    "            return DataProtoItem(batch=tensor_data, non_tensor_batch=non_tensor_data, meta_info=self.meta_info)\n",
    "\n",
    "        # # Case 4: Unsupported type\n",
    "        else:\n",
    "            raise TypeError(f\"Indexing with {type(item)} is not supported\")\n",
    "\n",
    "    def __getstate__(self):\n",
    "        import io\n",
    "\n",
    "        buffer = io.BytesIO()\n",
    "        if version.parse(tensordict.__version__) >= version.parse(\"0.5.0\") and self.batch is not None:\n",
    "            self.batch = self.batch.contiguous()\n",
    "            self.batch = self.batch.consolidate()\n",
    "        torch.save(self.batch, buffer)\n",
    "        buffer_bytes = buffer.getvalue()\n",
    "        return buffer_bytes, self.non_tensor_batch, self.meta_info\n",
    "\n",
    "    def __setstate__(self, data):\n",
    "        import io\n",
    "\n",
    "        batch_deserialized_bytes, non_tensor_batch, meta_info = data\n",
    "        batch_deserialized = io.BytesIO(initial_bytes=batch_deserialized_bytes)\n",
    "        batch = torch.load(batch_deserialized, weights_only=False, map_location=\"cpu\" if not torch.cuda.is_available() else None)\n",
    "        self.batch = batch\n",
    "        self.non_tensor_batch = non_tensor_batch\n",
    "        self.meta_info = meta_info\n",
    "\n",
    "    def save_to_disk(self, filepath):\n",
    "        with open(filepath, \"wb\") as f:\n",
    "            pickle.dump(self, f)\n",
    "\n",
    "    @staticmethod\n",
    "    def load_from_disk(filepath) -> \"DataProto\":\n",
    "        with open(filepath, \"rb\") as f:\n",
    "            data = pickle.load(f)\n",
    "            return data\n",
    "\n",
    "    def print_size(self, prefix=\"\"):\n",
    "        size_of_tensordict = 0\n",
    "        for key, tensor in self.batch.items():\n",
    "            size_of_tensordict += tensor.element_size() * tensor.numel()\n",
    "        size_of_numpy_array = 0\n",
    "        for key, numpy_array in self.non_tensor_batch.items():\n",
    "            size_of_numpy_array += numpy_array.nbytes\n",
    "\n",
    "        size_of_numpy_array /= 1024**3\n",
    "        size_of_tensordict /= 1024**3\n",
    "\n",
    "        message = f\"Size of tensordict: {size_of_tensordict} GB, size of non_tensor_batch: {size_of_numpy_array} GB\"\n",
    "\n",
    "        if prefix:\n",
    "            message = f\"{prefix}, \" + message\n",
    "        print(message)\n",
    "\n",
    "    def check_consistency(self):\n",
    "        \"\"\"Check the consistency of the DataProto. Mainly for batch and non_tensor_batch\n",
    "        We expose this function as a public one so that user can call themselves directly\n",
    "        \"\"\"\n",
    "        if self.batch is not None:\n",
    "            assert len(self.batch.batch_size) == 1, \"only support num_batch_dims=1\"\n",
    "\n",
    "        if self.non_tensor_batch is not None:\n",
    "            for key, val in self.non_tensor_batch.items():\n",
    "                assert isinstance(val, np.ndarray)\n",
    "\n",
    "        if self.batch is not None and len(self.non_tensor_batch) != 0:\n",
    "            # TODO: we can actually lift this restriction if needed\n",
    "            assert len(self.batch.batch_size) == 1, \"only support num_batch_dims=1 when non_tensor_batch is not empty.\"\n",
    "\n",
    "            batch_size = self.batch.batch_size[0]\n",
    "            for key, val in self.non_tensor_batch.items():\n",
    "                assert isinstance(val, np.ndarray), f\"data in the non_tensor_batch must be a numpy.array with dtype=object, but for {key=}, got {type(val)=}\"\n",
    "                assert val.shape[0] == batch_size, f\"key {key} length {len(val)} is not equal to batch size {batch_size}\"\n",
    "\n",
    "    @classmethod\n",
    "    def from_single_dict(cls, data: Dict[str, Union[torch.Tensor, np.ndarray]], meta_info=None, auto_padding=False):\n",
    "        \"\"\"Create a DataProto from a dict of tensors and non_tensors\"\"\"\n",
    "        tensors = {}\n",
    "        non_tensors = {}\n",
    "\n",
    "        for key, val in data.items():\n",
    "            if isinstance(val, torch.Tensor):\n",
    "                tensors[key] = val\n",
    "            elif isinstance(val, np.ndarray):\n",
    "                non_tensors[key] = val\n",
    "            else:\n",
    "                raise ValueError(f\"Unsupported type in data {type(val)}\")\n",
    "\n",
    "        return cls.from_dict(tensors=tensors, non_tensors=non_tensors, meta_info=meta_info, auto_padding=auto_padding)\n",
    "\n",
    "    @classmethod\n",
    "    def from_dict(cls, tensors: Dict[str, torch.Tensor], non_tensors=None, meta_info=None, num_batch_dims=1, auto_padding=False):\n",
    "        \"\"\"Create a DataProto from a dict of tensors. This assumes that\n",
    "        1. All the tensor in tensors have the same dim0\n",
    "        2. Only dim0 is the batch dim\n",
    "        \"\"\"\n",
    "        assert len(tensors) > 0, \"tensors must not be empty\"\n",
    "        assert num_batch_dims > 0, \"num_batch_dims must be greater than zero\"\n",
    "        if non_tensors is not None:\n",
    "            assert num_batch_dims == 1, \"only support num_batch_dims=1 when non_tensors is not None.\"\n",
    "\n",
    "        if meta_info is None:\n",
    "            meta_info = {}\n",
    "        if non_tensors is None:\n",
    "            non_tensors = {}\n",
    "\n",
    "        assert isinstance(non_tensors, dict)\n",
    "\n",
    "        # get and check batch size\n",
    "        batch_size = None\n",
    "        pivot_key = None\n",
    "        for key, tensor in tensors.items():\n",
    "            if batch_size is None:\n",
    "                batch_size = tensor.shape[:num_batch_dims]\n",
    "                pivot_key = key\n",
    "            else:\n",
    "                current_batch = tensor.shape[:num_batch_dims]\n",
    "                assert batch_size == current_batch, f\"Not all the tensor in tensors have the same batch size with batch_dims={num_batch_dims}. Got {pivot_key} has {batch_size}, {key} has {current_batch}\"\n",
    "\n",
    "        for key, val in non_tensors.items():\n",
    "            non_tensors[key] = np.array(val, dtype=object)\n",
    "\n",
    "        tensor_dict = TensorDict(source=tensors, batch_size=batch_size)\n",
    "        if auto_padding:\n",
    "            meta_info[DataProtoConfig.auto_padding_key] = True\n",
    "        return cls(batch=tensor_dict, non_tensor_batch=non_tensors, meta_info=meta_info)\n",
    "\n",
    "    def to(self, device) -> \"DataProto\":\n",
    "        \"\"\"move the batch to device\n",
    "\n",
    "        Args:\n",
    "            device (torch.device, str): torch device\n",
    "\n",
    "        Returns:\n",
    "            DataProto: the current DataProto\n",
    "\n",
    "        \"\"\"\n",
    "        if self.batch is not None:\n",
    "            self.batch = self.batch.to(device)\n",
    "        return self\n",
    "\n",
    "    def select(self, batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=None, deepcopy=False) -> \"DataProto\":\n",
    "        \"\"\"Select a subset of the DataProto via batch_keys and meta_info_keys\n",
    "\n",
    "        Args:\n",
    "            batch_keys (list, optional): a list of strings indicating the keys in batch to select\n",
    "            meta_info_keys (list, optional): a list of keys indicating the meta info to select\n",
    "\n",
    "        Returns:\n",
    "            DataProto: the DataProto with the selected batch_keys and meta_info_keys\n",
    "        \"\"\"\n",
    "        # TODO (zhangchi.usc1992) whether to copy\n",
    "        if batch_keys is not None:\n",
    "            batch_keys = tuple(batch_keys)\n",
    "            sub_batch = self.batch.select(*batch_keys)\n",
    "        else:\n",
    "            sub_batch = self.batch\n",
    "\n",
    "        if non_tensor_batch_keys is not None:\n",
    "            non_tensor_batch = {key: val for key, val in self.non_tensor_batch.items() if key in non_tensor_batch_keys}\n",
    "        else:\n",
    "            non_tensor_batch = self.non_tensor_batch\n",
    "\n",
    "        if deepcopy:\n",
    "            non_tensor_batch = copy.deepcopy(non_tensor_batch)\n",
    "\n",
    "        if meta_info_keys is not None:\n",
    "            sub_meta_info = {key: val for key, val in self.meta_info.items() if key in meta_info_keys}\n",
    "        else:\n",
    "            sub_meta_info = self.meta_info\n",
    "\n",
    "        if deepcopy:\n",
    "            sub_meta_info = copy.deepcopy(sub_meta_info)\n",
    "\n",
    "        return type(self)(batch=sub_batch, non_tensor_batch=non_tensor_batch, meta_info=sub_meta_info)\n",
    "\n",
    "    def select_idxs(self, idxs):\n",
    "        \"\"\"\n",
    "        Select specific indices from the DataProto.\n",
    "\n",
    "        Args:\n",
    "            idxs (torch.Tensor or numpy.ndarray or list): Indices to select\n",
    "\n",
    "        Returns:\n",
    "            DataProto: A new DataProto containing only the selected indices\n",
    "        \"\"\"\n",
    "        if isinstance(idxs, list):\n",
    "            idxs = torch.tensor(idxs)\n",
    "            if idxs.dtype != torch.bool:\n",
    "                idxs = idxs.type(torch.int32)\n",
    "\n",
    "        if isinstance(idxs, np.ndarray):\n",
    "            idxs_np = idxs\n",
    "            idxs_torch = torch.from_numpy(idxs)\n",
    "        else:  # torch.Tensor\n",
    "            idxs_torch = idxs\n",
    "            idxs_np = idxs.detach().cpu().numpy()\n",
    "\n",
    "        batch_size = idxs_np.sum() if idxs_np.dtype == bool else idxs_np.shape[0]\n",
    "\n",
    "        if self.batch is not None:\n",
    "            # Use TensorDict's built-in indexing capabilities\n",
    "            selected_batch = TensorDict(source={key: tensor[idxs_torch] for key, tensor in self.batch.items()}, batch_size=(batch_size,))\n",
    "        else:\n",
    "            selected_batch = None\n",
    "\n",
    "        selected_non_tensor = {}\n",
    "        for key, val in self.non_tensor_batch.items():\n",
    "            selected_non_tensor[key] = val[idxs_np]\n",
    "\n",
    "        return type(self)(batch=selected_batch, non_tensor_batch=selected_non_tensor, meta_info=self.meta_info)\n",
    "\n",
    "    def slice(self, start=None, end=None, step=None):\n",
    "        \"\"\"\n",
    "        Slice the DataProto and return a new DataProto object.\n",
    "        This is an improved version of direct slicing which returns a DataProtoItem.\n",
    "\n",
    "        Args:\n",
    "            start (int, optional): Start index. Defaults to None (start from beginning).\n",
    "            end (int, optional): End index (exclusive). Defaults to None (go to end).\n",
    "            step (int, optional): Step size. Defaults to None (step=1).\n",
    "\n",
    "        Returns:\n",
    "            DataProto: A new DataProto containing the sliced data\n",
    "\n",
    "        Examples:\n",
    "            # Using the slice method directly\n",
    "            sliced_data = data_proto.slice(10, 20)\n",
    "\n",
    "            # Using enhanced indexing (returns DataProto)\n",
    "            sliced_data = data_proto[10:20]\n",
    "            sliced_data = data_proto[::2]  # Every other element\n",
    "\n",
    "            # Using list indexing (returns DataProto)\n",
    "            indices = [1, 5, 10]\n",
    "            selected_data = data_proto[indices]\n",
    "\n",
    "            # Single index still returns DataProtoItem\n",
    "            single_item = data_proto[5]\n",
    "        \"\"\"\n",
    "        # Create a slice object\n",
    "        slice_obj = slice(start, end, step)\n",
    "\n",
    "        # Handle the batch data\n",
    "        if self.batch is not None:\n",
    "            # Use TensorDict's built-in slicing capabilities\n",
    "            sliced_batch = self.batch[slice_obj]\n",
    "        else:\n",
    "            sliced_batch = None\n",
    "\n",
    "        # Handle the non-tensor batch data\n",
    "        sliced_non_tensor = {}\n",
    "        for key, val in self.non_tensor_batch.items():\n",
    "            sliced_non_tensor[key] = val[slice_obj]\n",
    "\n",
    "        # Return a new DataProto object\n",
    "        return type(self)(batch=sliced_batch, non_tensor_batch=sliced_non_tensor, meta_info=self.meta_info)\n",
    "\n",
    "    def pop(self, batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=None) -> \"DataProto\":\n",
    "        \"\"\"Pop a subset of the DataProto via `batch_keys` and `meta_info_keys`\n",
    "\n",
    "        Args:\n",
    "            batch_keys (list, optional): a list of strings indicating the keys in batch to pop\n",
    "            meta_info_keys (list, optional): a list of keys indicating the meta info to pop\n",
    "\n",
    "        Returns:\n",
    "            DataProto: the DataProto with the poped batch_keys and meta_info_keys\n",
    "        \"\"\"\n",
    "        assert batch_keys is not None\n",
    "        if meta_info_keys is None:\n",
    "            meta_info_keys = []\n",
    "        if non_tensor_batch_keys is None:\n",
    "            non_tensor_batch_keys = []\n",
    "\n",
    "        tensors = {}\n",
    "        # tensor batch\n",
    "        for key in batch_keys:\n",
    "            assert key in self.batch.keys()\n",
    "            tensors[key] = self.batch.pop(key)\n",
    "        non_tensors = {}\n",
    "        # non tensor batch\n",
    "        for key in non_tensor_batch_keys:\n",
    "            assert key in self.non_tensor_batch.keys()\n",
    "            non_tensors[key] = self.non_tensor_batch.pop(key)\n",
    "        meta_info = {}\n",
    "        for key in meta_info_keys:\n",
    "            assert key in self.meta_info.keys()\n",
    "            meta_info[key] = self.meta_info.pop(key)\n",
    "        return DataProto.from_dict(tensors=tensors, non_tensors=non_tensors, meta_info=meta_info)\n",
    "\n",
    "    def rename(self, old_keys=None, new_keys=None) -> \"DataProto\":\n",
    "        \"\"\"\n",
    "        Note that this function only rename the key in the batch\n",
    "        \"\"\"\n",
    "\n",
    "        def validate_input(keys):\n",
    "            if keys is not None:\n",
    "                if isinstance(keys, str):\n",
    "                    keys = [keys]\n",
    "                elif isinstance(keys, list):\n",
    "                    pass\n",
    "                else:\n",
    "                    raise TypeError(f\"keys must be a list or a string, but got {type(keys)}\")\n",
    "            return keys\n",
    "\n",
    "        old_keys = validate_input(old_keys)\n",
    "        new_keys = validate_input(new_keys)\n",
    "\n",
    "        if len(new_keys) != len(old_keys):\n",
    "            raise ValueError(f\"new_keys and old_keys must have the same length, but got {len(new_keys)} and {len(old_keys)}\")\n",
    "\n",
    "        self.batch.rename_key_(tuple(old_keys), tuple(new_keys))\n",
    "\n",
    "        return self\n",
    "\n",
    "    def union(self, other: \"DataProto\") -> \"DataProto\":\n",
    "        \"\"\"Union with another DataProto. Union batch and meta_info separately.\n",
    "        Throw an error if\n",
    "\n",
    "        - there are conflict keys in batch and they are not equal\n",
    "        - the batch size of two data batch is not the same\n",
    "        - there are conflict keys in meta_info and they are not the same.\n",
    "\n",
    "        Args:\n",
    "            other (DataProto): another DataProto to union\n",
    "\n",
    "        Returns:\n",
    "            DataProto: the DataProto after union\n",
    "        \"\"\"\n",
    "        self.batch = union_tensor_dict(self.batch, other.batch)\n",
    "        self.non_tensor_batch = union_numpy_dict(self.non_tensor_batch, other.non_tensor_batch)\n",
    "        self.meta_info = union_two_dict(self.meta_info, other.meta_info)\n",
    "        return self\n",
    "\n",
    "    def make_iterator(self, mini_batch_size, epochs, seed=None, dataloader_kwargs=None):\n",
    "        r\"\"\"Make an iterator from the DataProto. This is built upon that TensorDict can be used as a normal Pytorch\n",
    "        dataset. See https://pytorch.org/tensordict/tutorials/data_fashion for more details.\n",
    "\n",
    "\n",
    "        Args:\n",
    "            mini_batch_size (int): mini-batch size when iterating the dataset. We require that ``batch.batch_size[0] % mini_batch_size == 0``.\n",
    "            epochs (int): number of epochs when iterating the dataset.\n",
    "            dataloader_kwargs (Any): internally, it returns a DataLoader over the batch. The dataloader_kwargs is the kwargs passed to the DataLoader.\n",
    "\n",
    "        Returns:\n",
    "            Iterator: an iterator that yields a mini-batch data at a time. The total number of iteration steps is ``self.batch.batch_size * epochs // mini_batch_size``\n",
    "        \"\"\"\n",
    "        assert self.batch.batch_size[0] % mini_batch_size == 0, f\"{self.batch.batch_size[0]} % {mini_batch_size} != 0\"\n",
    "        # we can directly create a dataloader from TensorDict\n",
    "        if dataloader_kwargs is None:\n",
    "            dataloader_kwargs = {}\n",
    "\n",
    "        if seed is not None:\n",
    "            generator = torch.Generator()\n",
    "            generator.manual_seed(seed)\n",
    "        else:\n",
    "            generator = None\n",
    "\n",
    "        assert isinstance(dataloader_kwargs, Dict)\n",
    "        train_dataloader = DataLoader(dataset=self, batch_size=mini_batch_size, collate_fn=collate_fn, generator=generator, **dataloader_kwargs)\n",
    "\n",
    "        def get_data():\n",
    "            for _ in range(epochs):\n",
    "                for d in train_dataloader:\n",
    "                    d.meta_info = self.meta_info\n",
    "                    yield d\n",
    "\n",
    "        return iter(get_data())\n",
    "\n",
    "    def is_padding_enabled(self):\n",
    "        \"\"\"\n",
    "        Check if padding is enabled for the DataProto.\n",
    "        Returns:\n",
    "            bool: True if padding is enabled, False otherwise.\n",
    "        \"\"\"\n",
    "        dataproto_specific_padding = self.meta_info.get(DataProtoConfig.auto_padding_key, False)\n",
    "        return dataproto_specific_padding or DataProtoConfig.auto_padding\n",
    "\n",
    "    def padding(self, padding_size, padding_candidate=\"\"):\n",
    "        \"\"\"Pad the DataProto by concating with padding_candidate.repeat(padding_size)\n",
    "\n",
    "        Args:\n",
    "            padding_size (int): the number of repeated padding_candidate\n",
    "            padding_candidate: the item to be repeated and appended to the DataProto, only supporting [\"first\", \"last\"]\n",
    "        \"\"\"\n",
    "        if padding_size == 0:\n",
    "            return\n",
    "        padding_candidate = self.select_idxs([0 if padding_candidate == \"first\" else len(self) - 1])\n",
    "        padding_part = padding_candidate.repeat(padding_size)\n",
    "        padded_dp = DataProto.concat([self, padding_part])\n",
    "        self.batch = padded_dp.batch\n",
    "        self.non_tensor_batch = padded_dp.non_tensor_batch\n",
    "\n",
    "    def chunk(self, chunks: int) -> List[\"DataProto\"]:\n",
    "        \"\"\"Split the batch among dim=0 into chunks. The meta_info is passed to each DataProto after split.\n",
    "\n",
    "        Args:\n",
    "            chunks (int): the number of chunks to split on dim=0\n",
    "\n",
    "        Returns:\n",
    "            List[DataProto]: a list of DataProto after splitting\n",
    "        \"\"\"\n",
    "        if not self.is_padding_enabled():\n",
    "            assert len(self) % chunks == 0, f\"only support equal chunk. Got size of DataProto {len(self)} and chunk {chunks}.\"\n",
    "\n",
    "        bsz_in_batch = None\n",
    "        if self.batch is not None:\n",
    "            batch_lst = self.batch.chunk(chunks=chunks, dim=0)\n",
    "            bsz_in_batch = np.array([batch.batch_size[0] for batch in batch_lst])\n",
    "            chunk_indices = np.cumsum(bsz_in_batch)[:-1]\n",
    "        else:\n",
    "            batch_lst = [None for _ in range(chunks)]\n",
    "\n",
    "        non_tensor_batch_lst = [{} for _ in range(chunks)]\n",
    "        for key, val in self.non_tensor_batch.items():\n",
    "            assert isinstance(val, np.ndarray)\n",
    "            if bsz_in_batch is not None:\n",
    "                non_tensor_lst = np.array_split(val, chunk_indices.tolist())\n",
    "            else:\n",
    "                non_tensor_lst = np.array_split(val, chunks)\n",
    "            assert len(non_tensor_lst) == chunks\n",
    "            for i in range(chunks):\n",
    "                non_tensor_batch_lst[i][key] = non_tensor_lst[i]\n",
    "\n",
    "        output = []\n",
    "        for i in range(chunks):\n",
    "            output.append(type(self)(batch=batch_lst[i], non_tensor_batch=non_tensor_batch_lst[i], meta_info=self.meta_info))\n",
    "\n",
    "        return output\n",
    "\n",
    "    @staticmethod\n",
    "    def concat(data: List[\"DataProto\"]) -> \"DataProto\":\n",
    "        \"\"\"Concat a list of DataProto. The batch is concatenated among dim=0.\n",
    "        The meta_info is assumed to be identical and will use the first one.\n",
    "\n",
    "        Args:\n",
    "            data (List[DataProto]): list of DataProto\n",
    "\n",
    "        Returns:\n",
    "            DataProto: concatenated DataProto\n",
    "        \"\"\"\n",
    "        batch_lst = []\n",
    "        for batch in data:\n",
    "            batch_lst.append(batch.batch)\n",
    "        new_batch = torch.cat(batch_lst, dim=0) if batch_lst[0] is not None else None\n",
    "\n",
    "        non_tensor_batch = list_of_dict_to_dict_of_list(list_of_dict=[d.non_tensor_batch for d in data])\n",
    "        for key, val in non_tensor_batch.items():\n",
    "            non_tensor_batch[key] = np.concatenate(val, axis=0)\n",
    "\n",
    "        cls = type(data[0]) if len(data) > 0 else DataProto\n",
    "        return cls(batch=new_batch, non_tensor_batch=non_tensor_batch, meta_info=data[0].meta_info)\n",
    "\n",
    "    def reorder(self, indices):\n",
    "        \"\"\"\n",
    "        Note that this operation is in-place\n",
    "        \"\"\"\n",
    "        indices_np = indices.detach().numpy()\n",
    "        self.batch = self.batch[indices]\n",
    "        self.non_tensor_batch = {key: val[indices_np] for key, val in self.non_tensor_batch.items()}\n",
    "\n",
    "    def repeat(self, repeat_times=2, interleave=True):\n",
    "        \"\"\"\n",
    "        Repeat the batch data a specified number of times.\n",
    "\n",
    "        Args:\n",
    "            repeat_times (int): Number of times to repeat the data.\n",
    "            interleave (bool): Whether to interleave the repeated data.\n",
    "\n",
    "        Returns:\n",
    "            DataProto: A new DataProto with repeated data.\n",
    "        \"\"\"\n",
    "        if self.batch is not None:\n",
    "            if interleave:\n",
    "                # Interleave the data\n",
    "                repeated_tensors = {key: tensor.repeat_interleave(repeat_times, dim=0) for key, tensor in self.batch.items()}\n",
    "            else:\n",
    "                # Stack the data\n",
    "                repeated_tensors = {key: tensor.unsqueeze(0).expand(repeat_times, *tensor.shape).reshape(-1, *tensor.shape[1:]) for key, tensor in self.batch.items()}\n",
    "\n",
    "            repeated_batch = TensorDict(\n",
    "                source=repeated_tensors,\n",
    "                batch_size=(self.batch.batch_size[0] * repeat_times,),\n",
    "            )\n",
    "        else:\n",
    "            repeated_batch = None\n",
    "\n",
    "        repeated_non_tensor_batch = {}\n",
    "        for key, val in self.non_tensor_batch.items():\n",
    "            if interleave:\n",
    "                repeated_non_tensor_batch[key] = np.repeat(val, repeat_times, axis=0)\n",
    "            else:\n",
    "                repeated_non_tensor_batch[key] = np.tile(val, (repeat_times,) + (1,) * (val.ndim - 1))\n",
    "\n",
    "        return type(self)(\n",
    "            batch=repeated_batch,\n",
    "            non_tensor_batch=repeated_non_tensor_batch,\n",
    "            meta_info=self.meta_info,\n",
    "        )\n",
    "\n",
    "\n",
    "@dataclass\n",
    "class DataProtoFuture:\n",
    "    \"\"\"\n",
    "    DataProtoFuture aims to eliminate actual data fetching on driver. By doing so, the driver doesn't have to wait\n",
    "    for data so that asynchronous execution becomes possible.\n",
    "    DataProtoFuture contains a list of futures from another WorkerGroup of size world_size.\n",
    "    - collect_fn is a Callable that reduces the list of futures to a DataProto\n",
    "    - dispatch_fn is a Callable that partitions the DataProto into a list of DataProto of size world_size and then select\n",
    "\n",
    "    Potential issue: we can optimize dispatch_fn(collect_fn) such that only needed data is fetched on destination\n",
    "    - DataProtoFuture only supports directly passing from the output of a method to another input. You can't perform any\n",
    "    operation on the DataProtoFuture in driver.\n",
    "    \"\"\"\n",
    "\n",
    "    collect_fn: Callable\n",
    "    futures: List[ray.ObjectRef]\n",
    "    dispatch_fn: Callable = None\n",
    "\n",
    "    @staticmethod\n",
    "    def concat(data: List[ray.ObjectRef]) -> \"DataProtoFuture\":\n",
    "        output = DataProtoFuture(collect_fn=DataProto.concat, futures=data)\n",
    "        return output\n",
    "\n",
    "    def chunk(self, chunks: int) -> List[\"DataProtoFuture\"]:\n",
    "        from functools import partial\n",
    "\n",
    "        arg_future_lst = []\n",
    "        for i in range(chunks):\n",
    "            # note that we can't directly pass i and chunks\n",
    "            def dispatch_fn(x, i, chunks):\n",
    "                return x.chunk(chunks=chunks)[i]\n",
    "\n",
    "            arg_future = DataProtoFuture(collect_fn=self.collect_fn, dispatch_fn=partial(dispatch_fn, i=i, chunks=chunks), futures=self.futures)\n",
    "            arg_future_lst.append(arg_future)\n",
    "        return arg_future_lst\n",
    "\n",
    "    def get(self):\n",
    "        output = ray.get(self.futures)  # dp_size.\n",
    "        for o in output:\n",
    "            assert isinstance(o, DataProto)\n",
    "        output = self.collect_fn(output)  # select dp, concat\n",
    "        if self.dispatch_fn is not None:\n",
    "            output = self.dispatch_fn(output)  # split in batch dim, select using dp\n",
    "        return output\n",
    "\n",
    "\n",
    "def all_gather_data_proto(data: DataProto, process_group):\n",
    "    # Note that this is an inplace operator just like torch.distributed.all_gather\n",
    "    group_size = torch.distributed.get_world_size(group=process_group)\n",
    "    assert isinstance(data, DataProto)\n",
    "    prev_device = data.batch.device\n",
    "    data.batch = data.batch.cuda(device=torch.cuda.current_device())\n",
    "    data.batch = allgather_dict_tensors(data.batch.contiguous(), size=group_size, group=process_group, dim=0)\n",
    "    data.batch = data.batch.to(prev_device)\n",
    "    # all gather non_tensor_batch\n",
    "    all_non_tensor_batch = [None for _ in range(group_size)]\n",
    "    torch.distributed.all_gather_object(all_non_tensor_batch, data.non_tensor_batch, group=process_group)\n",
    "    data.non_tensor_batch = {k: np.concatenate([d[k] for d in all_non_tensor_batch]) for k in data.non_tensor_batch}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31fc598d",
   "metadata": {},
   "outputs": [
    {
     "ename": "ModuleNotFoundError",
     "evalue": "No module named 'verl'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mModuleNotFoundError\u001b[0m                       Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[1], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mverl\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m DataProto\n\u001b[1;32m      2\u001b[0m data1 \u001b[38;5;241m=\u001b[39m {}\n\u001b[1;32m      4\u001b[0m test_batch16 \u001b[38;5;241m=\u001b[39m DataProto\u001b[38;5;241m.\u001b[39mfrom_single_dict(data1)\n",
      "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'verl'"
     ]
    }
   ],
   "source": [
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "vllm082",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
