import torch

from offline_rl.envs.line_env import LineEnvReward


class TestLineEnvReward:
    def test_make_ground_truth_reward(self):
        side_length = 2
        right_reward = 1
        left_reward = -1
        model = LineEnvReward.make_ground_truth_reward(
            side_length=side_length,
            right_reward=right_reward,
            left_reward=left_reward,
        )
        states = torch.LongTensor([1, 0, side_length * 2 - 2])
        actions = torch.LongTensor([0, 1, 1])
        next_states = torch.LongTensor([0, 1, side_length * 2])
        values = model.reward(states, actions, next_states, None)
        assert values[0] == left_reward
        assert values[1] == 0
        assert values[2] == right_reward

    def test_default_make_functions_execute_without_exception(self):
        for ctor in [
                LineEnvReward.make_ground_truth_reward,
                LineEnvReward.make_reverse_reward,
                LineEnvReward.make_zero_reward,
                LineEnvReward.make_center_reward,
        ]:
            model = ctor()
            assert isinstance(model, LineEnvReward)
