"""
"Maker" functions for common types of networks/modules. These are only provided
for the sake of convenience. There is no additional functionality implemented
in this file.
"""

import gymnasium as gym

import torch

from . import model, op, policy

from .utils import make_mlp, get_activation


def _prod(iterable):
    from functools import reduce
    return reduce(lambda x,y:x*y, iterable, 1)


def _make_preconditioner(space, label="space", override_dim=None):
    if isinstance(space, gym.spaces.Box):
        init = torch.nn.Flatten(start_dim=1, end_dim=-1)
        dim = _prod(space.shape)
        if override_dim is not None:
            dim = override_dim
    elif isinstance(space, gym.spaces.Discrete):
        dim = space.n
        if override_dim is not None:
            dim = override_dim
        init = torch.nn.Embedding(dim, dim)
    else:
        raise ValueError(f"Unsupported {label} {space}")
    return (dim, init)


def q_critic_mlp(env, hidden_sizes, activation="relu"):
    """
    Constructs a multi-layer perceptron (only fully connected layers) that
    generates a network that takes as two inputs, state and action, and then
    outputs a value. This is useful for algorithms such as DDPG and SAC.

    The module takes batched action and state as input, and returns a 1-D
    vector which contains the outputted critic value for each batch
    state-action pair.

    Parameters
    ----------
    env : gym.core.Env
      The environment that encodes information about the observation and
      action spaces.
    hidden_sizes : List[int]
      A list of positive integers which contain the sizes of the hidden layers.
    activation : str or torch.nn.Module class
      Non-linearity to have between the layers in the MLP.
    """
    if len(hidden_sizes) == 0:
        raise ValueError("No hidden sizes.")
    act = get_activation(activation)

    obs_dim, obs_init = _make_preconditioner(env.observation_space, label="observation space")
    action_dim, action_init = _make_preconditioner(env.action_space, label="action space")

    mlp = make_mlp(
        indim=obs_dim + action_dim, hidden_dims=hidden_sizes, outdim=1,
        activation=act,
        following_layers=[op.NNLayerSqueeze(-1)],
    )
    return op.NNLayerConcat2(dim=-1, init_left=obs_init, init_right=action_init, next=mlp)


def v_critic_mlp(env, hidden_sizes, activation="relu"):
    """
    Constructs a multi-layer perceptron (only fully connected layers) network
    that takes as the state as input and outputs a value. This is useful for
    algorithms such as PPO.

    The module takes batched state as input, and returns a 1-D vector which
    contains the outputted critic value for each batch state.

    Parameters
    ----------
    env : gym.core.Env
      The environment that encodes information about the observation and
      action spaces.
    hidden_sizes : List[int]
      A list of positive integers which contain the sizes of the hidden layers.
    activation : str or torch.nn.Module class
      Non-linearity to have between the layers in the MLP.
    """
    if len(hidden_sizes) == 0:
        raise ValueError("No hidden sizes.")
    act = get_activation(activation)

    obs_dim, obs_init = _make_preconditioner(env.observation_space, label="observation space")

    return make_mlp(
        indim=obs_dim, hidden_dims=hidden_sizes, outdim=1,
        activation=act,
        initial_layers=[obs_init],
        following_layers=[op.NNLayerSqueeze(-1)],
    )


def regressive_policy_mlp(env, hidden_sizes,
                          activation="relu"):
    """
    Construct a MLP that directly outputs the action, i.e. a regressive
    approximator for a deterministic policy.

    Parameters
    ----------
    env : gym.core.Env
      The environment that encodes information about the observation and
      action spaces.
    hidden_sizes : List[int]
      A list of positive integers which contain the sizes of the hidden layers.
    activation : str or torch.nn.Module class
      Non-linearity to have between the layers in the MLP.
    """
    if len(hidden_sizes) == 0:
        raise ValueError("No hidden sizes.")
    act = get_activation(activation)

    obs_dim, obs_init = _make_preconditioner(
        env.observation_space,
        label="observation space",
    )

    if not isinstance(env.action_space, gym.spaces.Box):
        raise ValueError("A regressive policy can only be applied on a continuous action space.")

    action_dim = _prod(env.action_space.shape)

    layers = [
        obs_init,
        torch.nn.Linear(obs_dim, hidden_sizes[0]),
        act(),
    ]
    for i in range(len(hidden_sizes) - 1):
        layers += [
            torch.nn.Linear(hidden_sizes[i], hidden_sizes[i+1]),
            act(),
        ]
    layers += [
        torch.nn.Linear(hidden_sizes[-1], action_dim),
    ]

    layers += [
        torch.nn.Tanh(),
        torch.nn.Unflatten(1, env.action_space.shape),
        op.NNTanhRefit(low=env.action_space.low, high=env.action_space.high, shape=env.action_space.shape),
    ]

    return torch.nn.Sequential(*layers)


