
import argparse
import torch

from src.save_data import SaveData
from src.runner import Runner


def main(model_name: str = "cnn_128", 
         data_type:str = "raw", 
         val_methods: list = ["first_window", "shift_windows", "pw_shift_windows"],
         learning_rate: list = 3e-4, 
         datasets_names: list = ['supervised'], 
         train_ratio: float = 0.7,
         n_steps: int = 50,
         total_timesteps: int = 50,
         n_val : int = 50, 
         epochs: int = 1,
         batch_size: int = 32,
         patience: int = 10,
         delta: float = 0.01,
         seed: int = 42,
         ressources: str = 'cpu', 
         testing : bool = False ) -> None:
    """
    Runs the benchmark for the given model and dataset.
    Args:
        runs (int, optional): Number of runs for the evaluation. Defaults to 1.
        model (str, optional): Model to be used for the evaluation. Defaults to "Linear".
        dataset_name (str, optional): Dataset to be used for the evaluation. Defaults to "KDD21".
        observation_period (int, optional): Size of the sliding window in number of points fro Training. Defaults to 100.

    Returns:
        None
    """
    device = torch.device("cuda" if (torch.cuda.is_available() and ressources == 'gpu') else "cpu")
    print(f" cuda availability : {torch.cuda.is_available()}") 
    print(f"Device: {device}")

    datasets_str = "_".join(datasets_names)
    dirname = f"train_{train_ratio}_patience{patience}_delta{delta}"

    if testing:
        path_save_dir = f"results/results/testing/{data_type}/{model_name}/{datasets_str}/{total_timesteps}_total_timesteps/{epochs}_epochs/{n_val}_eval/{learning_rate}_lr/batch_{batch_size}/seed_{seed}/nsteps_{n_steps}/{dirname}"
    else : 
        path_save_dir = f"results/results/{data_type}/{model_name}/{datasets_str}/{total_timesteps}_total_timesteps/{epochs}_epochs/{n_val}_eval/{learning_rate}_lr/batch_{batch_size}/seed_{seed}/nsteps_{n_steps}/{dirname}"

    res_dict_plots, res_dict_models, stoppers = Runner(model_name = model_name,
                                                        data_type = data_type,
                                                        val_methods = val_methods,
                                                        learning_rate = learning_rate,
                                                        datasets_names = datasets_names,
                                                        train_ratio = train_ratio, 
                                                        n_steps = n_steps,
                                                        epochs = epochs, 
                                                        total_timesteps = total_timesteps,
                                                        n_val = n_val, 
                                                        batch_size= batch_size,
                                                        patience = patience,
                                                        delta = delta,
                                                        seed = seed,
                                                        device = device,
                                                        testing = testing,
                                                        path_save_dir = path_save_dir).run()

    SaveData(model_name=model_name,
             data_type=data_type,
             val_methods=val_methods,
             learning_rate=learning_rate, 
             datasets_names=datasets_names,
             n_steps=n_steps,
             total_timesteps=total_timesteps,
             n_val = n_val, 
             batch_size=batch_size,
             epochs=epochs,
             train_ratio=train_ratio, 
             patience=patience,
             delta=delta,
             seed=seed,
             testing=testing, 
             path_save_dir=path_save_dir).save_results(res_dict_plots = res_dict_plots, 
                                     res_dict_models = res_dict_models,
                                     stoppers = stoppers)


if __name__ == "__main__":

    parser = argparse.ArgumentParser(
        description="Launch MRAD on the TSB-UAD benchmark. Outputs a plots of performances and two .pkl files containing the model and a dictionnary to reproduce the plot")
    
    parser.add_argument('--data_type', type=str, default="raw")
    parser.add_argument('--model_name', type=str, default="cnn_128",help='Model to be used for the evaluation')
    parser.add_argument('--val_methods', nargs="+", default=["first_window", "shift_windows", "pw_shift_windows"], help='Validation methods to be used for the evaluation')
    parser.add_argument('--learning_rate', type=float, default=3e-4, help='Learning rate to be used for the evaluation')
    parser.add_argument("--datasets_names", nargs="+", default= ['OPPORTUNITY'], help="List of datasets")
    parser.add_argument('--train_ratio', type=float, default=0.7, help='Ratio of the dataset to be used for training')
    parser.add_argument('--n_steps', type=int, default=50, help='Number of steps collecting rewards')
    parser.add_argument('--epochs', type=int, default=1, help='Number of time each batch is used for training')
    parser.add_argument('--total_timesteps', type=int, default=10, help='Total number of timesteps for the training')
    parser.add_argument('--n_val', type=int, default=10, help='Total number of validations steps for the training')
    parser.add_argument('--batch_size', type=int, default=32, help='Batch size for the evaluation')
    parser.add_argument('--patience', type=int, default=10, help='Patience for the evaluation')
    parser.add_argument('--delta', type=float, default=0.01, help='Smallest gap considered as a loss improvement for the evaluation')
    parser.add_argument('--seed', type=int, default=42, help='Seed for the evaluation')
    parser.add_argument('--ressources', type=str, default='cpu', help='Ressources to be used for the evaluation')
    parser.add_argument('--testing', type=bool, default=False, help='Run in testing mode')
    args = parser.parse_args()
    
    main(
        model_name=args.model_name,
        data_type=args.data_type,
        val_methods=args.val_methods,
        learning_rate=args.learning_rate,
        datasets_names=args.datasets_names,
        train_ratio=args.train_ratio,
        n_steps=args.n_steps,
        total_timesteps=args.total_timesteps,
        n_val=args.n_val,
        epochs=args.epochs,
        batch_size=args.batch_size,
        patience=args.patience,
        delta=args.delta,
        seed=args.seed,
        ressources=args.ressources,
        testing=args.testing
    )













