from .ResNet import *
from .VGG import *
from .MLP import *
from .load_data import load_data_mlp
from .CHTModel import VGG16_CIFAR_BN
import sys
sys.path.append("..")
import CHT

def prepare_model_and_loader(args):
    train_loader,test_loader,indim,outdim,hiddim=load_data_mlp(args.dataset,args.bs,args.dim)
    
    model_name=args.architecture.lower()
    match model_name:
        case 'resnet-20': model, num_activations=resnet20(num_classes=outdim), 15
        case 'vgg-16': 
            if args.conv_sparsity==0.0:
                model = vgg16(num_classes=outdim, dropout=args.dropout, one_fc=args.one_fc)
            else:
                model = VGG16_CIFAR_BN(CHT.hanming_config, outdim, dropout=args.dropout, one_fc=args.one_fc)
            num_activations = 13 if args.one_fc else 15
        case 'mlp': model, num_activations = MLP(indim,hiddim,outdim,dropout=0), 3
        case _ : raise NotImplemented

    return model, num_activations, train_loader,test_loader