from dataclasses import dataclass
from typing import Literal

from ruamel.yaml import YAML, yaml_object

from .._model_config import ModelConfig
from ._model import ResidualGatedGCN


@yaml_object(YAML())
@dataclass()
class ResidualGatedGCNConfig(ModelConfig):
    """
    Fields:

    - `num_nodes`: The number of nodes expected in the input graph.
    - `node_channels`: The dimensionality of the node features.
    - `voc_edges_in`: Input edge signal vocabulary size.
    - `hidden_channels`: The number of hidden channels for the GNN and MLP.
    - `gnn_layers`: The number of layers for the GNN.
    - `mlp_layers`: The number of layers for the MLP.
    - `gnn_aggregation`: Aggregation scheme for GCN. Must be `"mean"` or `"sum"`.
    """

    num_nodes: int
    voc_edges_in: int
    hidden_channels: int
    gnn_layers: int
    mlp_layers: int
    node_channels: int = 2
    gnn_aggregation: Literal["mean", "sum"] = "mean"

    # overriding base class constants (no type hints on purpose)
    MODEL_NAME = "residual-gated-gcn"
    MODEL_CLASS = ResidualGatedGCN
