import gym
import pytest
import torch

from offline_rl.rewards.learning.reward_model_networks import FullyConnectedRewardModel


class TestFullyConnectedRewardModel:
    @pytest.mark.parametrize("obs_size,act_size", [(4, 2), (2, 4)])
    def test_forward(self, obs_size, act_size):
        obs_space = gym.spaces.Discrete(obs_size)
        act_space = gym.spaces.Discrete(act_size)
        model = FullyConnectedRewardModel(obs_space, act_space, [8, 8])

        batch_size = 60
        states = torch.LongTensor([obs_space.sample() for _ in range(batch_size)])
        actions = torch.LongTensor([act_space.sample() for _ in range(batch_size)])
        next_states = torch.LongTensor([obs_space.sample() for _ in range(batch_size)])

        output = model.forward(states, actions, next_states, None)
        assert output.shape == (batch_size, 1)
