from .gnns import GCN, U_GCN, SGC
from .feedforward_models import LinearModel, SGC_MLP
from .streaming_lda import StreamingLDA
from .graph_neural_features import GraphRandomNeuralFeatures

def get_classifier(args):
    if args.method == 'slda':
        return StreamingLDA(args)
    if args.backbone in ['UGCN', 'GRNF', 'UMIXED']:
        return LinearModel(args)
    if args.backbone == 'GCN':
        return GCN(args)
    if args.backbone == 'SGC':
        return SGC_MLP(args)

def get_feat_extractor(args):
    if args.backbone == 'UGCN':
        return U_GCN(args)
    if args.backbone == 'GRNF':
        return GraphRandomNeuralFeatures(args.backbone_args['h_dims'][0], args.d_data, 1,
                                        order_2_prc=args.backbone_args['order_2_prc'],
                                        gain=args.gain)
    if args.backbone == 'SGC':
        return SGC(args)
    return None
