import gym
import torch

from offline_rl.rewards.learning.discriminative_reward_model import DiscriminativeRewardModel
from offline_rl.rewards.learning.reward_model_networks import FullyConnectedRewardModel
from offline_rl.utils.testing.torch import get_random_state_action_next_state_batch


class TestDiscriminativeRewardModel:
    def test_generic_step(self):
        obs_space = gym.spaces.Discrete(4)
        act_space = gym.spaces.Discrete(2)
        submodel_kwargs = dict(
            obs_space=obs_space,
            act_space=act_space,
            hidden_sizes=[8, 8],
        )
        model = DiscriminativeRewardModel(FullyConnectedRewardModel, submodel_kwargs, use_terminals=False)

        batch_size = 10
        batch = get_random_state_action_next_state_batch(obs_space, act_space, batch_size)
        batch["label"] = torch.zeros((batch_size, 1))

        loss = model.generic_step(batch, 0, "train")
        assert loss is not None

    def test_overfit(self):
        # TODO(redacted): Implement.
        pass
