import gym
import pytest
from ray.rllib.policy.sample_batch import SampleBatch

from offline_rl.data.transition_shuffling_dataset import TransitionShufflingDataset
from offline_rl.utils.testing.rllib import write_random_sample_batches_to_json


@pytest.fixture
def sample_batches(tmp_path):
    num_samples = 5
    obs_space = gym.spaces.Discrete(4)
    act_space = gym.spaces.Discrete(2)
    write_random_sample_batches_to_json(str(tmp_path), obs_space, act_space, num_samples=num_samples)
    return num_samples, tmp_path


class TestSampleBatchJsonReaderDataset:
    @pytest.mark.parametrize("num_pairs", [1, 2, 10])
    # pylint: disable=redefined-outer-name
    def test_length_and_getitem(self, sample_batches, num_pairs):
        num_samples, tmp_path = sample_batches
        maintain_original_pairing = True if num_pairs > 1 else False
        dataset = TransitionShufflingDataset(
            str(tmp_path / "*"),
            num_pairs=num_pairs,
            maintain_original_pairing=maintain_original_pairing,
        )
        assert len(dataset) == num_samples * num_pairs
        sample = dataset[0]
        for key in [SampleBatch.OBS, SampleBatch.ACTIONS, SampleBatch.NEXT_OBS]:
            assert key in sample
