
import os
import sys
current_file_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(current_file_path)
from default_parameter_autoencoder import *
from multiple_experiment_launcher import multiple_experiment, get_dataset
from missingDataTrainingModule import GaussianMixtureImputation, train_gmm, experiment



def train_autoencoder(dataset_init, default_args):
    args_autoencoder = get_default_autoencoder()

    args_autoencoder.args_classification.input_size_prediction_module = (1,28,56)
    args_autoencoder.args_selection.input_size_selector = (1,28,56)
    args_autoencoder.args_selection.output_size_selector = (1,28,56)
    args_autoencoder.args_train.nb_epoch = 100
    args_autoencoder.args_classification.add_mask = True
    args_autoencoder.args_classification.reconstruction_regularization = "AutoEncoderLatentReconstruction" # Posssibility Autoencoder regularization (the output of the autoencoder is not given to classification, simple regularization of the mask)
    args_autoencoder.args_classification.network_reconstruction = "self" # Posssibility Autoencoder regularization (the output of the autoencoder is not given to classification, simple regularization of the mask)
    args_autoencoder.args_classification.lambda_reconstruction = 0.1 # Parameter for controlling the reconstruction regularization
    

    args_autoencoder.args_dataset.args_dataset_parameters.dataset = dataset_init

    path_for_weight = os.path.join(args_autoencoder.args_output.folder, "weights", default_args.args_dataset.dataset + "_autoencoder") 
    args_autoencoder.args_output.path = path_for_weight
    args_autoencoder.args_output.folder = path_for_weight

    dataset, loader = get_dataset(args_autoencoder)
    experiment(dataset,
                loader,
                complete_args=args_autoencoder,
                )



if __name__ == "__main__":

    from default_parameter import get_default
    args = get_default()
    args.args_dataset.dataset = "MNIST_and_FASHIONMNIST"
    dataset_init, loader_init = get_dataset(args )
    train_autoencoder(dataset_init, args)