# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import importlib.util
import io
import json
import os
import shutil
import tempfile
from pathlib import Path
from typing import Any, Callable, Dict, Tuple

import torch

from tensordict import make_tensordict, NonTensorData, pad, TensorDict
from tensordict.utils import _is_non_tensor

from torchrl.data.datasets.common import BaseDatasetExperienceReplay
from torchrl.data.datasets.utils import _get_root_dir
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import (
    Sampler,
    SliceSampler,
    SliceSamplerWithoutReplacement,
)
from torchrl.data.replay_buffers.storages import _collate_id, Storage, TensorStorage
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer

_has_datasets = importlib.util.find_spec("datasets", None) is not None
_has_tv = importlib.util.find_spec("torchvision", None) is not None


class OpenXExperienceReplay(BaseDatasetExperienceReplay):
    """Open X-Embodiment datasets experience replay.

    The Open X-Embodiment Dataset contains 1M+ real robot trajectories
    spanning 22 robot embodiments, collected through a collaboration between
    21 institutions, demonstrating 527 skills (160266 tasks).

    Website: https://robotics-transformer-x.github.io/

    GitHub: https://github.com/google-deepmind/open_x_embodiment

    Paper: https://arxiv.org/abs/2310.08864

    The data format follows the :ref:`TED convention <TED-format>`.

    .. note::
        Non-tensor data will be written in the tensordict data using the
        :class:`~tensordict.tensorclass.NonTensorData` primitive.
        For instance, the `language_instruction` field in the data will
        be stored in `data.get_non_tensor("language_instruction")` (or equivalently
        `data.get("language_instruction").data`). See the documentation of this
        class for more information on how to interact with non-tensor data
        stored in a :class:`~tensordict.TensorDict`.

    Args:
        dataset_id (str): The dataset to be downloaded.
            Must be part of ``OpenXExperienceReplay.available_datasets``.
        batch_size (int): Batch-size used during sampling.
            Can be overridden by `data.sample(batch_size)` if necessary.
            See ``num_slices`` and ``slice_len`` keyword arguments for a refined
            sampling strategy.
            If the ``batch_size`` is ``None`` (default), iterating over the
            dataset will deliver trajectories one at a time *whereas* calling
            :meth:`~.sample` will *still* require a batch-size to be provided.

    Keyword Args:
        shuffle (bool, optional): if ``True``, trajectories are delivered in a
            random order when the dataset is iterated over.
            If ``False``, the dataset is iterated over in the pre-defined order.

            .. warning::
              shuffle=False will also impact the sampling. We advice users to
              create a copy of the dataset where the ``shuffle`` attribute of the
              sampler is set to ``False`` if they wish to enjoy the two different
              behaviours (shuffled and not) within the same code base.

        num_slices (int, optional): the number of slices in a batch. This
            corresponds to the number of trajectories present in a batch.
            Once collected, the batch is presented as a concatenation of
            sub-trajectories that can be recovered through `batch.reshape(num_slices, -1)`.
            The `batch_size` must be divisible by `num_slices` if provided.
            This argument is exclusive with ``slice_len``.
            If the ``num_slices`` argument equates the ``batch_size``, each sample
            will belong to a different trajectory.
            If neither ``slice_len`` nor ``num_slice`` are provided:
            whenever a trajectory has a length shorter than the
            batch-size, a contiguous slice of it of length `batch_size` will be
            sampled. If the trajectory length is insufficient, an exception will
            be raised unless `pad` is not `None`.
        slice_len (int, optional): the length of slices in a batch. This
            corresponds to the length of trajectories present in a batch.
            Once collected, the batch is presented as a concatenation of
            sub-trajectories that can be recovered through `batch.reshape(-1, slice_len)`.
            The `batch_size` must be divisible by `slice_len` if provided.
            This argument is exclusive with ``num_slice``.
            If the ``slice_len`` argument equates ``1``, each sample
            will belong to a different trajectory.
            If neither ``slice_len`` nor ``num_slice`` are provided:
            whenever a trajectory has a length shorter than the
            batch-size, a contiguous slice of it of length `batch_size` will be
            sampled. If the trajectory length is insufficient, an exception will
            be raised unless `pad` is not `None`.

            .. note::
              The ``slice_len`` (but not ``num_slices``) can be used when
              iterating over a dataset without passing a batch-size in the,
              constructor. In these cases, a random sub-sequence of the
              trajectory will be chosen.

        replacement (bool, optional): if ``False``, sampling will be done
            without replacement. Defaults to ``True`` for downloaded datasets,
            ``False`` for streamed datasets.
        pad (bool, float or None): if ``True``, trajectories of insufficient length
            given the `slice_len` or `num_slices` arguments will be padded with
            0s. If another value is provided, it will be used for padding. If
            ``False`` or ``None`` (default) any encounter with a trajectory of
            insufficient length will raise an exception.
        root (Path or str, optional): The OpenX dataset root directory.
            The actual dataset memory-mapped files will be saved under
            `<root>/<dataset_id>`. If none is provided, it defaults to
            ``~/.cache/torchrl/openx`.
        streaming (bool, optional): if ``True``, the data won't be downloaded but
            read from a stream instead.

            .. note:: The formatting of the data __will change__ when `download=True`
                compared to `streaming=True`. If the data is downloaded and
                the sampler is left untouched (ie, `num_slices=None`, `slice_len=None`
                and `sampler=None`, transitions will be sampled randomly from
                the dataset. This isn't possible at a reasonable cost with
                `streaming=True`: in this case, trajectories will be sampled
                one at a time and delivered as such (with cropping to comply with
                the batch-size etc). The behaviour of the two modalities is
                much more similar when `num_slices` and `slice_len` are specified,
                as in these cases, views of sub-episodes will be returned in both
                cases.

        download (bool or str, optional): Whether the dataset should be downloaded if
            not found. Defaults to ``True``. Download can also be passed as "force",
            in which case the downloaded data will be overwritten.
        sampler (Sampler, optional): the sampler to be used. If none is provided
            a default RandomSampler() will be used.
        writer (Writer, optional): the writer to be used. If none is provided
            a default :class:`~torchrl.data.replay_buffers.writers.ImmutableDatasetWriter` will be used.
        collate_fn (callable, optional): merges a list of samples to form a
            mini-batch of Tensor(s)/outputs.  Used when using batched
            loading from a map-style dataset.
        pin_memory (bool): whether pin_memory() should be called on the rb
            samples.
        prefetch (int, optional): number of next batches to be prefetched
            using multithreading.
        transform (Transform, optional): Transform to be executed when sample() is called.
            To chain transforms use the :class:`~torchrl.envs.transforms.transforms.Compose` class.
        split_trajs (bool, optional): if ``True``, the trajectories will be split
            along the first dimension and padded to have a matching shape.
            To split the trajectories, the ``"done"`` signal will be used, which
            is recovered via ``done = truncated | terminated``. In other words,
            it is assumed that any ``truncated`` or ``terminated`` signal is
            equivalent to the end of a trajectory.
            Defaults to ``False``.
        strict_length (bool, optional): if ``False``, trajectories of length
            shorter than `slice_len` (or `batch_size // num_slices`) will be
            allowed to appear in the batch.
            Be mindful that this can result in effective `batch_size`  shorter
            than the one asked for! Trajectories can be split using
            :func:`torchrl.collectors.split_trajectories`. Defaults to ``True``.

    Examples:
        >>> from torchrl.data.datasets import OpenXExperienceReplay
        >>> import tempfile
        >>> # Download the data, and sample 128 elements in each batch out of two trajectories
        >>> num_slices = 2
        >>> with tempfile.TemporaryDirectory() as root:
        ...     dataset = OpenXExperienceReplay("cmu_stretch", batch_size=128,
        ...         num_slices=num_slices, download=True, streaming=False,
        ...         root=root,
        ...         )
        ...     for batch in dataset:
        ...         print(batch.reshape(num_slices, -1))
        ...         break
        TensorDict(
            fields={
                action: Tensor(shape=torch.Size([2, 64, 8]), device=cpu, dtype=torch.float64, is_shared=False),
                discount: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.float32, is_shared=False),
                done: Tensor(shape=torch.Size([2, 64, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                episode: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.int32, is_shared=False),
                index: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.int64, is_shared=False),
                is_init: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.bool, is_shared=False),
                language_embedding: Tensor(shape=torch.Size([2, 64, 512]), device=cpu, dtype=torch.float64, is_shared=False),
                language_instruction: NonTensorData(
                    data='lift open green garbage can lid',
                    batch_size=torch.Size([2, 64]),
                    device=cpu,
                    is_shared=False),
                next: TensorDict(
                    fields={
                        done: Tensor(shape=torch.Size([2, 64, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                        observation: TensorDict(
                            fields={
                                image: Tensor(shape=torch.Size([2, 64, 3, 128, 128]), device=cpu, dtype=torch.uint8, is_shared=False),
                                state: Tensor(shape=torch.Size([2, 64, 4]), device=cpu, dtype=torch.float64, is_shared=False)},
                            batch_size=torch.Size([2, 64]),
                            device=cpu,
                            is_shared=False),
                        reward: Tensor(shape=torch.Size([2, 64, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                        terminated: Tensor(shape=torch.Size([2, 64, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                        truncated: Tensor(shape=torch.Size([2, 64, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
                    batch_size=torch.Size([2, 64]),
                    device=cpu,
                    is_shared=False),
                observation: TensorDict(
                    fields={
                        image: Tensor(shape=torch.Size([2, 64, 3, 128, 128]), device=cpu, dtype=torch.uint8, is_shared=False),
                        state: Tensor(shape=torch.Size([2, 64, 4]), device=cpu, dtype=torch.float64, is_shared=False)},
                    batch_size=torch.Size([2, 64]),
                    device=cpu,
                    is_shared=False),
                terminated: Tensor(shape=torch.Size([2, 64, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                truncated: Tensor(shape=torch.Size([2, 64, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([2, 64]),
            device=cpu,
            is_shared=False)
        >>> # Read data from a stream. Deliver entire trajectories when iterating
        >>> dataset = OpenXExperienceReplay("cmu_stretch",
        ...     num_slices=num_slices, download=False, streaming=True)
        >>> for data in dataset: # data does not have a consistent shape
        ...     break
        >>> # Define batch-size dynamically
        >>> data = dataset.sample(128)  # delivers 2 sub-trajectories of length 64

    """

    available_datasets = [
        "fractal20220817_data",
        "kuka",
        "bridge",
        "taco_play",
        "jaco_play",
        "berkeley_cable_routing",
        "roboturk",
        "nyu_door_opening_surprising_effectiveness",
        "viola",
        "berkeley_autolab_ur5",
        "toto",
        "language_table",
        "columbia_cairlab_pusht_real",
        "stanford_kuka_multimodal_dataset_converted_externally_to_rlds",
        "nyu_rot_dataset_converted_externally_to_rlds",
        "stanford_hydra_dataset_converted_externally_to_rlds",
        "austin_buds_dataset_converted_externally_to_rlds",
        "nyu_franka_play_dataset_converted_externally_to_rlds",
        "maniskill_dataset_converted_externally_to_rlds",
        "furniture_bench_dataset_converted_externally_to_rlds",
        "cmu_franka_exploration_dataset_converted_externally_to_rlds",
        "ucsd_kitchen_dataset_converted_externally_to_rlds",
        "ucsd_pick_and_place_dataset_converted_externally_to_rlds",
        "austin_sailor_dataset_converted_externally_to_rlds",
        "austin_sirius_dataset_converted_externally_to_rlds",
        "bc_z",
        "usc_cloth_sim_converted_externally_to_rlds",
        "utokyo_pr2_opening_fridge_converted_externally_to_rlds",
        "utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds",
        "utokyo_saytap_converted_externally_to_rlds",
        "utokyo_xarm_pick_and_place_converted_externally_to_rlds",
        "utokyo_xarm_bimanual_converted_externally_to_rlds",
        "robo_net",
        "berkeley_mvp_converted_externally_to_rlds",
        "berkeley_rpt_converted_externally_to_rlds",
        "kaist_nonprehensile_converted_externally_to_rlds",
        "stanford_mask_vit_converted_externally_to_rlds",
        "tokyo_u_lsmo_converted_externally_to_rlds",
        "dlr_sara_pour_converted_externally_to_rlds",
        "dlr_sara_grid_clamp_converted_externally_to_rlds",
        "dlr_edan_shared_control_converted_externally_to_rlds",
        "asu_table_top_converted_externally_to_rlds",
        "stanford_robocook_converted_externally_to_rlds",
        "eth_agent_affordances",
        "imperialcollege_sawyer_wrist_cam",
        "iamlab_cmu_pickup_insert_converted_externally_to_rlds",
        "uiuc_d3field",
        "utaustin_mutex",
        "berkeley_fanuc_manipulation",
        "cmu_playing_with_food",
        "cmu_play_fusion",
        "cmu_stretch",
        "berkeley_gnm_recon",
        "berkeley_gnm_cory_hall",
        "berkeley_gnm_sac_son",
    ]

    # some very high number that should be above all trajecory lengths in the dataset
    _MAX_TRAJ_LEN = 1_000_000

    def __init__(
        self,
        dataset_id,
        batch_size: int | None = None,
        *,
        shuffle: bool = True,
        num_slices: int | None = None,
        slice_len: int | None = None,
        pad: float | bool | None = None,
        replacement: bool = None,
        streaming: bool | None = None,
        root: str | Path | None = None,
        download: bool | None = None,
        sampler: Sampler | None = None,
        writer: Writer | None = None,
        collate_fn: Callable | None = None,
        pin_memory: bool = False,
        prefetch: int | None = None,
        transform: "torchrl.envs.Transform" | None = None,  # noqa-F821
        split_trajs: bool = False,
        strict_length: bool = True,
    ):
        if download is None and streaming is None:
            download = False
            streaming = True
        elif download is None:
            download = not streaming
        elif streaming is None:
            streaming = not download
        self.download = download
        self.streaming = streaming
        self.dataset_id = dataset_id
        self.split_trajs = split_trajs
        self.shuffle = shuffle
        self.num_slices = num_slices
        self.slice_len = slice_len
        self.pad = pad
        self.strict_length = strict_length
        if (self.num_slices is not None) and (self.slice_len is not None):
            raise ValueError("num_slices or slice_len can be not None, but not both.")
        if split_trajs:
            raise NotImplementedError
        if not streaming:
            if replacement is None:
                replacement = True
            if pad is not None:
                raise RuntimeError(
                    "the `pad` argument is to be used only with streaming datasets."
                )
            if root is None:
                root = _get_root_dir("openx")
                os.makedirs(root, exist_ok=True)
            self.root = Path(root)
            if self.download == "force" or (
                self.download and not self._is_downloaded()
            ):
                if download == "force" and os.path.exists(self.data_path_root):
                    shutil.rmtree(self.data_path_root)

                storage = self._download_and_preproc()
            else:
                storage = TensorStorage(TensorDict.load_memmap(self.root / dataset_id))
            if num_slices is not None or slice_len is not None:
                if sampler is not None:
                    raise ValueError(
                        "`num_slices` and `slice_len` are exclusive with the `sampler` argument."
                    )

                if replacement:
                    if not self.shuffle:
                        raise RuntimeError(
                            "shuffle=False can only be used when replacement=False."
                        )
                    sampler = SliceSampler(
                        num_slices=num_slices,
                        slice_len=slice_len,
                        strict_length=strict_length,
                    )
                else:
                    sampler = SliceSamplerWithoutReplacement(
                        num_slices=num_slices,
                        slice_len=slice_len,
                        strict_length=strict_length,
                        shuffle=self.shuffle,
                    )

        else:
            if replacement is True:
                # replacement can be False or None
                raise RuntimeError(
                    "replacement=True is not available with streamed datasets."
                )
            self.root = None
            if download:
                raise ValueError(
                    "download and streaming cannot be set to ``True`` concomitantly."
                )
            storage = _StreamingStorage(
                dataset_id=dataset_id,
                shuffle=self.shuffle,
                num_slices=self.num_slices,
                slice_len=self.slice_len,
                pad=self.pad,
            )
            if sampler is None:
                sampler = _StreamingSampler()
        if writer is None:
            writer = ImmutableDatasetWriter()
        if collate_fn is None:
            collate_fn = _collate_id
        super().__init__(
            storage=storage,
            sampler=sampler,
            writer=writer,
            collate_fn=collate_fn,
            pin_memory=pin_memory,
            prefetch=prefetch,
            batch_size=batch_size,
            transform=transform,
        )

    def __iter__(self):
        if self._batch_size is None:
            # we can still iterate over the dataset
            if isinstance(self._storage, _StreamingStorage):
                yield from self._storage
            elif self.slice_len is not None and self.num_slices is None:
                try:
                    # truncate the trajs with slice_len
                    self._batch_size = self.slice_len
                    self.num_slices = 1
                    self.slice_len = None
                    yield from self
                finally:
                    self.slice_len = self._batch_size
                    self._batch_size = None
                    self.num_slices = None
            else:
                # if we don't have a batch size but we know how many trajectories
                # we want in each batch, we can build that on the fly.
                # The only time we can do this is if num_slices is given but not
                # slice_len.
                num_slices = self.num_slices
                if not num_slices:
                    num_slices = 1
                sampler = SliceSamplerWithoutReplacement(
                    num_slices=num_slices,
                    strict_length=False,
                    shuffle=self.shuffle,
                )
                batch_size = self._MAX_TRAJ_LEN
                yield from TensorDictReplayBuffer(
                    storage=self._storage,
                    sampler=sampler,
                    batch_size=batch_size,
                    transform=self._transform,
                )
        else:
            yield from super().__iter__()

    @property
    def data_path(self):
        if self.streaming:
            return None
        if self.split_trajs:
            return Path(self.root) / (self.dataset_id + "_split")
        return self.data_path_root

    @property
    def data_path_root(self):
        if self.streaming:
            return None
        return self.root / self.dataset_id

    def _is_downloaded(self):
        return os.path.exists(self.data_path_root)

    def _download_and_preproc(self):
        if not _has_datasets:
            raise ImportError(
                f"the `datasets` library is required for the dataset {self.dataset_id}."
            )
        import datasets

        with tempfile.TemporaryDirectory() as cache_dir:
            dataset = datasets.load_dataset(
                "jxu124/OpenX-Embodiment",
                self.dataset_id,
                streaming=False,
                split="train",
                cache_dir=cache_dir,
            )
            # iterate over the dataset a first time to count elements
            total_frames = 0

            try:
                import tqdm

                _has_tqdm = True
                pbar = tqdm.tqdm(dataset, desc="counting")
            except ImportError:
                _has_tqdm = False
                pbar = dataset

            for data in pbar:
                if total_frames == 0:
                    for step in data["data.pickle"]["steps"]:
                        td = _make_tensordict_image_conv(step).zero_()
                        # format td: requires td to have a non-null batch_size
                        td = td.expand(2, *td.shape)
                        _format_data(td, 0)
                        td = td[0]
                total_frames += len(data["data.pickle"]["steps"])
            td_data = td.expand(total_frames)

            def expand_non_tensor(x):
                if isinstance(x, NonTensorData):
                    return x.maybe_to_stack()
                return x

            td_data = td_data._apply_nest(
                expand_non_tensor,
                is_leaf=lambda x: issubclass(x, torch.Tensor) or _is_non_tensor(x),
            )
            td_data = td_data.memmap_like(self.root / self.dataset_id)
            if _has_tqdm:
                pbar = tqdm.tqdm(dataset, desc="preproc", total=total_frames)
            else:
                pbar = dataset
            idx0 = 0
            idx1 = 0
            episode = 0
            for data in pbar:
                current_ep = torch.stack(
                    [
                        _make_tensordict_image_conv(step)
                        for step in data["data.pickle"]["steps"]
                    ]
                ).contiguous()
                _format_data(current_ep, episode)
                episode += 1
                idx1 += len(current_ep)
                td_data[idx0:idx1] = current_ep
                idx0 = idx1
                if _has_tqdm:
                    pbar.update(current_ep.shape[0])
            return TensorStorage(td_data.lock_())


