"""This modules creates a continuous Q-function network."""

import torch

from garage.torch.modules.multi_headed_mlp_module import MultiHeadedMLPModule


class MultiheadContinuousMLPQFunction(MultiHeadedMLPModule):
    """Implements a continuous MLP Q-value network.
    It predicts the Q-value for all actions based on the input state. It uses
    a PyTorch neural network module to fit the function of Q(s, a).
    """

    def __init__(
        self,
        env_spec,
        num_heads,
        policy_assigner=None,
        split_observation=None,
        **kwargs
    ):
        """Initialize class with multiple attributes.
        Args:
            env_spec (EnvSpec): Environment specification.
            num_heads (int): Number of network heads
            **kwargs: Keyword arguments.
        """
        self._env_spec = env_spec
        self._n_heads = num_heads
        self._obs_dim = env_spec.observation_space.flat_dim
        self._action_dim = env_spec.action_space.flat_dim

        MultiHeadedMLPModule.__init__(
            self,
            n_heads=self._n_heads,
            input_dim=self._obs_dim + self._action_dim,
            output_dims=1,
            **kwargs
        )

        self._policy_assigner = policy_assigner
        self.split_observation = split_observation or (lambda x: (x, x))

    # pylint: disable=arguments-differ
    def forward(self, observations, actions):
        """Return Q-value(s).
        Args:
            observations (np.ndarray): observations.
            actions (np.ndarray): actions.
        Returns:
            torch.Tensor: Output value
        """
        obss, tasks = self.split_observation(observations)
        curr_policies = self._policy_assigner.get_actions(tasks)[0]
        idx = list(range(len(curr_policies)))
        qvalues = super().forward(torch.cat([obss, actions], 1))
        qvalues = torch.cat(qvalues, dim=-1)
        qvalues = qvalues[idx, curr_policies.long()]
        return qvalues
