import pytest

from offline_rl.rewards.learning.env_specific_networks import LearnedBouncingBallsEnvRewardModel
from offline_rl.utils.testing.bouncing_balls_env import (
    dummy_bouncing_balls_act_space,
    dummy_bouncing_balls_obs_space,
    get_dist_from_goal_states,
    get_noop_actions,
)


class TestLearnedBouncingBallsEnvRewardModel:
    @pytest.mark.parametrize("batch_size,num_balls", [
        (10, 5),
        (1, 5),
        (10, 1),
    ])
    def test_feature_extraction(self, batch_size, num_balls):
        reward_fn = LearnedBouncingBallsEnvRewardModel(
            dummy_bouncing_balls_obs_space(num_balls=num_balls),
            dummy_bouncing_balls_act_space(),
            [16],
        )
        states = get_dist_from_goal_states(batch_size, num_balls, 10)
        actions = get_noop_actions(batch_size)
        next_states = get_dist_from_goal_states(batch_size, num_balls, 15)
        features = reward_fn.extract_bouncing_ball_env_features(states, actions, next_states)
        assert features.shape == (batch_size, reward_fn.features_size)