def gaussian_policy_mlp(env, hidden_sizes,
                        activation="relu",
                        use_tanh_refit=True,
                        override_obs_dim=None):
    """
    Constructs a MLP that outputs the parameters to a gaussian distribution
    over the action space from an observation input.

    The network consists of a common shared first layer, which is then split
    into two heads. A "mu head" for the mean of the distribution and a
    "std head" for the standard deviation of the distribution.

    Parameters
    ----------
    env : gym.core.Env
      The environment that encodes information about the observation and
      action spaces.
    hidden_sizes : List[int]
      A list of positive integers which contain the sizes of the hidden layers.
    activation : str or torch.nn.Module class
      Non-linearity to have between the layers in the MLP.
    mu_activation : str or torch.nn.Module class
      The activation at the end of the mu head.
    std_activation : str or torch.nn.Module class
      The activation at the end of the std head.
    use_tanh_refit : bool
      Allow usage of a tanh refit to ensure that the action generated actions
      fit within a bounded action space.
    """
    if len(hidden_sizes) == 0:
        raise ValueError("No hidden sizes.")
    act = get_activation(activation)

    obs_dim, obs_init = _make_preconditioner(
        env.observation_space,
        label="observation space",
        override_dim=override_obs_dim,
    )

    if not isinstance(env.action_space, gym.spaces.Box):
        raise ValueError("A gaussian policy can only be applied on a continuous action space.")

    action_dim = _prod(env.action_space.shape)

    common_head = make_mlp(
        indim=obs_dim, hidden_dims=hidden_sizes[:-1], outdim=hidden_sizes[-1],
        activation=act,
        initial_layers=[obs_init],
        following_layers=[act()],
    )
    mu_head = make_mlp(
        indim=hidden_sizes[-1], hidden_dims=[], outdim=action_dim,
        activation=act,
        following_layers=[torch.nn.Unflatten(1, env.action_space.shape)],
    )
    log_std_head = make_mlp(
        indim=hidden_sizes[-1], hidden_dims=[], outdim=action_dim,
        activation=act,
        following_layers=[torch.nn.Unflatten(1, env.action_space.shape)],
    )

    tanh_refit = None
    if use_tanh_refit:
        tanh_refit = op.NNTanhRefit(low=env.action_space.low, high=env.action_space.high, shape=env.action_space.shape)

    return policy.NNGaussianPolicy(
        common_head=torch.nn.Sequential(*common_head),
        mu_head=torch.nn.Sequential(*mu_head),
        log_std_head=torch.nn.Sequential(*log_std_head),
        tanh_refit=tanh_refit,
    )


