from add_sys import *
from multiple_experiment_launcher import multiple_experiment
from iterator_exp import ExperimentIterator
from utils_experiment import create_name


if __name__ == '__main__' :
    args = get_default()
    args = get_default_local(args)
                                     
    args.args_trainer.monte_carlo_gradient_estimator = "REBAR"
    args.args_distribution_module.distribution_module = "REBARBernoulli"
    args.args_distribution_module.distribution = "Bernoulli"
    args.args_distribution_module.distribution_relaxed = "RelaxedBernoulli"

    origin_path = args.args_output.path
    args.args_trainer.complete_trainer = "trainingWithSelection"
    args.args_interpretable_module.interpretable_module = "EVAL_X"
    args.args_classification.imputation = "ConstantImputation"

    args.args_selection.regularization = "None"
    args.args_selection.loss_regularization = "None" # L1, L2 

    iterator = ExperimentIterator(iter_cste= True, iter_lambda=False, iter_rate=False)
    count = 0

    for index, dataset_name, dataset, loader in iterator.__iter__(args) :

        args.args_train.nb_sample_z_train_monte_carlo = 1 # Number of samples for monte carlo gradient estimator
        args.args_train.nb_sample_z_train_IWAE = 1 # Number K in the IWAE-similar loss 
        args.args_classification.nb_imputation_mc = 1
        args.args_classification.nb_imputation_iwae = 1 

        args_local = create_name(args, dataset_name, name_experiment, count)
        
        count = multiple_experiment(
                    count,
                    dataset,
                    loader,
                    complete_args=args_local,
                    name_modification = True)
