import haiku as hk
import jax
import jax.numpy as jnp
from haiku_geometric.nn import SAGEConv


class GraphQNetwork(hk.Module):
    def __init__(self, hidden_channels, num_registers: int = 1):
        super().__init__()
        self.hidden_channels = hidden_channels
        self.conv1 = SAGEConv(hidden_channels)
        self.conv2 = SAGEConv(hidden_channels)
        self.conv3 = SAGEConv(hidden_channels)
        self.conv4 = SAGEConv(hidden_channels)

        self.node_features_proj = hk.Linear(hidden_channels)

        self.linear_node_features = hk.Linear(hidden_channels)
        self.linear_pool = hk.Linear(hidden_channels)
        self.linear_node = hk.Linear(hidden_channels)
        self.linear_aux = hk.Linear(hidden_channels)
        self.linear_value = hk.Linear(1)

        self.num_registers = num_registers
        self.register_node_init = hk.get_parameter(
            "register_node_init",
            shape=(num_registers, hidden_channels),
            init=hk.initializers.RandomNormal(),
        )

    def __call__(self, nodes, senders, receivers, aux):
        # Add a vector of 32 ones to the node features
        nodes = self.linear_node_features(nodes)  # (nodes, hidden_channels)
        nodes = jax.nn.relu(nodes)
        nodes = jnp.concatenate(
            [nodes, jnp.ones((nodes.shape[0], self.hidden_channels))], axis=-1
        )
        nodes = self.node_features_proj(nodes)  # (nodes, hidden_channels)
        nodes = jax.nn.relu(nodes)

        register_nodes = self.register_node_init  # (K, hidden)
        num_nodes = nodes.shape[0]
        nodes = jnp.concatenate(
            [nodes, register_nodes], axis=0
        )  # (n_nodes + K, hidden)

        # There are n_nodes + K nodes in the graph. We want to connect nodes [n_nodes:] to all the prior nodes.
        receivers_ = jnp.arange(num_nodes)
        senders_ = jnp.full(num_nodes, num_nodes, dtype=jnp.int32)
        senders_ = jnp.concatenate([senders, senders_], axis=0)
        receivers_ = jnp.concatenate([receivers, receivers_], axis=0)

        # Make the graph undirected by adding the reverse edges
        senders = jnp.concatenate([senders_, receivers_], axis=0)
        receivers = jnp.concatenate([receivers_, senders_], axis=0)

        x = self.conv1(nodes, senders, receivers)
        x = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)(jax.nn.relu(x))
        x = x + hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)(
            jax.nn.relu(self.conv2(x, senders, receivers))
        )
        x = x + hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)(
            jax.nn.relu(self.conv3(x, senders, receivers))
        )
        x = x + hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)(
            jax.nn.relu(self.conv4(x, senders, receivers))
        )  # (n_nodes + K, hidden_channels)

        pool_rep = jnp.sum(x, axis=0)  # (hidden_channels,)
        pool_rep = self.linear_pool(pool_rep)  # (out_channels,)

        node_rep = self.linear_node(x)  # (n_nodes + K, out_channels)
        aux_rep = self.linear_aux(aux)  # (out_channels)
        # Repeat aux_rep to match the number of nodes
        aux_rep = jnp.repeat(
            aux_rep[jnp.newaxis, :], x.shape[0], axis=0
        )  # (n_nodes + K, out_channels)

        pool_rep = jnp.repeat(pool_rep[jnp.newaxis, :], x.shape[0], axis=0)
        node_pool_concat = jnp.concat(
            [pool_rep, node_rep], axis=-1
        )  # (n_nodes + K, out_channels * 2)
        node_pool_concat = jax.nn.relu(
            node_pool_concat
        )  # (n_nodes + K, out_channels * 2)
        node_pool_aux_concat = jnp.concat(
            [node_pool_concat, aux_rep], axis=-1
        )  # (n_nodes + K, out_channels * 3)
        values = self.linear_value(node_pool_aux_concat)  # (n_nodes + K, 1)

        # Remove the register nodes from the values
        values = values[: -self.num_registers]  # (n_nodes, 1)

        return values.squeeze(-1)  # (n_nodes,) The value for each node


def create_graph_q_network(hidden_channels):
    def model_fn(
        nodes: jax.Array, senders: jax.Array, receivers: jax.Array, aux: jax.Array
    ) -> jax.Array:
        net = GraphQNetwork(hidden_channels=hidden_channels)
        return jax.vmap(net)(nodes, senders, receivers, aux)

    model = hk.without_apply_rng(hk.transform(model_fn))
    return model