def gaussian_multivar_policy_mlp(env, hidden_sizes, n_encode, maxpred,
                                 activation="relu",
                                 use_tanh_refit=True,
                                 override_obs_dim=None):
    """
    multi var n_encode policy
    maxpred is the maximum prediction length to the policy.
    """
    if len(hidden_sizes) == 0:
        raise ValueError("No hidden sizes.")
    if n_encode >= len(hidden_sizes):
        raise ValueError(f"Too many encode layers. Maximum: {(len(hidden_sizes) - 1)}")
    act = get_activation(activation)

    obs_dim, obs_init = _make_preconditioner(
        env.observation_space,
        label="observation space",
        override_dim=override_obs_dim,
    )

    if not isinstance(env.action_space, gym.spaces.Box):
        raise ValueError("A gaussian policy can only be applied on a continuous action space.")

    obs_dim = _prod(env.observation_space.shape)
    action_dim = _prod(env.action_space.shape)

    encoder_heads = []
    for i in range(maxpred):
        enc_indim = obs_dim + ((i + 1) * action_dim)
        enc_head = make_mlp(
            indim=enc_indim, hidden_dims=hidden_sizes[:n_encode], outdim=hidden_sizes[n_encode],
            activation=act,
            following_layers=[act()],
        )
        encoder_heads.append(torch.nn.Sequential(*enc_head))

    common_head = make_mlp(
        indim=hidden_sizes[n_encode], hidden_dims=hidden_sizes[n_encode:], outdim=hidden_sizes[-1],
        activation=act,
        initial_layers=[obs_init],
        following_layers=[act()],
    )
    mu_head = make_mlp(
        indim=hidden_sizes[-1], hidden_dims=[], outdim=action_dim,
        activation=act,
        following_layers=[torch.nn.Unflatten(1, env.action_space.shape)],
    )
    log_std_head = make_mlp(
        indim=hidden_sizes[-1], hidden_dims=[], outdim=action_dim,
        activation=act,
        following_layers=[torch.nn.Unflatten(1, env.action_space.shape)],
    )

    tanh_refit = None
    if use_tanh_refit:
        tanh_refit = op.NNTanhRefit(low=env.action_space.low, high=env.action_space.high, shape=env.action_space.shape)

    return policy.NNGaussianMultiVarLengthPolicy(
        encoder_heads=encoder_heads,
        common_head=torch.nn.Sequential(*common_head),
        mu_head=torch.nn.Sequential(*mu_head),
        log_std_head=torch.nn.Sequential(*log_std_head),
        tanh_refit=tanh_refit,
    )


def gaussian_multivar_policy_transformer(
        env, enc_hidden_sizes, maxpred,
        n_layers=3,
        n_heads=1,
        dim_feedforward=None,
        activation="relu",
        transformer_activation="gelu",
        use_tanh_refit=True,
        override_obs_dim=None):
    """
    multi var n_encode policy
    maxpred is the maximum prediction length to the policy.
    """
    if len(enc_hidden_sizes) == 0:
        raise ValueError("No hidden sizes.")
    act = get_activation(activation)

    obs_dim, obs_init = _make_preconditioner(
        env.observation_space,
        label="observation space",
        override_dim=override_obs_dim,
    )
    action_dim, action_init = _make_preconditioner(
        env.action_space,
        label="action space",
    )

    if not isinstance(env.action_space, gym.spaces.Box):
        raise ValueError("A gaussian policy can only be applied on a continuous action space.")

    H = enc_hidden_sizes[-1]
    if dim_feedforward is None:
        dim_feedforward = 4 * H

    obs_embed = make_mlp(
        indim=obs_dim, hidden_dims=enc_hidden_sizes[:-1], outdim=enc_hidden_sizes[-1],
        activation=act,
        initial_layers=[obs_init],
        following_layers=[act()],
    )
    action_embed = make_mlp(
        indim=action_dim, hidden_dims=enc_hidden_sizes[:-1], outdim=enc_hidden_sizes[-1],
        activation=act,
        initial_layers=[action_init],
        following_layers=[act()],
    )
    pos_embed = torch.nn.Embedding(maxpred, H)

    enc_transformer_layer = torch.nn.TransformerEncoderLayer(
        H,
        nhead=n_heads,
        dim_feedforward=dim_feedforward,
        activation=transformer_activation,
        batch_first=True,
    )
    enc_transformer = torch.nn.TransformerEncoder(enc_transformer_layer, n_layers)

    mu_head = make_mlp(
        indim=H, hidden_dims=[], outdim=action_dim,
        activation=act,
        following_layers=[torch.nn.Unflatten(1, env.action_space.shape)],
    )
    log_std_head = make_mlp(
        indim=H, hidden_dims=[], outdim=action_dim,
        activation=act,
        following_layers=[torch.nn.Unflatten(1, env.action_space.shape)],
    )

    tanh_refit = None
    if use_tanh_refit:
        tanh_refit = op.NNTanhRefit(low=env.action_space.low, high=env.action_space.high, shape=env.action_space.shape)

    return policy.NNGaussianTransformerVarLengthPolicy(
        state_embed=obs_embed,
        act_embed=action_embed,
        pos_embed=pos_embed,
        enc_transformer=enc_transformer,
        n_heads=n_heads,
        mu_head=mu_head,
        log_std_head=log_std_head,
        tanh_refit=tanh_refit,
    )


