"""Models for GNN experiments."""

import abc

import haiku as hk
import jax
import jax.numpy as jnp
import jax.tree_util as tree
import jraph


def get_module(num_layers: int, hidden_size: int):
  layers = []
  for _ in range(num_layers - 1):
    layers.append(hk.Linear(hidden_size))
    layers.append(jax.nn.relu)

  layers.append(hk.Linear(hidden_size))
  return hk.Sequential(layers)


class BaseModel(abc.ABC):
  """Base class for graph models."""

  def __init__(self, num_node_labels: int, num_edge_labels: int,
               num_graph_labels: int, hidden_size: int, mp_steps: int,
               num_layers: int, use_centrality_encoding: bool):
    self._num_node_labels = num_node_labels
    self._num_edge_labels = num_edge_labels
    self._num_graph_labels = num_graph_labels
    self._hidden_size = hidden_size
    self._num_layers = num_layers
    self._mp_steps = mp_steps
    self._use_centrality_encoding = use_centrality_encoding

  def get_update_global_fn(self):

    @jraph.concatenated_args
    def update_global_fn(feats: jnp.ndarray) -> jnp.ndarray:
      """Global update function for graph net."""
      net = get_module(self._num_layers, self._hidden_size)
      return net(feats)

    return update_global_fn

  def get_edge_update_fn(self):

    @jraph.concatenated_args
    def edge_update_fn(feats: jnp.ndarray) -> jnp.ndarray:
      """Edge update function for graph net."""
      net = get_module(self._num_layers, self._hidden_size)
      return net(feats)

    return edge_update_fn

  def get_node_update_fn(self):

    @jraph.concatenated_args
    def node_update_fn(feats: jnp.ndarray) -> jnp.ndarray:
      """Node update function for graph net."""
      net = get_module(self._num_layers, self._hidden_size)
      return net(feats)

    return node_update_fn

  def net_fn(self, graph: jraph.GraphsTuple) -> jraph.GraphsTuple:
    """Graph net function."""

    # Add centrality encoding, if applicable
    if self._use_centrality_encoding:
      regular_feats = graph.nodes
      centrality_dim = regular_feats.shape[1]
      in_embed = jnp.squeeze(
          hk.Embed(vocab_size=500,
                   embed_dim=centrality_dim)(graph.nodes['in_degs']),
          axis=1)
      out_embed = jnp.squeeze(
          hk.Embed(vocab_size=500,
                   embed_dim=centrality_dim)(graph.nodes['out_degs']),
          axis=1)

      all_feats = regular_feats + in_embed + out_embed
      graph = graph._replace(nodes=all_feats)
    else:
      graph = graph._replace(nodes=graph.nodes)

    encoder = jraph.GraphMapFeatures(
        embed_node_fn=hk.Linear(self._hidden_size),
        embed_edge_fn=hk.Linear(self._hidden_size),
        embed_global_fn=hk.Linear(self._hidden_size))
    decoder = jraph.GraphMapFeatures(
        embed_node_fn=hk.Linear(self._num_node_labels),
        embed_edge_fn=hk.Linear(self._num_edge_labels),
        embed_global_fn=hk.Linear(self._num_graph_labels))

    graph = encoder(graph)
    for _ in range(self._mp_steps):
      graph = self.processor_fn(graph)
    return decoder(graph)

  @abc.abstractmethod
  def processor_fn(self, graph: jraph.GraphsTuple) -> jraph.GraphsTuple:
    raise NotImplementedError('processor_fn needs to be implemented')


class MPNN(BaseModel):
  """MPNN model."""

  def processor_fn(self, graph: jraph.GraphsTuple) -> jraph.GraphsTuple:
    processor = jraph.GraphNetwork(
        update_node_fn=self.get_node_update_fn(),
        update_edge_fn=self.get_edge_update_fn(),
        update_global_fn=self.get_update_global_fn())
    return processor(graph)


class GCN(BaseModel):
  """GCN model."""

  def processor_fn(self, graph: jraph.GraphsTuple) -> jraph.GraphsTuple:
    processor = jraph.GraphConvolution(
        update_node_fn=self.get_node_update_fn(), add_self_edges=True)
    return processor(graph)


class GAT(BaseModel):
  """GAT model."""

  def get_edge_logit_fn(self):

    @jraph.concatenated_args
    def edge_logit_fn(feats: jnp.ndarray) -> jnp.ndarray:
      """Edge logit function for attention mechanism."""
      net = hk.Sequential(
          [hk.Linear(self._hidden_size), jax.nn.relu,
           hk.Linear(1)])
      return net(feats)

    return edge_logit_fn

  def processor_fn(self, graph: jraph.GraphsTuple) -> jraph.GraphsTuple:
    # add self-edges
    total_num_nodes = tree.tree_leaves(graph.nodes)[0].shape[0]
    new_senders = jnp.concatenate((graph.senders, jnp.arange(total_num_nodes)),
                                  axis=0)
    new_receivers = jnp.concatenate(
        (graph.receivers, jnp.arange(total_num_nodes)), axis=0)
    new_edges = jnp.concatenate(
        (graph.edges, jnp.zeros((total_num_nodes, jnp.shape(graph.edges)[-1]))),
        axis=0)
    graph = graph._replace(
        senders=new_senders, receivers=new_receivers, edges=new_edges)
    processor = jraph.GAT(
        attention_query_fn=self.get_node_update_fn(),
        attention_logit_fn=self.get_edge_logit_fn())

    return processor(graph)


MODELS_DICT = {'mpnn': MPNN, 'gcn': GCN, 'gat': GAT}
