
import os
import sys
current_file_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(current_file_path)
from default_parameter import *
from missingDataTrainingModule.Prediction import *
from missingDataTrainingModule.Selection import *

nb_epoch_pretrain = 0
nb_epoch = 0
fix_classifier_parameters = False
post_hoc = False

name_experiment = "CELEBA_experiment_withpretraining"



list_test_mc = []

def get_default_local(args):
    args.args_test.liste_mc = list_test_mc
    args.args_dataset.args_dataset_parameters.batch_size_test = 100
    args.args_dataset.args_dataset_parameters.batch_size_train = 16

    args.args_selection.input_size_selector = (3, 128, 128)
    args.args_selection.output_size_selector = (1, 128, 128)
    args.args_classification.input_size_prediction_module = (3, 128, 128) # Size before imputation
    args.args_selection.kernel_size = (4,4)
    args.args_selection.kernel_stride = (4,4)


    args.args_train.post_hoc = post_hoc
    args.args_train.argmax_post_hoc = False
    args.args_train.post_hoc_guidance = None
    args.args_train.nb_epoch_pretrain = nb_epoch_pretrain # Training the complete model 
    args.args_train.nb_epoch = nb_epoch # Training the complete model
    args.args_train.fix_classifier_parameters = fix_classifier_parameters

    return args
    