import math

import pytest
import torch

from offline_rl.rewards.noisy_reward_wrapper import NoisyRewardWrapper
from offline_rl.utils.testing.rewards import ConstantRewardModel


class TestNoisyRewardWrapper:
    @pytest.mark.parametrize("sigma, seed", [
        (0.1, 0),
        (1, 1),
        (10, 2),
    ])
    def test_reward_values(self, sigma, seed):
        torch.manual_seed(seed)
        base_reward_model = ConstantRewardModel()
        noisy_model = NoisyRewardWrapper(base_reward_model, sigma)

        size = 100000

        states = torch.zeros((size, 1))
        actions = torch.zeros((size, 1))
        next_states = torch.zeros((size, 1))
        terminals = torch.zeros((size, 1))
        values = noisy_model.reward(states, actions, next_states, terminals)

        assert math.isclose(float(values.std()), sigma, abs_tol=1e-1)