class _StreamingStorage(Storage):
    SLICE_MISMATCH = "The batch_size {} must be divisible by num_slices {} or slice_len {} if provided."

    def __init__(
        self,
        dataset_id: str,
        repo: str = "jxu124/OpenX-Embodiment",
        split="train",
        base_path="data.pickle",
        shuffle: bool = True,
        truncate: bool = True,
        num_slices=None,
        slice_len=None,
        pad=None,
    ):
        self.shuffle = shuffle
        self.dataset_id = dataset_id
        self.repo = repo
        self.split = split
        self._init()
        self.base_path = base_path
        self.truncate = truncate
        self.num_slices = num_slices
        self.slice_len = slice_len
        self.pad = pad

    def _init(self):
        if not _has_datasets:
            raise ImportError(
                f"the `datasets` library is required for the dataset {self.dataset_id}."
            )
        import datasets

        dataset = datasets.load_dataset(
            self.repo, self.dataset_id, streaming=True, split=self.split
        )
        if self.shuffle:
            dataset = dataset.shuffle()
        self.dataset = dataset
        self.dataset_iter = iter(dataset)

    def __iter__(self):
        episode = 0
        for data in self.dataset:
            if self.base_path:
                data = data[self.base_path]
            data = torch.stack(
                [_make_tensordict_image_conv(step) for step in data["steps"]]
            ).contiguous()
            _format_data(data, episode)
            if self.slice_len is not None:
                yield _slice_data(data, slice_len=self.slice_len, pad_value=self.pad)
            else:
                yield data

    def get(self, index: range | torch.Tensor) -> Any:
        if not isinstance(index, range):
            if (index[1:] != index[:-1] + 1).any():
                # we use a range to indicate how much data we want
                raise RuntimeError("iterable datasets do not support indexing.")
            index = range(index.shape[0])
        total = 0
        data_list = []
        episode = 0
        batch_size = index.stop
        if self.num_slices is not None:
            if batch_size % self.num_slices != 0:
                raise ValueError(
                    self.SLICE_MISMATCH.format(
                        batch_size, self.num_slices, self.slice_len
                    )
                )
            num_slices = self.num_slices
            slice_len = batch_size // num_slices
        else:
            if batch_size % self.slice_len != 0:
                raise ValueError(
                    self.SLICE_MISMATCH.format(
                        batch_size, self.num_slices, self.slice_len
                    )
                )
            slice_len = self.slice_len
            # num_slices = batch_size // slice_len

        while total < batch_size:
            try:
                data = next(self.dataset_iter)
            except StopIteration:
                self.dataset_iter = iter(self.dataset)
                data = next(self.dataset_iter)

            if self.base_path:
                data = data[self.base_path]
            data = torch.stack(
                [_make_tensordict_image_conv(step) for step in data["steps"]]
            ).contiguous()
            _format_data(data, episode)
            data = _slice_data(data, slice_len=slice_len, pad_value=self.pad)
            data_list.append(data)
            total += data.numel()
            episode += 1
        data = torch.cat(data_list)
        if self.truncate:
            return data[: index.stop]
        return data

    def dumps(self, path):
        path = Path(path)
        state_dict = self.state_dict()
        json.dump(state_dict, path / "state_dict.json")

    def state_dict(self) -> Dict[str, Any]:
        return {
            "repo": self.repo,
            "split": self.split,
            "dataset_id": self.dataset_id,
            "shuffle": self.shuffle,
            "base_path": self.base_path,
            "truncated": self.truncate,
            "num_slices": self.num_slices,
            "slice_len": self.slice_len,
            "pad": self.pad,
        }

    def loads(self, path):
        path = Path(path)
        state_dict = json.load(path / "state_dict.json")
        self.load_state_dict(state_dict)

    def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
        for key, val in state_dict.items():
            setattr(self, key, val)
        self._init()

    def __len__(self):
        raise RuntimeError(
            f"{type(self)} does not have a length. Use a downloaded dataset to "
            f"access this property."
        )


