import torch
from alg.models import testNNet
from alg.SAalg import SAbound, MYSAalg, inv_binomial
from alg.data import loaddataset



def SArunexp(data_slice_idx, g, C, name_data, delta, learning_rate, momentum, batch_size, train_epochs, dropout_prob, n_datapoints, pretrain_frac, device):


    loader_kargs = {'num_workers':0}

    # load data
    train, test = loaddataset(name_data)

    
    # full dataset for testing
    test_loader = torch.utils.data.DataLoader(test, batch_size=len(test), shuffle=False, **loader_kargs)
    
    # run meta-alg (referred to as SA) and SGD jointly
    num_supp_init = int(pretrain_frac*n_datapoints)
    test_err_binomial, SGD_p_misclass, net, supp_indx, nonsupp_indx = MYSAalg(test_loader, data_slice_idx, C, num_supp_init, n_datapoints, learning_rate, momentum, batch_size, train_epochs, dropout_prob, device, name_data)
    

    # compute generalization error bound for SA alg and SGD+test-set
    supp_num = len(supp_indx) - num_supp_init 
    nonsupp_num = len(nonsupp_indx)
    total_num = supp_num + nonsupp_num
    #
    sa_ub = SAbound(supp_num, total_num, delta)
    bin_ub = inv_binomial(total_num, total_num*test_err_binomial, delta)
    SA_p_misclass = testNNet(net, test_loader, device)
    
    
    return bin_ub, sa_ub, SA_p_misclass, SGD_p_misclass
  
    
