"""PyTorch model for DQN"""

from typing import Sequence
import gym
from models.torch.misc import SlimFC
from models.torch.modules.noisy_layer import NoisyLayer
from models.torch.torch_modelv2 import TorchModelV2
from utils.framework import try_import_torch
from utils.typing import ModelConfigDict

torch, nn = try_import_torch()


class DQNTorchModel(TorchModelV2, nn.Module):
    """Extension of standard TorchModelV2 to provide dueling-Q functionality.
    """

    def __init__(
            self,
            obs_space: gym.spaces.Space,
            action_space: gym.spaces.Space,
            num_outputs: int,
            model_config: ModelConfigDict,
            name: str,
            *,
            q_hiddens: Sequence[int] = (256, ),
            dueling: bool = False,
            dueling_activation: str = "relu",
            num_atoms: int = 1,
            use_noisy: bool = False,
            v_min: float = -10.0,
            v_max: float = 10.0,
            sigma0: float = 0.5,
            # TODO(sven): Move `add_layer_norm` into ModelCatalog as
            #  generic option, then error if we use ParameterNoise as
            #  Exploration type and do not have any LayerNorm layers in
            #  the net.
            add_layer_norm: bool = False):
        """Initialize variables of this model.

        Extra model kwargs:
            q_hiddens (Sequence[int]): List of layer-sizes after(!) the
                Advantages(A)/Value(V)-split. Hence, each of the A- and V-
                branches will have this structure of Dense layers. To define
                the NN before this A/V-split, use - as always -
                config["model"]["fcnet_hiddens"].
            dueling (bool): Whether to build the advantage(A)/value(V) heads
                for DDQN. If True, Q-values are calculated as:
                Q = (A - mean[A]) + V. If False, raw NN output is interpreted
                as Q-values.
            dueling_activation (str): The activation to use for all dueling
                layers (A- and V-branch). One of "relu", "tanh", "linear".
            num_atoms (int): If >1, enables distributional DQN.
            use_noisy (bool): Use noisy layers.
            v_min (float): Min value support for distributional DQN.
            v_max (float): Max value support for distributional DQN.
            sigma0 (float): Initial value of noisy layers.
            add_layer_norm (bool): Enable layer norm (for param noise).
        """
        nn.Module.__init__(self)
        super(DQNTorchModel, self).__init__(obs_space, action_space,
                                            num_outputs, model_config, name)

        self.dueling = dueling
        self.num_atoms = num_atoms
        self.v_min = v_min
        self.v_max = v_max
        self.sigma0 = sigma0
        ins = num_outputs

        advantage_module = nn.Sequential()
        value_module = nn.Sequential()

        # Dueling case: Build the shared (advantages and value) fc-network.
        for i, n in enumerate(q_hiddens):
            if use_noisy:
                advantage_module.add_module(
                    "dueling_A_{}".format(i),
                    NoisyLayer(
                        ins,
                        n,
                        sigma0=self.sigma0,
                        activation=dueling_activation))
                value_module.add_module(
                    "dueling_V_{}".format(i),
                    NoisyLayer(
                        ins,
                        n,
                        sigma0=self.sigma0,
                        activation=dueling_activation))
            else:
                advantage_module.add_module(
                    "dueling_A_{}".format(i),
                    SlimFC(ins, n, activation_fn=dueling_activation))
                value_module.add_module(
                    "dueling_V_{}".format(i),
                    SlimFC(ins, n, activation_fn=dueling_activation))
                # Add LayerNorm after each Dense.
                if add_layer_norm:
                    advantage_module.add_module("LayerNorm_A_{}".format(i),
                                                nn.LayerNorm(n))
                    value_module.add_module("LayerNorm_V_{}".format(i),
                                            nn.LayerNorm(n))
            ins = n

        # Actual Advantages layer (nodes=num-actions).
        if use_noisy:
            advantage_module.add_module(
                "A",
                NoisyLayer(
                    ins,
                    self.action_space.n * self.num_atoms,
                    sigma0,
                    activation=None))
        elif q_hiddens:
            advantage_module.add_module(
                "A",
                SlimFC(
                    ins, action_space.n * self.num_atoms, activation_fn=None))

        self.advantage_module = advantage_module

        # Value layer (nodes=1).
        if self.dueling:
            if use_noisy:
                value_module.add_module(
                    "V",
                    NoisyLayer(ins, self.num_atoms, sigma0, activation=None))
            elif q_hiddens:
                value_module.add_module(
                    "V", SlimFC(ins, self.num_atoms, activation_fn=None))
            self.value_module = value_module

    def get_q_value_distributions(self, model_out):
        """Returns distributional values for Q(s, a) given a state embedding.

        Override this in your custom model to customize the Q output head.

        Args:
            model_out (Tensor): Embedding from the model layers.

        Returns:
            (action_scores, logits, dist) if num_atoms == 1, otherwise
            (action_scores, z, support_logits_per_action, logits, dist)
        """
        action_scores = self.advantage_module(model_out)

        if self.num_atoms > 1:
            # Distributional Q-learning uses a discrete support z
            # to represent the action value distribution
            z = torch.range(
                0.0, self.num_atoms - 1,
                dtype=torch.float32).to(action_scores.device)
            z = self.v_min + \
                z * (self.v_max - self.v_min) / float(self.num_atoms - 1)

            support_logits_per_action = torch.reshape(
                action_scores, shape=(-1, self.action_space.n, self.num_atoms))
            support_prob_per_action = nn.functional.softmax(
                support_logits_per_action, dim=-1)
            action_scores = torch.sum(z * support_prob_per_action, dim=-1)
            logits = support_logits_per_action
            probs = support_prob_per_action
            return action_scores, z, support_logits_per_action, logits, probs
        else:
            logits = torch.unsqueeze(torch.ones_like(action_scores), -1)
            return action_scores, logits, logits

    def get_state_value(self, model_out):
        """Returns the state value prediction for the given state embedding."""

        return self.value_module(model_out)
