from typing import Final

from src.models.gnn_classes.expander_graph_propagation import ExpanderGNN
from src.models.gnn_classes.g2 import G2_GNN
from src.models.gnn_classes.gcn import GCN, GCNPair, GATv2, GIN, GraphSAGE
from src.models.gnn_classes.gcnii import GCNII

MODEL_NAME_MAPPING: Final[dict[str, any]] = {
    "gcn": GCN,
    "gcnii": GCNII,
    "g2": G2_GNN,
    "expander": ExpanderGNN,
    "gcn_pair": GCNPair,
    "gatv2": GATv2,
    "gat_v2": GATv2,
    "gin": GIN,
    "sage": GraphSAGE
}
