from typing import Optional

import gym
import torch

from offline_rl.rewards.reward_model import RewardModel


class TabularRewardModel(RewardModel):
    """A reward model based on a table of values.

    This works for MDPs with discrete state and action spaces.

    Args:
        rewards: The tabular reward function.
            This should be a tensor of shape (num_states, num_actions, num_next_states).
            Indexing into this tensor should yield the reward value associated with the
            corresponding transition.
    """
    def __init__(self, rewards: torch.Tensor):
        assert rewards.ndim == 3
        num_states, num_actions, num_next_states = rewards.shape
        assert num_states == num_next_states
        self.rewards = rewards
        self._observation_space = gym.spaces.Discrete(num_states)
        self._action_space = gym.spaces.Discrete(num_actions)

    def reward(
            self,
            states: torch.Tensor,
            actions: torch.Tensor,
            next_states: Optional[torch.Tensor],
            terminals: Optional[torch.Tensor],
    ) -> torch.Tensor:
        """See base class documentation.

        I can't think of a simple, general way of handling terminals in the tabular
        case other than adding another dimension to rewards, which seems too inefficient,
        so for now they are ignored.
        """
        del terminals
        assert len(states) > 0
        assert len(states) == len(actions)
        assert len(states) == len(next_states)
        return self.rewards[states, actions, next_states]

    @property
    def observation_space(self) -> gym.spaces.Space:
        return self._observation_space

    @property
    def action_space(self) -> gym.spaces.Space:
        return self._action_space
