
import argparse
import torch

from src.save_data import SaveData
from src.runner import Runner


def main(model_name: str = "cnn_128", 
         data_type:str = "raw", 
         val_methods: list = ["first_window", "shift_windows", "pw_shift_windows"],
         datasets_names: list = ['supervised'], 
         train_ratio: float = 0.7,
         num_generations: int = 1, 
         num_parents_mating: int = 5,
         num_solutions: int = 5,
         sample_size: int = 32,
         patience: int = 10,
         delta: float = 0.01,
         seed: int = 42,
         ressources: str = 'cpu', 
         testing : bool = False ) -> None:
    device = torch.device("cuda" if (torch.cuda.is_available() and ressources == 'gpu') else "cpu")
    print(f" cuda availability : {torch.cuda.is_available()}") 
    print(f"Device: {device}")

    dataset_str = "_".join(datasets_names)

    if testing:
        path_save_dir = f"results/results/testing/{data_type}/{model_name}/{dataset_str}/{num_generations}_generations/{num_solutions}_solutions/{num_parents_mating}_parents_mating/{sample_size}_samples/seed_{seed}/train_{train_ratio}_patience{patience}_delta{delta}_{dataset_str}"
    else : 
        path_save_dir = f"results/results/{data_type}/{model_name}/{dataset_str}/{num_generations}_generations/{num_solutions}_solutions/{num_parents_mating}_parents_mating/{sample_size}_samples/seed_{seed}/train_{train_ratio}_patience{patience}_delta{delta}_{dataset_str}"

    res_dict_plots, res_dict_models, stoppers = Runner(model_name = model_name,
                                                        data_type = data_type,
                                                        val_methods = val_methods,
                                                        datasets_names = datasets_names,
                                                        train_ratio = train_ratio, 
                                                        num_generations = num_generations,
                                                        num_parents_mating = num_parents_mating,
                                                        num_solutions = num_solutions, 
                                                        sample_size= sample_size,
                                                        patience = patience,
                                                        delta = delta,
                                                        seed = seed,
                                                        device = device,
                                                        testing = testing, 
                                                        path_save_dir = path_save_dir).run()

    SaveData(model_name=model_name,
             data_type=data_type,
             val_methods=val_methods,
             datasets_names=datasets_names,
             num_generations=num_generations,
             num_parents_mating=num_parents_mating,
             num_solutions=num_solutions,
             sample_size=sample_size,
             train_ratio=train_ratio, 
             patience=patience,
             delta=delta,
             seed=seed,
             testing=testing, 
             path_save_dir=path_save_dir).save_results(res_dict_plots = res_dict_plots, 
                                                    res_dict_models = res_dict_models,
                                                    stoppers = stoppers)

if __name__ == "__main__":

    parser = argparse.ArgumentParser(
        description="Launch MRAD on the TSB-UAD benchmark. Outputs a plots of performances and two .pkl files containing the model and a dictionnary to reproduce the plot")
    
    parser.add_argument('--data_type', type=str, default="raw")
    parser.add_argument('--model_name', type=str, default="cnn_128",help='Model to be used for the evaluation')
    parser.add_argument('--val_methods', nargs="+", default=["first_window", "shift_windows", "pw_shift_windows"], help='Validation methods to be used for the evaluation')
    parser.add_argument("--datasets_names", nargs="+", default= ['OPPORTUNITY'], help="List of datasets")
    parser.add_argument('--train_ratio', type=float, default=0.7, help='Ratio of the dataset to be used for training')
    parser.add_argument('--num_eval', type=int, default=10, help='Total number of validations steps for the training')
    parser.add_argument('--num_generations', type=int, default=1, help='Number of generations for the evaluation')
    parser.add_argument('--num_parents_mating', type=int, default=5, help='Number of parents mating for the evaluation')
    parser.add_argument('--num_solutions', type=int, default=5)
    parser.add_argument('--sample_size', type=int, default=60, help='the nb of times series used to evaluate all solutions before validation on the best solution with the full validation set')
    parser.add_argument('--patience', type=int, default=10, help='Patience for the evaluation')
    parser.add_argument('--delta', type=float, default=0.01, help='Smallest gap considered as a loss improvement for the evaluation')
    parser.add_argument('--seed', type=int, default=42, help='Seed for the evaluation')
    parser.add_argument('--ressources', type=str, default='cpu', help='Ressources to be used for the evaluation')
    parser.add_argument('--testing', type=bool, default=False, help='Run in testing mode')
    args = parser.parse_args()
    
    main(
        model_name=args.model_name,
        data_type=args.data_type,
        val_methods=args.val_methods,
        datasets_names=args.datasets_names,
        train_ratio=args.train_ratio,
        num_generations=args.num_generations,
        num_parents_mating=args.num_parents_mating,
        num_solutions=args.num_solutions,
        sample_size=args.sample_size,
        patience=args.patience,
        delta=args.delta,
        seed=args.seed,
        ressources=args.ressources,
        testing=args.testing
        )