def _slice_data(data: TensorDict, slice_len, pad_value):
    if data.shape[-1] < slice_len:
        if pad_value is None:
            raise RuntimeError(
                f"The trajectory length ({data.shape[-1]}) is shorter than the slice length ({slice_len}). "
                f"Decrease the slice length or provide a padding value."
            )
        if pad_value is True:
            pad_value = 0
        return pad(data, [0, slice_len - data.shape[-1]], value=pad_value)

    if data.ndim == 1:
        random_range = (
            ((data.shape[-1] - slice_len) * torch.rand(())).floor().int().item()
        )
        random_range = slice(random_range, random_range + slice_len)
    else:
        raise NotImplementedError(data)
    data = data[..., random_range]
    truncated = data.get(("next", "truncated"))
    truncated = torch.index_fill(
        truncated,
        dim=data.ndim - 1,
        value=True,
        index=torch.as_tensor(-1, device=truncated.device),
    )
    done = data.get(("next", "done"))
    data.set(("next", "truncated"), truncated)
    data.set(("next", "done"), truncated | done)
    return data


class _StreamingSampler(Sampler):
    def __init__(self):
        ...

    def sample(self, storage: Storage, batch_size: int) -> Tuple[Any, dict]:
        return range(batch_size), {}

    def _empty(self):
        return

    def dumps(self, path):
        ...

    def loads(self, path):
        ...

    def state_dict(self) -> Dict[str, Any]:
        return {}

    def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
        ...


