from models.gcn import GCN, SGC, GCNL, newGCN, newGCNL, newSGC
from models.gat import GAT, GATL, newGATL, newGAT
from models.gcn2 import GCN2, newGCN2

def construct_model(num_features, num_classes, args):
    if hasattr(args, 'dropout'):
        dropout = args.dropout
    if args.model == 'GAT':
        header = 8
        dropout = 0.6
        net = GAT(num_features, args.hidden_features // header, header, num_classes, dropout=dropout)
    elif args.model == 'newGAT':
        header = 8
        dropout = 0.8
        net = newGAT(num_features, args.hidden_features // header, header, num_classes, dropout=dropout)
    elif args.model == 'GATL':
        header = 4
        net = GATL(num_features, args.hidden_features // header, header, num_classes)
    elif args.model == 'newGATL':
        header = 4
        net = newGATL(num_features, args.hidden_features // header, header, num_classes)
    elif args.model == 'SGC':
        if args.dataset == 'PPI':
            dropout = 0.4
        else:
            dropout = 0.8
        net = SGC(num_features, args.depth, num_classes, dropout)
    elif args.model == 'GCN2':
        if args.dataset == 'PPI':
            alpha = 0.5
            theta = 1.0
            dropout = 0.4
        else:
            alpha = 0.1
            theta = 0.5
            dropout = 0.6
        net = GCN2(num_features, args.hidden_features, num_classes, args.depth, alpha=alpha, theta=theta, dropout=dropout)
    elif args.model == 'newGCN':
        dropout = 0.8
        net = newGCN(num_features, args.hidden_features, num_classes, dropout)
    elif args.model == 'newGCNL':
        net = newGCNL(num_features, args.hidden_features, num_classes)
    elif args.model == 'newGCN2':
        if args.dataset == 'PPI':
            alpha = 0.5
            theta = 1.0
            dropout = 0.4
        elif args.dataset in ['Cora', 'CiteSeer']:
            alpha = 0.1
            theta = 0.5
            dropout = 0.4
        else:
            alpha = 0.1
            theta = 0.5
            dropout = 0.6
        net = newGCN2(num_features, args.hidden_features, num_classes, args.depth, alpha=alpha, theta=theta, dropout=dropout)
    elif args.model == 'GCNL':
        net = GCNL(num_features, args.hidden_features, num_classes)
    elif args.model == 'newSGC':
        if args.dataset == 'PPI':
            dropout = 0.4
        else:
            dropout = 0.8
        net = newSGC(num_features, args.depth, num_classes, dropout)
    else:
        dropout = 0.4
        net = GCN(num_features, args.hidden_features, num_classes, dropout=dropout)
    return net