import torch

from offline_rl.data.dict_dataset import DictDataset
from offline_rl.data.labeled_merge_dataset import LabeledMergeDataset


class TestLabeledMergeDataset:
    def test_length_and_getitem(self):
        first = DictDataset({
            "a": torch.zeros(5),
            "b": torch.ones(5),
        })
        second = DictDataset({
            "c": torch.zeros(10),
            "d": torch.ones(10),
        })

        label_key = "label"
        dataset = LabeledMergeDataset(
            (
                (0, first),
                (1, second),
            ),
            label_key=label_key,
        )
        sample = dataset[0]
        assert label_key in sample
        assert "a" in sample
        assert "b" in sample

        sample = dataset[10]
        assert label_key in sample
        assert "c" in sample
        assert "d" in sample
