
import os
import sys
current_file_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(current_file_path)
from default_parameter import *
from missingDataTrainingModule.Prediction import *
from missingDataTrainingModule.Selection import *

cov_type = "spherical"
cov_name = "spherical"
cov = torch.tensor(1.0, dtype=torch.float32)

classification = True
scale_regression = True
nb_epoch_pretrain = 0
nb_epoch = 1000
fix_classifier_parameters = False
post_hoc = False

name_experiment = "SYN_DATASET_UsingRealSeedHighIWAE"

list_test_mc = []

def get_default_local(args):
    args.args_test.liste_mc = list_test_mc


    args.args_dataset.args_dataset_parameters.batch_size_test = 1000
    args.args_dataset.args_dataset_parameters.classification = classification
    args.args_dataset.args_dataset_parameters.scale_regression = scale_regression
    args.args_dataset.args_dataset_parameters.epsilon_sigma = None
    args.args_dataset.args_dataset_parameters.cov = cov
    args.args_dataset.args_dataset_parameters.covariance_type = cov_type


    args.args_selection.input_size_selector = (1,11)
    args.args_selection.output_size_selector = (1,11)
    args.args_classification.input_size_prediction_module = (1,11) # Size before imputation

    args.args_train.post_hoc = post_hoc
    args.args_train.argmax_post_hoc = False
    args.args_train.post_hoc_guidance = None
    args.args_train.nb_epoch_pretrain = nb_epoch_pretrain # Training the complete model 
    args.args_train.nb_epoch = nb_epoch # Training the complete model
    args.args_train.fix_classifier_parameters = fix_classifier_parameters

    return args
    