"""
Contains torch Modules for value networks. These networks take an 
observation dictionary as input (and possibly additional conditioning, 
such as subgoal or goal dictionaries) and produce value or 
action-value estimates or distributions.
"""
import numpy as np
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as D

import agents.models.robomimic.utils.tensor_utils as TensorUtils
from agents.models.robomimic.models.obs_nets import MIMO_MLP
from agents.models.robomimic.models.distributions import DiscreteValueDistribution


class ValueNetwork(MIMO_MLP):
    """
    A basic value network that predicts values from observations.
    Can optionally be goal conditioned on future observations.
    """
    def __init__(
        self,
        obs_shapes,
        mlp_layer_dims,
        value_bounds=None,
        goal_shapes=None,
        encoder_kwargs=None,
    ):
        """
        Args:
            obs_shapes (OrderedDict): a dictionary that maps observation keys to
                expected shapes for observations.

            mlp_layer_dims ([int]): sequence of integers for the MLP hidden layers sizes. 

            value_bounds (tuple): a 2-tuple corresponding to the lowest and highest possible return
                that the network should be possible of generating. The network will rescale outputs
                using a tanh layer to lie within these bounds. If None, no tanh re-scaling is done.

            goal_shapes (OrderedDict): a dictionary that maps observation keys to
                expected shapes for goal observations.

            encoder_kwargs (dict or None): If None, results in default encoder_kwargs being applied. Otherwise, should
                be nested dictionary containing relevant per-observation key information for encoder networks.
                Should be of form:

                obs_modality1: dict
                    feature_dimension: int
                    core_class: str
                    core_kwargs: dict
                        ...
                        ...
                    obs_randomizer_class: str
                    obs_randomizer_kwargs: dict
                        ...
                        ...
                obs_modality2: dict
                    ...
        """
        self.value_bounds = value_bounds
        if self.value_bounds is not None:
            # convert [lb, ub] to a scale and offset for the tanh output, which is in [-1, 1]
            self._value_scale = (float(self.value_bounds[1]) - float(self.value_bounds[0])) / 2.
            self._value_offset = (float(self.value_bounds[1]) + float(self.value_bounds[0])) / 2.

        assert isinstance(obs_shapes, OrderedDict)
        self.obs_shapes = obs_shapes

        # set up different observation groups for @MIMO_MLP
        observation_group_shapes = OrderedDict()
        observation_group_shapes["obs"] = OrderedDict(self.obs_shapes)

        self._is_goal_conditioned = False
        if goal_shapes is not None and len(goal_shapes) > 0:
            assert isinstance(goal_shapes, OrderedDict)
            self._is_goal_conditioned = True
            self.goal_shapes = OrderedDict(goal_shapes)
            observation_group_shapes["goal"] = OrderedDict(self.goal_shapes)
        else:
            self.goal_shapes = OrderedDict()

        output_shapes = self._get_output_shapes()
        super(ValueNetwork, self).__init__(
            input_obs_group_shapes=observation_group_shapes,
            output_shapes=output_shapes,
            layer_dims=mlp_layer_dims,
            encoder_kwargs=encoder_kwargs,
        )

    def _get_output_shapes(self):
        """
        Allow subclasses to re-define outputs from @MIMO_MLP, since we won't
        always directly predict values, but may instead predict the parameters
        of a value distribution.
        """
        return OrderedDict(value=(1,))

    def output_shape(self, input_shape=None):
        """
        Function to compute output shape from inputs to this module. 

        Args:
            input_shape (iterable of int): shape of input. Does not include batch dimension.
                Some modules may not need this argument, if their output does not depend 
                on the size of the input, or if they assume fixed size input.

        Returns:
            out_shape ([int]): list of integers corresponding to output shape
        """
        return [1]

    def forward(self, obs_dict, goal_dict=None):
        """
        Forward through value network, and then optionally use tanh scaling.
        """
        values = super(ValueNetwork, self).forward(obs=obs_dict, goal=goal_dict)["value"]
        if self.value_bounds is not None:
            values = self._value_offset + self._value_scale * torch.tanh(values)
        return values

    def _to_string(self):
        return "value_bounds={}".format(self.value_bounds)


