from typing import Optional

import gym
import torch

from offline_rl.rewards.reward_model import RewardModel


class ConstantRewardModel(RewardModel):
    """A reward model that is constant for all transitions."""
    def __init__(self, constant: float = 1.0):
        self.constant = constant

    # pylint: disable=unused-argument
    def reward(
            self,
            states: torch.Tensor,
            actions: torch.Tensor,
            next_states: Optional[torch.Tensor],
            terminals: Optional[torch.Tensor],
    ) -> torch.Tensor:
        return torch.ones((len(states), 1), device=states.device) * self.constant

    @property
    def observation_space(self) -> gym.spaces.Space:
        return None

    @property
    def action_space(self) -> gym.spaces.Space:
        return None


class RandomRewardModel(RewardModel):
    """A reward model that returns random rewards."""

    # pylint: disable=unused-argument
    def reward(
            self,
            states: torch.Tensor,
            actions: torch.Tensor,
            next_states: Optional[torch.Tensor],
            terminals: Optional[torch.Tensor],
    ) -> torch.Tensor:
        return torch.rand((len(states), 1), device=states.device)

    @property
    def observation_space(self) -> gym.spaces.Space:
        return None

    @property
    def action_space(self) -> gym.spaces.Space:
        return None


class TabularPotentialShaping(RewardModel):
    """Adds tabular potential shaping to a reward model."""
    def __init__(self, base_reward: RewardModel, potential: torch.Tensor, discount: float):
        self.base_reward = base_reward
        self.potential = potential
        self.discount = discount

    def reward(
            self,
            states: torch.Tensor,
            actions: torch.Tensor,
            next_states: Optional[torch.Tensor],
            terminals: Optional[torch.Tensor],
    ) -> torch.Tensor:
        original_rewards = self.base_reward.reward(states, actions, next_states, terminals)
        shaped_rewards = original_rewards + self.discount * self.potential[next_states] - self.potential[states]
        return shaped_rewards

    @property
    def observation_space(self) -> gym.spaces.Space:
        return self.base_reward.observation_space

    @property
    def action_space(self) -> gym.spaces.Space:
        return self.base_reward.action_space
