from models.Nets import CNNCifar,CNNMnist,MLP,Linear,MM_CNN,MLP1,MLP2,CNN1Cifar,CNN2Cifar,CNN3Cifar

def build_model(args):
    # build model
    if args.model == 'cnn' and args.dataset == 'cifar':
        net_glob = CNNCifar(args=args).to(args.device)
    elif args.model == 'cnn1' and args.dataset == 'cifar':
        net_glob = CNN1Cifar(args=args).to(args.device)
    elif args.model == 'cnn2' and args.dataset == 'cifar':
        net_glob = CNN2Cifar(args=args).to(args.device)
    elif args.model == 'cnn3' and args.dataset == 'cifar':
        net_glob = CNN3Cifar(args=args).to(args.device)
    elif args.model == 'cnn' and args.dataset == 'mnist':
        net_glob = CNNMnist(args=args).to(args.device)
    elif args.model == 'mlp':
        len_in = 1
        for x in args.img_size:
            len_in *= x
        net_glob = MLP(dim_in=len_in, dim_hidden=200,
                       dim_out=args.num_classes).to(args.device)    
    elif args.model == 'mlp1':
        len_in = 1
        for x in args.img_size:
            len_in *= x
        net_glob = MLP1(dim_in=len_in, dim_hidden=200,
                       dim_out=args.num_classes).to(args.device)
    elif args.model == 'mlp2':
        len_in = 1
        for x in args.img_size:
            len_in *= x
        net_glob = MLP2(dim_in=len_in, dim_hidden=200,
                       dim_out=args.num_classes).to(args.device)
    elif args.model == 'linear':
        net_glob = Linear(d=args.d,n=args.n).to(args.device)
    elif args.model == 'fmnist_cnn':
        net_glob = MM_CNN(args).to(args.device)
    else:
        exit('Error: unrecognized model')
    print(net_glob)
    net_glob.train()
    return net_glob