import pytest
import torch

from offline_rl.rewards.evaluation.point_maze_env_transition_samplers import PointMazeEnvActionSamplingTransitionSampler
from offline_rl.rewards.evaluation.transition_sampler import BoundaryActionSampler


class TestPointMazeEnvActionSamplingTransitionSampler:
    @pytest.mark.parametrize("dt,num_states", [
        (1.0, 10),
        (1.0, 1),
    ])
    def test_sample(self, dt, num_states, max_magnitude=1.0):
        state_dim = 6
        action_sampler = BoundaryActionSampler(max_magnitude)
        sampler = PointMazeEnvActionSamplingTransitionSampler(dt, action_sampler)
        states = torch.ones(num_states, state_dim)
        actions, next_states, weights = sampler.sample(states)
        assert actions.shape == (num_states, action_sampler.num_actions, 2)
        assert next_states.shape == (num_states, action_sampler.num_actions, state_dim)
        assert weights.shape == (num_states, action_sampler.num_actions)
