from dataclasses import dataclass
import pydoc
from typing import Any, Optional

from ruamel.yaml import YAML, yaml_object
from torch_geometric.nn.models.basic_gnn import BasicGNN

from constants import TORCH_GEOMETRIC_MODELS_MODULE
from .._model_config import ModelConfig
from ._model import SimpleGNN


@yaml_object(YAML())
@dataclass()
class SimpleGNNConfig(ModelConfig):
    """
    Fields:

    - `gnn_architecture`: The type of GNN to use.
                          Must be the name of a class in `torch_geometric.nn.models` that inherits from `BasicGNN`,
                          e.g. `"GCN"`.
    - `in_channels`: The number of input channels for the GNN.
    - `gnn_hidden_channels`: The number of hidden channels for the GNN.
    - `gnn_out_channels`: The number of output channels for the GNN.
    - `gnn_layers`: The number of layers for the GNN.
    - `gnn_kwargs`: Additional keyword arguments for the GNN constructor.
    - `edge_classifier_hidden_channels`: The number of hidden channels for the MLP edge classifier.
    - `edge_classifier_layers`: The number of layers for the MLP edge classifier.
    """

    gnn_architecture: str = "GraphSAGE"
    in_channels: int = 128
    gnn_hidden_channels: int = 128
    gnn_out_channels: int = 128
    gnn_layers: int = 6
    gnn_kwargs: Optional[dict[str, Any]] = None
    edge_classifier_hidden_channels: int = 64
    edge_classifier_layers: int = 2

    # overriding base class constants (no type hints on purpose)
    MODEL_NAME = "simple-gnn"
    MODEL_CLASS = SimpleGNN

    def get_gnn_architecture(self) -> type[BasicGNN]:
        architecture = pydoc.locate(TORCH_GEOMETRIC_MODELS_MODULE + "." + self.gnn_architecture)
        assert issubclass(architecture, BasicGNN)
        return architecture
