"""Tests the TabularRewardModel class."""
import torch

from offline_rl.rewards.tabular_reward_model import TabularRewardModel


class TestTabularRewardModel:
    def test_reward_values(self):
        num_states = 2
        num_actions = 3
        rewards = torch.arange(num_states * num_actions * num_states).reshape(num_states, num_actions, num_states)
        model = TabularRewardModel(rewards)

        states = torch.LongTensor([0, 1, 0])
        actions = torch.LongTensor([2, 1, 0])
        next_states = torch.LongTensor([1, 0, 1])
        terminals = None
        values = model.reward(states, actions, next_states, terminals)
        for r, s, a, ns in zip(values, states, actions, next_states):
            assert rewards[s, a, ns] == r
