import gym
import pytest
from ray.rllib.policy.sample_batch import SampleBatch

from offline_rl.data.paired_trajectory_segment_dataset import PairedTrajectorySegmentDataset
from offline_rl.utils.testing.rllib import write_random_sample_batches_to_json


@pytest.fixture
def sample_batches(tmp_path):
    num_samples = 51
    obs_space = gym.spaces.Discrete(4)
    act_space = gym.spaces.Discrete(2)
    write_random_sample_batches_to_json(str(tmp_path), obs_space, act_space, num_samples=num_samples)
    return num_samples, tmp_path


class TestPairedTrajectorySegmentDataset:
    def test_convert_sample_index_to_segment_indices(self, sample_batches):
        segment_length = 10
        num_timesteps, tmp_path = sample_batches
        dataset = PairedTrajectorySegmentDataset(str(tmp_path / "*"), segment_length)

        expected_segment_1_indices = [0, 0, 0, 0, 1, 1, 1, 2, 2, 3]
        expected_segment_2_indices = [1, 2, 3, 4, 2, 3, 4, 3, 4, 4]
        for sample_index in range(len(dataset)):
            actual_segment_1_index, actual_segment_2_index = \
                dataset._convert_sample_index_to_segment_indices(sample_index)
            assert expected_segment_1_indices[sample_index] == actual_segment_1_index
            assert expected_segment_2_indices[sample_index] == actual_segment_2_index

    def test_convert_sample_index_to_segment_indices_with_max_num_pairs(self, sample_batches):
        max_num_pairs = 3
        segment_length = 10
        num_timesteps, tmp_path = sample_batches
        dataset = PairedTrajectorySegmentDataset(str(tmp_path / "*"), segment_length, max_num_pairs=max_num_pairs)
        expected_length = (num_timesteps // segment_length) * max_num_pairs
        assert len(dataset) == expected_length

    # pylint: disable=redefined-outer-name
    @pytest.mark.parametrize("segment_length", [5, 1, 25])
    def test_length_and_getitem(self, sample_batches, segment_length):
        num_timesteps, tmp_path = sample_batches
        dataset = PairedTrajectorySegmentDataset(str(tmp_path / "*"), segment_length)

        num_segments = num_timesteps // segment_length
        num_samples = (num_segments * (num_segments - 1)) // 2
        assert len(dataset) == num_samples

        sample = dataset[0]
        for key in [SampleBatch.OBS, SampleBatch.ACTIONS, SampleBatch.NEXT_OBS]:
            assert key in sample
        assert sample[SampleBatch.OBS].shape == (2, segment_length)
        assert sample[SampleBatch.ACTIONS].shape == (2, segment_length)
        assert sample[SampleBatch.NEXT_OBS].shape == (2, segment_length)
