import argparse
import torch

from src.utils.save_data import SaveData
from src.runner import Runner


#test modification 
def main(model_name: str = "linear", 
         data_type:str = "features", 
         val_methods: list = ["first_window", "shift_windows", "pw_shift_windows", "pw_slide_windows"],
         learning_rate: list = 0.01, 
         loss_name: str = "mse",
         datasets_names: list = ['supervised'], 
         train_ratio: float = 0.7,
         epochs: int = 100,
         batch_size: int = 8,
         patience: int = 10,
         delta: float = 0.01,
         seed: int = 42,
         ressources: str = 'cpu', 
         testing : bool = False ) -> None:

    device = torch.device("cuda" if (torch.cuda.is_available() and ressources == 'gpu') else "cpu")
    print(f" cuda availability : {torch.cuda.is_available()}") 
    print(f"Device: {device}")

    datasets_str = "-".join(datasets_names)

    if testing:
        path_save_dir = f"results/results/testing/{data_type}/{model_name}/{datasets_str}/{epochs}_epochs/{learning_rate}_lr/batch_{batch_size}/seed_{seed}/loss_{loss_name}/train_{train_ratio}/patience_{patience}/delta_{delta}/{datasets_str}"
    else : 
        path_save_dir = f"results/results/{data_type}/{model_name}/{datasets_str}/{epochs}_epochs/{learning_rate}_lr/batch_{batch_size}/seed_{seed}/loss_{loss_name}/train_{train_ratio}/patience_{patience}/delta_{delta}/{datasets_str}"

    runner = Runner(path_save_dir = path_save_dir,
                    model_name = model_name,
                    data_type = data_type,
                    val_methods = val_methods,
                    learning_rate = learning_rate,
                    loss_name = loss_name,
                    datasets_names = datasets_names,
                    train_ratio = train_ratio, 
                    epochs = epochs, 
                    batch_size= batch_size,
                    patience = patience,
                    delta = delta,
                    seed = seed,
                    device = device,
                    testing = testing)

    res_dict_plots, res_dict_models, stoppers = runner.evaluate()

    SaveData(path_save_dir = path_save_dir,
             model_name=model_name,
             data_type=data_type,
             val_methods=val_methods,
             learning_rate=learning_rate, 
             loss_name=loss_name, 
             datasets_names=datasets_names,
             train_ratio=train_ratio, 
             epochs=epochs,
             batch_size=batch_size,
             patience=patience,
             delta=delta,
             seed=seed,
             testing=testing).save_results(res_dict_plots = res_dict_plots, 
                                        res_dict_models = res_dict_models,
                                        stoppers = stoppers)

if __name__ == "__main__":

    parser = argparse.ArgumentParser(
        description="Launch MRAD on the TSB-UAD benchmark. Outputs a plots of performances and two .pkl files containing the model and a dictionnary to reproduce the plot")
    
    parser.add_argument('--model_name', type=str, default="linear",help='Model to be used for the evaluation')
    parser.add_argument('--data_type', type=str, default="features")
    parser.add_argument('--val_methods', nargs="+", help='Validation methods to be used for the evaluation')
    parser.add_argument('--learning_rate', type=float, default=0.01, help='Learning rate to be used for the evaluation')
    parser.add_argument('--loss_name', type=str, default="mse", help='Loss to be used for the evaluation')
    parser.add_argument("--datasets_names", nargs="+", help="List of datasets")
    parser.add_argument('--train_ratio', type=float, default=0.7, help='Ratio of the dataset to be used for training')
    parser.add_argument('--epochs', type=int, default=100, help='Number of epochs for the evaluation')
    parser.add_argument('--batch_size', type=int, default=8, help='Batch size for the evaluation')
    parser.add_argument('--patience', type=int, default=10, help='Patience for the evaluation')
    parser.add_argument('--delta', type=float, default=0.01, help='Smallest gap considered as a loss improvement for the evaluation')
    parser.add_argument('--seed', type=int, default=42, help='Seed for the evaluation')
    parser.add_argument('--ressources', type=str, default='cpu', help='Ressources to be used for the evaluation')
    parser.add_argument('--testing', type=bool, default=False, help='Run in testing mode')
    args = parser.parse_args()
    
    main(
        model_name=args.model_name,
        data_type=args.data_type,
        val_methods=args.val_methods,
        learning_rate=args.learning_rate,
        loss_name=args.loss_name,
        datasets_names=args.datasets_names,
        train_ratio=args.train_ratio,
        epochs=args.epochs,
        batch_size=args.batch_size,
        patience=args.patience,
        delta=args.delta,
        seed=args.seed,
        ressources=args.ressources,
        testing=args.testing
    )











