from equislt.methods import base
from equislt.methods.gcn import GCN, PruneGCN
from equislt.methods.e2cnn import E2CNN, PruneE2CNN
from equislt.methods.ign import InvariantGraphNets, PruneInvariantGraphNets

TRAIN_METHODS = {
    "gcn": GCN,
    "e2cnn": E2CNN,
    "ign": InvariantGraphNets,
}

__all__ = ["base",] + list(TRAIN_METHODS.values())

PRUNE_METHODS = {
    "gcn": PruneGCN,
    "e2cnn": PruneE2CNN,
    "ign": PruneInvariantGraphNets,
}
