from typing import Optional

import gym
import torch

from offline_rl.rewards.reward_model import RewardModel


class NoisyRewardWrapper(RewardModel):
    """Randomly applies noise to the outputs of a base reward model.

    The noise is sampled randomly for each execution of the reward model independent of the inputs.
    The noise is additive.

    Args:
        base_reward: The base reward model to perturb.
        sigma: Standard deviation of the noise to apply additively to the rewards.
    """
    def __init__(self, base_reward: RewardModel, sigma: float):
        self.base_reward = base_reward
        self.sigma = sigma

    def _sample_noise(self, size: int, dtype: type, device: torch.device) -> torch.Tensor:
        return torch.randn((size, 1), dtype=dtype, device=device) * self.sigma

    def reward(
            self,
            states: torch.Tensor,
            actions: torch.Tensor,
            next_states: Optional[torch.Tensor],
            terminals: Optional[torch.Tensor],
    ) -> torch.Tensor:
        """Applies noise to the rewards from a base reward model.

        See base class documentation for argument descriptions.
        """
        rewards = self.base_reward.reward(states, actions, next_states, terminals)
        rewards += self._sample_noise(len(rewards), rewards.dtype, rewards.device)
        return rewards

    @property
    def observation_space(self) -> gym.spaces.Space:
        return self.base_reward.observation_space

    @property
    def action_space(self) -> gym.spaces.Space:
        return self.base_reward.action_space