OPENX_KEY_MAP = {
    "is_first": "is_init",
    "is_last": ("next", "done"),
    "is_terminal": ("next", "terminated"),
    "reward": ("next", "reward"),
}


def _format_data(data: TensorDict, episode: int):
    observation_ = data.get("observation")
    observation_pad = pad(observation_[1:], [0, 1])
    data.set(("next", "observation"), observation_pad)
    for key, newkey in OPENX_KEY_MAP.items():
        data.rename_key_(key, newkey)
    data.set(
        ("next", "truncated"),
        data.get(("next", "done")) & ~data.get(("next", "terminated")),
    )

    for key in ("done", "terminated", "truncated", "reward"):
        data.set(("next", key), data.get(("next", key)).unsqueeze(-1))
        if key != "reward":
            data.set(key, torch.zeros_like(data.get(("next", key))))

    data.set(
        "episode", torch.full(data.shape, episode, device=data.device, dtype=torch.int)
    )


def _make_tensordict_image_conv(data):
    # in some datasets, the images are not well converted.
    # before building the tensordict, we load the PIL image and convert it to a tensor
    try:
        img_bytes = data["observation"]["image"]["bytes"]
        if not _has_tv:
            raise ImportError(
                "the `torchvision` library is required to read the image observation."
            )
        import torchvision.transforms.v2.functional
        from PIL import Image

        img = Image.open(io.BytesIO(img_bytes))
        tensor = torchvision.transforms.v2.functional.pil_to_tensor(img)
        data["observation"]["image"] = tensor
    except KeyError:
        pass
    return make_tensordict(data)
