from script.models.GCN.model import GCN
from script.models.HGCN.model import HGCN
from script.models.kGCN.model import kGCN
from script.models.QGCN.model import QGCN
from script.models.Hypformer.model import Hypformer
from script.models.GraphGPS.model import GraphGPS
from script.models.NodeFormer.model import NodeFormer
from script.models.SGFormer.model import SGFormer
from script.models.FPST.model import FPST
from script.models.QGCN2.model import QGCN2
from script.models.QGT.model import QGT


def load_model(args, logger):
    if args.model == 'GCN':
        model = GCN(args)
    elif args.model == 'HGCN':
        model = HGCN(args)
    elif args.model == 'kGCN':
        model = kGCN(args)
    elif args.model == 'QGCN':
        model = QGCN(args)
    elif args.model == 'Hypformer':
        model = Hypformer(args)
    elif args.model == 'GraphGPS':
        model = GraphGPS(args)
    elif args.model == 'NodeFormer':
        model = NodeFormer(args)
    elif args.model == 'SGFormer':
        model = SGFormer(args)
    elif args.model == 'FPST':
        model = FPST(args)
    elif args.model == 'QGCN2':
        model = QGCN2(args)
    elif args.model == 'QGT':
        model = QGT(args)
    else:
        raise Exception('pls define the model')
    logger.info('using model {} '.format(args.model))
    return model
