import tensorflow as tf
import tensorflow_gnn as tfgnn


def create_model(params: dict, training: bool) -> tf.keras.Model:
    """
    Args:
        params: hyperparameter dictionary
        training:

    Returns:

    """
    gnn = tfgnn.keras.ConvGNNBuilder(
        lambda edge_set_name: WeightedSumConvolution(),
        lambda node_set_name: tfgnn.keras.layers.NextStateFromConcat(
            tf.keras.layers.Dense(params[node_set_name])
        ),
    )
    model = tf.keras.models.Sequential(
        [
            gnn.Convolve({"genre"}),
            gnn.Convolve({"user"}),
            tfgnn.keras.layers.Readout(node_set_name="user"),
            tf.keras.layers.Dense(1),
        ]
    )
    return model


class WeightedSumConvolution(tf.keras.layers.Layer):
    def call(self, graph: tfgnn.GraphTensor, edge_set_name: tfgnn.EdgeSetName) -> tfgnn.Field:
        messages = tfgnn.broadcast_node_to_edges(
            graph, edge_set_name, tfgnn.SOURCE, feature_name=tfgnn.DEFAULT_STATE_NAME
        )
        weights = graph.edge_sets[edge_set_name]["weight"]
        weighted_messages = tf.expand_dims(weights, -1) * messages
        pooled_messages = tfgnn.pool_edges_to_node(
            graph,
            edge_set_name,
            tfgnn.TARGET,
            reduce_type="sum",
            feature_value=weighted_messages,
        )
        return pooled_messages


def test():
    params = {"user": 256, "movie": 64, "genre": 128}
    create_model(params, True)
