from typing import List, Optional

import gym
import torch
import torch.nn as nn

from offline_rl.envs.bouncing_balls_env import BouncingBallsEnv
from offline_rl.rewards.reward_model import RewardModel
from offline_rl.utils.space_utils import clip_actions_to_space_bounds
from offline_rl.utils.torch_utils import build_fc_network


class LearnedBouncingBallsEnvRewardModel(RewardModel, nn.Module):
    """A learnable reward model specific to the bouncing balls env.

    This model extracts features specific to the env, and learns a 
    reward model on top of those features.

    Args:
        obs_space: Input observation space.
        act_space: Input action space.
        hidden_sizes: List of layer sizes.
        output_size: Dimensionality of output. Should almost certainly be 1.
        discount: Discount factor of the env. Used in computing some features.
    """
    def __init__(
            self,
            obs_space: gym.spaces.Space,
            act_space: gym.spaces.Space,
            hidden_sizes: List[int],
            output_size: int = 1,
            discount: float = 0.95,
    ):
        # Calls `nn.Module.__init__` because `RewardModel` is abstract.
        super().__init__()
        self.obs_space = obs_space
        self.act_space = act_space
        self.discount = discount

        # Compute the size of the manually-extracted feature vector, which depends on the number
        # of other balls in the environment, which can be derived from the obs space.
        length_of_ball_features = self.obs_space.shape[0] - (4 + 2)
        assert length_of_ball_features % 4 == 0
        num_balls = length_of_ball_features // 4
        # Currently 2 features per ado ball.
        num_features_per_other_ball = 6
        num_ego_features = 6
        self.features_size = num_ego_features + num_balls * num_features_per_other_ball

        layer_sizes = hidden_sizes + [output_size]
        self.network = build_fc_network(
            self.features_size,
            layer_sizes,
        )

        # Mean and standard deviation tensors with which to normalize.
        self.mean = torch.cat((
            # Ego features.
            torch.tensor([25, 5, 25, 5, 0.5, 0]),
            # Ado features.
            torch.repeat_interleave(torch.tensor([25, 25, 0.5, 0.5, 0.0, 0.0]), num_balls),
        ))
        self.std = torch.cat((
            # Ego features.
            torch.tensor([10, 2, 10, 2, 0.5, 1.0]),
            # Ado features.
            torch.repeat_interleave(torch.tensor([10, 10, 0.5, 0.5, 1.0, 1.0]), num_balls),
        ))
        assert len(self.mean) == self.features_size
        assert len(self.std) == self.features_size

    def forward(
            self,
            states: torch.Tensor,
            actions: torch.Tensor,
            next_states: Optional[torch.Tensor],
            terminals: Optional[torch.Tensor],
    ) -> torch.Tensor:
        """Same as `reward()`."""
        del terminals
        actions = clip_actions_to_space_bounds(actions, self.act_space)
        inputs = self.extract_bouncing_ball_env_features(states, actions, next_states)
        return self.network.forward(inputs)

    def reward(self, *args, **kwargs) -> torch.Tensor:
        """See base class documentation."""
        return self.forward(*args, **kwargs)

    @property
    def observation_space(self) -> gym.spaces.Space:
        return self.obs_space

    @property
    def action_space(self) -> gym.spaces.Space:
        return self.act_space

    def extract_bouncing_ball_env_features(
            self,
            states: torch.Tensor,
            actions: torch.Tensor,
            next_states: torch.Tensor,
    ) -> torch.Tensor:
        """Manually extracts features from the states, actions, and next states.

        Args:
            states: States from which to extract features.
            actions: Actions from which to extract features.
            next_states: Next states from which to extract features.

        Returns:
            Extracted features of shape (batch size, feature size).
        """
        # No features for actions currently.
        del actions
        states_goal_distances = BouncingBallsEnv.get_distances_from_goal(states)
        states_goal_distances_sqrt = states_goal_distances**(1 / 2)
        next_states_goal_distances = BouncingBallsEnv.get_distances_from_goal(next_states)
        next_states_goal_distances_sqrt = next_states_goal_distances**(1 / 2)
        reached_goal = BouncingBallsEnv.get_reached_goal_indicators(next_states)
        discounted_diff_goal_distances_sqrt = -(next_states_goal_distances_sqrt * self.discount -
                                                states_goal_distances_sqrt)

        states_other_distances = BouncingBallsEnv.get_distances_from_other_balls_from_flat_states(states)
        next_states_other_distances = BouncingBallsEnv.get_distances_from_other_balls_from_flat_states(next_states)

        states_collisions = BouncingBallsEnv.get_collisions(states)
        next_states_collisions = BouncingBallsEnv.get_collisions(next_states)

        states_other_positions = BouncingBallsEnv.get_other_positions_from_flat_states(states)
        next_states_other_positions = BouncingBallsEnv.get_other_positions_from_flat_states(next_states)
        diff_other_positions_sq = (next_states_other_positions - states_other_positions).flatten(1)**2

        features = torch.cat(
            (
                states_goal_distances,
                states_goal_distances_sqrt,
                next_states_goal_distances,
                next_states_goal_distances_sqrt,
                reached_goal,
                discounted_diff_goal_distances_sqrt,
                states_other_distances,
                next_states_other_distances,
                states_collisions,
                next_states_collisions,
                diff_other_positions_sq,
            ),
            dim=1,
        ).to(torch.float32)

        features = self._normalize(features)
        return features

    def _normalize(self, x: torch.Tensor, eps: float = 1e-10) -> torch.Tensor:
        return (x - self.mean.to(x.device)) / (self.std.to(x.device) + eps)


