
import os
import sys
current_file_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(current_file_path)
from default_parameter import *


from multiple_experiment_launcher import  get_dataset

class ExperimentIterator():
    def __init__(self, iter_cste = False, iter_lambda = False, iter_rate = False, add_noise = False) -> None:
        
        self.list_selector = ["SelectorLVL2"]
        self.list_classifier = ["RealXClassifier"]
        self.list_seed = [(0,1), (1,2), (2,3), (3,4), (5,6)]

        

        self.list_loss_function = ["NLL"]
        self.list_mc_iwae = [(1, 100, 1, 1,)]
        self.list_dataset = [("Syn4", Syn4),
                             ("Syn5", Syn5),
                             ("Syn6", Syn6),]

        if iter_cste :
            self.list_cste = [0, 1, 3, -1, -3]
        else :
            self.list_cste = [None]
        
        if iter_lambda :
            self.list_lambda = [0.01, 0.05, 0.1, 0.3, 0.5, 1.0, 5.0, ]
        else :
            self.list_lambda = [None]
        
        if iter_rate :
            self.list_rate = [2./11, 3./11, 4./11, 5./11, 6./11, 7./11, 9./11,]
        else :
            self.list_rate = [None]



    def __iter__(self, args):
        index = 0

        for dataset_name, dataset in self.list_dataset :
            args.args_dataset.dataset = dataset_name
            for train_seed, test_seed in self.list_seed :
                args.args_dataset.args_dataset_parameters.train_seed = train_seed
                args.args_dataset.args_dataset_parameters.test_seed = test_seed
                dataset, loader = get_dataset(args)
                for cste in self.list_cste:
                    args.args_classification.cste_imputation = cste
                    for lambda_reg in self.list_lambda:
                        args.args_selection.lambda_reg = lambda_reg
                        for rate in self.list_rate :
                            args.args_selection.rate = rate
                            for classifier in self.list_classifier :
                                args.args_classification.classifier = classifier
                                for selector in self.list_selector :
                                    args.args_selection.selector = selector
                                    for mc_mask, iwae_mask, mc_sample, iwae_sample in self.list_mc_iwae :
                                        args.args_train.nb_sample_z_train_monte_carlo = mc_mask # Number of samples for monte carlo gradient estimator
                                        args.args_train.nb_sample_z_train_IWAE = iwae_mask # Number K in the IWAE-similar loss 
                                        args.args_classification.nb_imputation_mc = mc_sample
                                        args.args_classification.nb_imputation_iwae = iwae_sample
                                        for loss in self.list_loss_function:
                                            args.args_train.loss_function = loss # NLL, MSE
                                            index +=1
                                            yield index, dataset_name, dataset, loader
                
