import torch
from alg.models import mytestNNet
from alg.SAalg import SAbound, MYSAalg, inv_binomial
from alg.data import loaddataset


   
def SArunexp(data_slice_idx, g, C, name_data, delta, learning_rate, momentum, batch_size, train_epochs, dropout_prob, n_points_dataset, fraction_pretrain, device):

          
    loader_kargs =  {'num_workers':0}

    # load data
    train, test = loaddataset(name_data)
    
     
    # full dataset for testing
    test_loader = torch.utils.data.DataLoader(test, batch_size=len(test), shuffle=False, **loader_kargs)
    
    # run meta-alg and SGD+test-set jointly
    num_supp_init = int(fraction_pretrain*n_points_dataset)
    test_err_binomial, SGD_p_outside, net, supp_indx, nonsupp_indx = MYSAalg(data_slice_idx, C, num_supp_init, n_points_dataset, learning_rate, momentum, batch_size, train_epochs, dropout_prob, device, name_data)
    

    # compute generalization error bound for SA alg
    supp_num = len(supp_indx) - num_supp_init 
    nonsupp_num = len(nonsupp_indx)
    total_num = supp_num + nonsupp_num
    #
    sa_ub = SAbound(supp_num, total_num, delta)
    bin_ub = inv_binomial(total_num, total_num*test_err_binomial, delta)
    test = next(iter(test_loader))
    SA_p_outside = mytestNNet(net, test, C) 
    
    
    return bin_ub, sa_ub, SGD_p_outside, SA_p_outside, 
    
