from typing import List, Optional

import gym
import numpy as np
import torch
import torch.nn as nn

from offline_rl.rewards.reward_model import RewardModel
from offline_rl.utils.space_utils import (
    BoxSpaceInputFormatter,
    DiscreteSpaceInputFormatter,
    MultiDiscreteSpaceInputFormatter,
    clip_actions_to_space_bounds,
)
from offline_rl.utils.torch_utils import build_fc_network, parse_norm_layer


class FullyConnectedRewardModel(RewardModel, nn.Module):
    """A learnable reward model using a fully-connected network.

    Args:
        obs_space: Input observation space.
        act_space: Input action space.
        hidden_sizes: List of layer sizes.
        output_size: Dimensionality of output. Should almost certainly be 1.
    """
    def __init__(self,
                 obs_space: gym.spaces.Space,
                 act_space: gym.spaces.Space,
                 hidden_sizes: List[int],
                 output_size: int = 1):
        # Calls `nn.Module.__init__` because `RewardModel` is abstract.
        super().__init__()
        self.obs_space = obs_space
        self.act_space = act_space
        self.obs_formatter = self._get_formatter_for_space(obs_space)
        self.act_formatter = self._get_formatter_for_space(act_space)

        input_size = self.obs_formatter.output_size + self.act_formatter.output_size + self.obs_formatter.output_size
        layer_sizes = hidden_sizes + [output_size]
        self.network = build_fc_network(input_size, layer_sizes, norm_layer=parse_norm_layer("batch_norm"))

    def forward(
            self,
            states: torch.Tensor,
            actions: torch.Tensor,
            next_states: Optional[torch.Tensor],
            terminals: Optional[torch.Tensor],
    ) -> torch.Tensor:
        """Same as `reward()`."""
        del terminals
        inputs = self._format_inputs(states, actions, next_states)
        return self.network.forward(inputs)

    def reward(self, *args, **kwargs) -> torch.Tensor:
        """See base class documentation."""
        return self.forward(*args, **kwargs)

    @property
    def observation_space(self) -> gym.spaces.Space:
        return self.obs_space

    @property
    def action_space(self) -> gym.spaces.Space:
        return self.act_space

    def _format_inputs(self, states: torch.Tensor, actions: torch.Tensor, next_states: torch.Tensor) -> torch.Tensor:
        """Formats the arguments as input for the network."""
        formatted_states = self.obs_formatter(states)
        if isinstance(self.act_space, gym.spaces.Box):
            actions = clip_actions_to_space_bounds(actions, self.act_space)
        formatted_actions = self.act_formatter(actions)
        formatted_next_states = self.obs_formatter(next_states)
        assert formatted_states.dtype == formatted_next_states.dtype, "The states and next states should be the same data type."
        # The actions might be different, so convert.
        formatted_actions = formatted_actions.to(formatted_states.dtype)
        return torch.hstack([
            formatted_states,
            formatted_actions,
            formatted_next_states,
        ])

    @staticmethod
    def _get_formatter_for_space(space: gym.spaces.Space):
        """Gets the formatter for the provided space.

        This function is implemented on this class as opposed to being a reusable function
        defined elsewhere because this logic may be specific to this class.
        This is a @staticmethod to avoid bugs due to using self.

        Args:
            space: The space for which to return a formatter.

        Returns:
            A formatter to use for the given space.
        """
        if isinstance(space, gym.spaces.Discrete):
            return DiscreteSpaceInputFormatter(space)
        elif isinstance(space, gym.spaces.MultiDiscrete):
            return MultiDiscreteSpaceInputFormatter(space)
        elif isinstance(space, gym.spaces.Box):
            should_normalize = True
            if np.any(np.isinf(space.low)) or np.any(np.isinf(space.high)):
                print("WARNING: Skipping normalization due to infinite bounds in Box space.")
                should_normalize = False
            return BoxSpaceInputFormatter(space, should_normalize=should_normalize)
        else:
            raise ValueError(f"Unsupported obs space: {space}")
