import tensorflow as tf
import numpy as np
import tensorflow_probability as tfp

from tf_agents.networks import categorical_projection_network
from tf_agents.networks import normal_projection_network
from tf_agents.networks import network, encoding_network
from tf_agents.networks import sequential
from tf_agents.specs import tensor_spec
from tf_agents.utils import nest_utils


def _categorical_projection_net(action_spec, logits_init_output_factor=0.1):
    return categorical_projection_network.CategoricalProjectionNetwork(
        action_spec, logits_init_output_factor=logits_init_output_factor
    )


def _normal_projection_net(
    action_spec,
    init_action_stddev=0.35,
    init_means_output_factor=0.1,
    seed_stream_class=tfp.util.SeedStream,
    seed=None,
):
    std_bias_initializer_value = np.log(np.exp(init_action_stddev) - 1)

    return normal_projection_network.NormalProjectionNetwork(
        action_spec,
        init_means_output_factor=init_means_output_factor,
        std_bias_initializer_value=std_bias_initializer_value,
        scale_distribution=False,
        seed_stream_class=seed_stream_class,
        seed=seed,
    )


# Customized ActorDistributionNetwork
class my_network(network.DistributionNetwork):
    def __init__(
        self,
        env,
        input_tensor_spec,
        output_tensor_spec,
        fc_layer_params,
        name="my_network",
    ):
        # Default parameters defined in ActorDistributionNetwork
        kernel_initializer = tf.compat.v1.keras.initializers.glorot_uniform()
        discrete_projection_net = _categorical_projection_net
        continuous_projection_net = _normal_projection_net
        seed = None
        seed_stream_class = tfp.util.SeedStream

        def dense_layer(num_units):
            return tf.keras.layers.Dense(
                num_units,
                activation=tf.keras.activations.relu,
                kernel_initializer=tf.keras.initializers.VarianceScaling(
                    scale=2.0, mode="fan_in", distribution="truncated_normal"
                ),
            )

        def map_proj(spec):
            if tensor_spec.is_discrete(spec):
                return discrete_projection_net(spec)
            else:
                kwargs = {}
                if continuous_projection_net is _normal_projection_net:
                    kwargs["seed"] = seed
                    kwargs["seed_stream_class"] = seed_stream_class
                return continuous_projection_net(spec, **kwargs)

        projection_networks = tf.nest.map_structure(map_proj, output_tensor_spec)
        output_spec = tf.nest.map_structure(
            lambda proj_net: proj_net.output_spec, projection_networks
        )

        super(my_network, self).__init__(
            input_tensor_spec=input_tensor_spec,
            state_spec=(),
            output_spec=output_spec,
            name=name,
        )

        self._projection_networks = projection_networks
        self._output_tensor_spec = output_tensor_spec

        super(my_network, self).__init__(
            input_tensor_spec=input_tensor_spec,
            state_spec=(),
            output_spec=output_spec,
            name=name,
        )

        dense_layers = [dense_layer(num_units) for num_units in fc_layer_params]
        emb = tf.keras.layers.Embedding(7, 20, input_length=env.obs_shape)
        mean_emb = tf.keras.layers.Lambda(lambda x: tf.keras.backend.mean(x, axis=1))
        encoder = sequential.Sequential([emb] + [mean_emb] + dense_layers)
        self._encoder = encoder

    def call(
        self, observations, step_type=(), network_state=(), training=False, mask=None
    ):
        state, _ = self._encoder(observations)

        outer_rank = nest_utils.get_outer_rank(observations, self.input_tensor_spec)

        def call_projection_net(proj_net):
            distribution, _ = proj_net(state, outer_rank, training=training, mask=mask)
            return distribution

        output_actions = tf.nest.map_structure(
            call_projection_net, self._projection_networks
        )
        return output_actions, network_state