class LearnedReacherRewardModel(RewardModel, nn.Module):
    """A learnable reward model specific to the reacher env.

    Args:
        obs_space: Input observation space.
        act_space: Input action space.
        hidden_sizes: List of layer sizes.
        output_size: Dimensionality of output. Should almost certainly be 1.
    """
    def __init__(self,
                 obs_space: gym.spaces.Space,
                 act_space: gym.spaces.Space,
                 hidden_sizes: List[int],
                 output_size: int = 1):
        # Calls `nn.Module.__init__` because `RewardModel` is abstract.
        super().__init__()
        self.obs_space = obs_space
        self.act_space = act_space

        state_features_means = torch.tensor([0.0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
        state_features_stds = torch.tensor([1, 1, 1, 1, 0.2, 0.2, 35, 14, 0.2, 0.2, 1])
        assert state_features_means.shape == state_features_stds.shape

        action_features_means = torch.tensor([0, 0])
        action_features_stds = torch.tensor([0.5, 0.5])
        assert action_features_means.shape == action_features_stds.shape

        state_features_size = len(state_features_means)
        actions_features_size = len(action_features_means)
        self.features_size = state_features_size * 2 + actions_features_size

        layer_sizes = hidden_sizes + [output_size]
        self.network = build_fc_network(
            self.features_size,
            layer_sizes,
        )

        # Mean and standard deviation tensors with which to normalize.
        self.mean = torch.cat((
            state_features_means,
            action_features_means,
            state_features_means,
        ))
        self.std = torch.cat((
            state_features_stds,
            action_features_stds,
            state_features_stds,
        ))
        assert len(self.mean) == self.features_size
        assert len(self.std) == self.features_size

    def forward(
            self,
            states: torch.Tensor,
            actions: torch.Tensor,
            next_states: Optional[torch.Tensor],
            terminals: Optional[torch.Tensor],
    ) -> torch.Tensor:
        """Same as `reward()`."""
        del terminals
        actions = clip_actions_to_space_bounds(actions, self.act_space)
        inputs = self.extract_features(states, actions, next_states)
        return self.network.forward(inputs)

    def reward(self, *args, **kwargs) -> torch.Tensor:
        """See base class documentation."""
        return self.forward(*args, **kwargs)

    @property
    def observation_space(self) -> gym.spaces.Space:
        return self.obs_space

    @property
    def action_space(self) -> gym.spaces.Space:
        return self.act_space

    def _extract_state_features(self, states: torch.Tensor) -> torch.Tensor:
        """Extracts the obs features used in the original environment.

        See https://github.com/openai/gym/blob/master/gym/envs/mujoco/reacher.py#L41.

        Args:
            states: The states from which to extract features.

        Returns:
            The extracted features of shape (num_states, feature_dim).
        """
        assert states.shape[-1] == 11, "Only states from CustomReacherEnv supported."
        return torch.cat(
            (
                torch.cos(states[:, :2]),
                torch.sin(states[:, :2]),
                states[:, 2:4],
                states[:, 4:6],
                states[:, -3:],
            ),
            dim=1,
        )

    def extract_features(
            self,
            states: torch.Tensor,
            actions: torch.Tensor,
            next_states: torch.Tensor,
    ) -> torch.Tensor:
        """Manually extracts features from the states, actions, and next states.

        Args:
            states: States from which to extract features.
            actions: Actions from which to extract features.
            next_states: Next states from which to extract features.

        Returns:
            Extracted features of shape (batch size, feature size).
        """
        features = torch.hstack([
            self._extract_state_features(states),
            actions.to(states.dtype),
            self._extract_state_features(next_states),
        ])
        features = self._normalize(features)
        return features

    def _normalize(self, x: torch.Tensor, eps: float = 1e-10) -> torch.Tensor:
        return (x - self.mean.to(x.device)) / (self.std.to(x.device) + eps)
