
import os
import sys
current_file_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(current_file_path)
from default_parameter import *


from multiple_experiment_launcher import  get_dataset

class ExperimentIterator():
    def __init__(self, iter_cste = False, iter_lambda = False, iter_rate = False, add_noise = False) -> None:
        
        self.list_selector = ["SelectorUNET"]
        self.list_classifier = ["ConvClassifier2"]
        self.list_seed = [(0,1),]
        self.add_noise = add_noise

        

        self.list_loss_function = ["NLL"]
        self.list_mc_iwae = [(1, 10, 1, 1,)]
        self.list_dataset = [("MNIST_and_FASHIONMNIST", MNIST_and_FASHIONMNIST)]

        if iter_cste :
            self.list_cste = [0, 1, 3, -1, -3]
        else :
            self.list_cste = [None]
        
        if iter_lambda :
            self.list_lambda = [0.01, 0.05, 0.1, 0.3, 0.5,]
        else :
            self.list_lambda = [0.0]
        
        if iter_rate :
            self.list_rate = [0.01, 0.05, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.6, 0.8]
        else :
            self.list_rate = [None]



    def __iter__(self, args):
        index = 0

        for dataset_name, dataset in self.list_dataset :
            args.args_dataset.dataset = dataset_name
            for train_seed, test_seed in self.list_seed :
                args.args_dataset.args_dataset_parameters.train_seed = train_seed
                args.args_dataset.args_dataset_parameters.test_seed = test_seed
                args.args_dataset.args_dataset_parameters.add_noise = self.add_noise
                args.args_dataset.args_dataset_parameters.target_mnist = False
                dataset, loader = get_dataset(args)
                for cste in self.list_cste:
                    args.args_classification.cste_imputation = cste
                    for lambda_reg in self.list_lambda:
                        args.args_selection.lambda_reg = lambda_reg
                        for rate in self.list_rate :
                            args.args_selection.rate = rate
                            for classifier in self.list_classifier :
                                args.args_classification.classifier = classifier
                                for selector in self.list_selector :
                                    args.args_selection.selector = selector
                                    for mc_mask, iwae_mask, mc_sample, iwae_sample in self.list_mc_iwae :
                                        args.args_train.nb_sample_z_train_monte_carlo = mc_mask # Number of samples for monte carlo gradient estimator
                                        args.args_train.nb_sample_z_train_IWAE = iwae_mask # Number K in the IWAE-similar loss 
                                        args.args_classification.nb_imputation_mc = mc_sample
                                        args.args_classification.nb_imputation_iwae = iwae_sample
                                        for loss in self.list_loss_function:
                                            args.args_train.loss_function = loss # NLL, MSE
                                            index +=1
                                            yield index, dataset_name, dataset, loader
                
