MinMax BNN-Repetition Guildlines:
First download the work of CLTR, then in mcrgan/models/__init__.py, build function get_BNN_models(data_name, device,label=0)


def get_BNN_models(data_name, device,label=0):
# noting to set the setting without bias to get stable radius
    if data_name == "mnist":
        if label == 0:
            netV = DiscriminatorMNIST().to(device)
            netV.apply(weights_init_mnist_model_Var)
            netV = nn.DataParallel(netV)
            netG = DiscriminatorMNIST().to(device)
            netG.apply(weights_init_mnist_model)
            netG = nn.DataParallel(netG)
            netD = DiscriminatorMNIST().to(device)
            netD.apply(weights_init_mnist_model)
            netD = nn.DataParallel(netD)


return netD, netG, netV

Then also here, build get_noise(data_name, device, label=0)
def get_noise(data_name, device, label=0):

    if data_name == "mnist":
        if label == 0:
            netNoise = DiscriminatorMNIST()
            netNoise = nn.DataParallel(netNoise.to(device))
            netNoise.apply(weights_init_mnist_model_noise)
return netnoise

and at the end of  utlis/utlis, put the following code:
# which replace Zbar as f+noise
def extract_features2(data_loader, encoder, decoder):

    X_all = []
    X_bar_all = []
    Z_all = []
    Z_bar_all = []
    labels_all = []
    train_bar = tqdm(data_loader, desc="extracting all features from dataset")
    with torch.no_grad():
        encoder.eval()
        decoder.eval()
        for step, (X, labels) in enumerate(train_bar):
            Z = encoder(X.cuda())
            Z_bar = decoder(X.cuda())

            # X_all.append(X.cpu())
            Z_all.append(Z.view(-1, Z.shape[1]).cpu())
            # X_bar_all.append(X_bar.cpu())
            Z_bar_all.append(Z_bar.view(-1, Z_bar.shape[1]).cpu())

            labels_all.append(labels)

    return None, torch.cat(Z_all), None, torch.cat(Z_bar_all), torch.cat(labels_all)

For other functions and trained models, you can find that on the jupyter notebook.
And label 0 is model with 128 dim, label 1 is 64, label 2 is 32, label 3 is 11, 
label 4 is with batch normalization, and label 5 is with bias.
