
import os
import sys
current_file_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(current_file_path)
from default_parameter import *

nb_epoch_pretrain = 0
nb_epoch = 50
fix_classifier_parameters = False
post_hoc = False

name_experiment = "MNIST_and_FASHIONMNIST_L1_realparam_longer"



list_test_mc = []

def get_default_local(args):
    args.args_test.liste_mc = list_test_mc
    args.args_dataset.args_dataset_parameters.batch_size_test = 64
    args.args_dataset.args_dataset_parameters.batch_size_train = 64

    args.args_selection.input_size_selector = (1, 28, 56)
    args.args_selection.output_size_selector = (1, 28, 56)
    args.args_classification.input_size_prediction_module = (1, 28, 56) # Size before imputation

    args.args_train.post_hoc = post_hoc
    args.args_train.argmax_post_hoc = False
    args.args_train.post_hoc_guidance = None
    args.args_train.nb_epoch_pretrain = nb_epoch_pretrain # Training the complete model 
    args.args_train.nb_epoch = nb_epoch # Training the complete model
    args.args_train.fix_classifier_parameters = fix_classifier_parameters

    return args
    