from typing import Any, Callable, Sequence, Optional
import jax
from jax import numpy as jnp
import flax
import jraph


class MLP_layer(flax.linen.Module):
    num_layers : int
    hidden_dim : int
    output_dim : int

    @flax.linen.compact
    def __call__(self, x):
        if self.num_layers == 1:
            return flax.linen.Dense(self.output_dim)(x)
        else:
            for _ in range(self.num_layers-1):
                t1 = flax.linen.Dense(self.hidden_dim)(x)
                t1 = flax.linen.LayerNorm()(t1)
                x = flax.linen.relu(t1)
            return flax.linen.Dense(self.output_dim)(x)


class PositionwiseFeedForwardNetwork(flax.linen.Module):
    widening_factor : int

    @flax.linen.compact
    def __call__(self, inputs):
        x = inputs
        hidden_size = x.shape[-1]
        x = flax.linen.Dense(self.widening_factor * hidden_size, use_bias = True)(x)
        x = flax.linen.relu(x)
        x = flax.linen.Dense(hidden_size, use_bias = True)(x)
        return x


class GINLayer(flax.linen.Module):
    mlp_num_layers : int 
    mlp_hidden_dim : int
    node_embedding_size : int

    def make_network(self):
        eps = self.param('eps', lambda rng, shape: jnp.zeros(shape), (1,))

        def update_edge_fn(edges, sent_attributes, received_attributes, global_edge_attributes):
            """
            """
            m = jnp.concatenate([sent_attributes, received_attributes, edges], axis = -1)
            m = flax.linen.Dense(self.node_embedding_size)(m)
            m = flax.linen.relu(m)
            return m

        def aggregate_edges_for_nodes_fn(messages, indices, num_segments):
            return jax.ops.segment_sum(messages, indices, num_segments)

        def update_node_fn(nodes, sent_messages, received_messages, global_attributes):
            x = (1 + eps) * nodes + received_messages
            x = MLP_layer(self.mlp_num_layers, self.mlp_hidden_dim, self.node_embedding_size)(x)
            return x

        def aggregate_nodes_for_globals_fn(nodes, indices, num_segments):
            return jax.ops.segment_sum(nodes, indices, num_segments)

        def update_global_fn(node_attributes, edge_attribtutes, globals_):
            return flax.linen.Dense(self.node_embedding_size, use_bias = False)(node_attributes)

        gn = jraph.GraphNetwork(update_edge_fn = update_edge_fn, 
                        update_node_fn = update_node_fn, 
                        update_global_fn = update_global_fn,
                        aggregate_edges_for_nodes_fn = aggregate_edges_for_nodes_fn, 
                        aggregate_nodes_for_globals_fn = aggregate_nodes_for_globals_fn, 
                        # aggregate_edges_for_globals_fn = <function segment_sum>, 
                        attention_logit_fn = None, 
                        # attention_normalize_fn = <function segment_softmax>, 
                        attention_reduce_fn = None)
        return gn

    # def setup(self):
    #     self.gn = self.make_network()

    @flax.linen.compact
    def __call__(self, G):
        init_globals = G.globals
        G = self.make_network()(G)
        _globals = G.globals
        G = G._replace(globals = init_globals)
        return G, _globals