class ActionValueNetwork(ValueNetwork):
    """
    A basic Q (action-value) network that predicts values from observations
    and actions. Can optionally be goal conditioned on future observations.
    """
    def __init__(
        self,
        obs_shapes,
        ac_dim,
        mlp_layer_dims,
        value_bounds=None,
        goal_shapes=None,
        encoder_kwargs=None,
    ):
        """
        Args:
            obs_shapes (OrderedDict): a dictionary that maps observation keys to
                expected shapes for observations.

            ac_dim (int): dimension of action space.

            mlp_layer_dims ([int]): sequence of integers for the MLP hidden layers sizes. 

            value_bounds (tuple): a 2-tuple corresponding to the lowest and highest possible return
                that the network should be possible of generating. The network will rescale outputs
                using a tanh layer to lie within these bounds. If None, no tanh re-scaling is done.

            goal_shapes (OrderedDict): a dictionary that maps observation keys to
                expected shapes for goal observations.

            encoder_kwargs (dict or None): If None, results in default encoder_kwargs being applied. Otherwise, should
                be nested dictionary containing relevant per-observation key information for encoder networks.
                Should be of form:

                obs_modality1: dict
                    feature_dimension: int
                    core_class: str
                    core_kwargs: dict
                        ...
                        ...
                    obs_randomizer_class: str
                    obs_randomizer_kwargs: dict
                        ...
                        ...
                obs_modality2: dict
                    ...
        """

        # add in action as a modality
        new_obs_shapes = OrderedDict(obs_shapes)
        new_obs_shapes["action"] = (ac_dim,)
        self.ac_dim = ac_dim

        # pass to super class to instantiate network
        super(ActionValueNetwork, self).__init__(
            obs_shapes=new_obs_shapes,
            mlp_layer_dims=mlp_layer_dims,
            value_bounds=value_bounds,
            goal_shapes=goal_shapes,
            encoder_kwargs=encoder_kwargs,
        )

    def forward(self, obs_dict, acts, goal_dict=None):
        """
        Modify forward from super class to include actions in inputs.
        """
        inputs = dict(obs_dict)
        inputs["action"] = acts
        return super(ActionValueNetwork, self).forward(inputs, goal_dict)

    def _to_string(self):
        return "action_dim={}\nvalue_bounds={}".format(self.ac_dim, self.value_bounds)


class DistributionalActionValueNetwork(ActionValueNetwork):
    """
    Distributional Q (action-value) network that outputs a categorical distribution over
    a discrete grid of value atoms. See https://arxiv.org/pdf/1707.06887.pdf for 
    more details.
    """
    def __init__(
        self,
        obs_shapes,
        ac_dim,
        mlp_layer_dims,
        value_bounds,
        num_atoms,
        goal_shapes=None,
        encoder_kwargs=None,
    ):
        """
        Args:
            obs_shapes (OrderedDict): a dictionary that maps modality to
                expected shapes for observations.

            ac_dim (int): dimension of action space.

            mlp_layer_dims ([int]): sequence of integers for the MLP hidden layers sizes. 

            value_bounds (tuple): a 2-tuple corresponding to the lowest and highest possible return
                that the network should be possible of generating. This defines the support
                of the value distribution.

            num_atoms (int): number of value atoms to use for the categorical distribution - which
                is the representation of the value distribution.

            goal_shapes (OrderedDict): a dictionary that maps modality to
                expected shapes for goal observations.

            encoder_kwargs (dict or None): If None, results in default encoder_kwargs being applied. Otherwise, should
                be nested dictionary containing relevant per-modality information for encoder networks.
                Should be of form:

                obs_modality1: dict
                    feature_dimension: int
                    core_class: str
                    core_kwargs: dict
                        ...
                        ...
                    obs_randomizer_class: str
                    obs_randomizer_kwargs: dict
                        ...
                        ...
                obs_modality2: dict
                    ...
        """

        # parameters specific to DistributionalActionValueNetwork
        self.num_atoms = num_atoms
        self._atoms = np.linspace(value_bounds[0], value_bounds[1], num_atoms)

        # pass to super class to instantiate network
        super(DistributionalActionValueNetwork, self).__init__(
            obs_shapes=obs_shapes,
            ac_dim=ac_dim,
            mlp_layer_dims=mlp_layer_dims,
            value_bounds=value_bounds,
            goal_shapes=goal_shapes,
            encoder_kwargs=encoder_kwargs,
        )

    def _get_output_shapes(self):
        """
        Network outputs log probabilities for categorical distribution over discrete value grid.
        """
        return OrderedDict(log_probs=(self.num_atoms,))

    def forward_train(self, obs_dict, acts, goal_dict=None):
        """
        Return full critic categorical distribution.

        Args:
            obs_dict (dict): batch of observations
            acts (torch.Tensor): batch of actions
            goal_dict (dict): if not None, batch of goal observations

        Returns:
            value_distribution (DiscreteValueDistribution instance)
        """

        # add in actions
        inputs = dict(obs_dict)
        inputs["action"] = acts

        # network returns unnormalized log probabilities (logits) for each of the value atoms
        logits = MIMO_MLP.forward(self, obs=inputs, goal=goal_dict)["log_probs"]

        # turn these logits into a categorical distribution over the value atoms.
        # (unsqueeze to make sure atoms are compatible with batch operations)
        value_atoms = torch.Tensor(self._atoms).unsqueeze(0).to(logits.device)
        return DiscreteValueDistribution(values=value_atoms, logits=logits)

    def forward(self, obs_dict, acts, goal_dict=None):
        """
        Return mean of critic categorical distribution. Useful for obtaining
        point estimates of critic values.

        Args:
            obs_dict (dict): batch of observations
            acts (torch.Tensor): batch of actions
            goal_dict (dict): if not None, batch of goal observations

        Returns:
            mean_value (torch.Tensor): expectation of value distribution
        """
        vd = self.forward_train(obs_dict=obs_dict, acts=acts, goal_dict=goal_dict)
        return vd.mean()

    def _to_string(self):
        return "action_dim={}\nvalue_bounds={}\nnum_atoms={}".format(self.ac_dim, self.value_bounds, self.num_atoms)