from . import CHT_Model
from . import VGG
from . import MLP
import sys
sys.path.append("..")
from Models.load_data import load_data_mlp

def prepare_model_and_loader(args):
    train_loader,test_loader,indim,outdim,hiddim=load_data_mlp(args.dataset,args.bs,args.dim)
    
    if args.architecture != 'MLP':
        if args.conv_sparsity == 0.0:
            ann = VGG.VGG16_optimalThres(outdim, args.one_fc)
        else:
            ann = CHT_Model.VGG16_CHT_optimalThres(outdim, args.one_fc)
        num_activations=16 if not args.one_fc else 14  #for VGG-16
    else:
        ann = MLP.MLP_optimalThres(indim, hiddim, outdim)
        num_activations=4

    return ann, num_activations, train_loader,test_loader