def categorical_policy_mlp(env, hidden_sizes,
                           activation="relu"):
    """
    Constructs a MLP that outputs the parameters to a categorical distribution
    over the action space from an observation input.

    Parameters
    ----------
    env : gym.core.Env
      The environment that encodes information about the observation and
      action spaces.
    hidden_sizes : List[int]
      A list of positive integers which contain the sizes of the hidden layers.
    activation : str or torch.nn.Module class
      Non-linearity to have between the layers in the MLP.
    """
    if len(hidden_sizes) == 0:
        raise ValueError("No hidden sizes.")
    act = get_activation(activation)

    obs_dim, obs_init = _make_preconditioner(env.observation_space, label="observation space")

    if not isinstance(env.action_space, gym.spaces.Discrete):
        raise ValueError("A categorical policy can only be applied on a discrete action space.")

    layers = [
        obs_init,
        torch.nn.Linear(obs_dim, hidden_sizes[0]),
            act(),
    ]
    for i in range(len(hidden_sizes) - 1):
        layers += [
            torch.nn.Linear(hidden_sizes[i], hidden_sizes[i+1]),
            act(),
        ]

    layers += [
        torch.nn.Linear(hidden_sizes[len(hidden_sizes) - 1], env.action_space.n),
    ]

    return policy.NNCategoricalPolicy(
        net=torch.nn.Sequential(*layers),
    )


def auto_stochastic_policy_mlp(env, hidden_sizes,
                               activation="relu",
                               use_tanh_refit=True):
    """
    Automatically selects a policy kind based on the action space. If the
    action space is discrete, then a categorical distribution is chosen.
    Otherwise a gaussian distribution is chosen.
    """
    if isinstance(env.action_space, gym.spaces.Discrete):
        return categorical_policy_mlp(env, hidden_sizes, activation=activation)
    else:
        return gaussian_policy_mlp(env, hidden_sizes, activation=activation, use_tanh_refit=use_tanh_refit)


def modelrnn_mlp(env, hidden_sizes, h_size,
                 activation="relu",
                 rnn_layers=1,
                 rnn_activation="tanh",
                 additive_output=False,):
    """
    Constructs a RNN Model with an MLP as its output network.
    """
    if len(hidden_sizes) == 0:
        raise ValueError("No hidden sizes.")
    act = get_activation(activation)
    if len(env.observation_space.shape) != 1 or len(env.action_space.shape) != 1:
        raise NotImplementedError(f"Only linear spaces supported at the moment. Unsupported spaces: S = {env.observation_space}, A = {env.action_space}")

    obs_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]

    output_net = [
        torch.nn.Linear(h_size, hidden_sizes[0]),
        act(),
    ]
    for i in range(len(hidden_sizes) - 1):
        output_net += [
            torch.nn.Linear(hidden_sizes[i], hidden_sizes[i+1]),
            act(),
        ]
    output_net += [torch.nn.Linear(hidden_sizes[-1], obs_dim)]

    return model.NNRecurrentModelRNN(obs_dim, action_dim, h_size,
        output_net=torch.nn.Sequential(*output_net),
        num_layers=rnn_layers,
        activation=rnn_activation,
        additive_output=additive_output,
    )


def modelfeedfwd_mlp(env, hidden_sizes,
                     activation="relu",):
    """
    Constructs a Feed-Forward Model with an MLP network.
    """
    act = get_activation(activation)
    if len(env.observation_space.shape) != 1 or len(env.action_space.shape) != 1:
        raise NotImplementedError(f"Only linear spaces supported at the moment. Unsupported spaces: S = {env.observation_space}, A = {env.action_space}")

    obs_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]

    output_net = [
        torch.nn.Linear(obs_dim + action_dim, hidden_sizes[0]),
        act(),
    ]
    for i in range(len(hidden_sizes) - 1):
        output_net += [
            torch.nn.Linear(hidden_sizes[i], hidden_sizes[i+1]),
            act(),
        ]
    output_net += [torch.nn.Linear(hidden_sizes[-1], obs_dim)]

    return model.NNRecurrentModelFeedForward(obs_dim, action_dim,
        net=torch.nn.Sequential(*output_net),
    )
