import ast
from distutils.util import strtobool
import datetime
import json
import os
import sys
import argparse
from datetime import datetime
from typing import List, Tuple, Optional, Iterable

import torch
from torch.nn import L1Loss, MSELoss
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

from BaselineModel.BaselineBaseLSTMRegressor import BaseLSTMRegressor
from BaselineModel.BaselineModelLoader import load_single_ae_model
from BaselineModel.BaselineCombinedModel import CombinedLSTMRegressor
from BaselineModel.BaselineLSTMRankingEvaluation import lstm_maps_ranking_evaluation, _actual_lstm_maps_eval
from DataHandling.FileBasedDatasetBase import FileBasedDataset
from Utils import logger, get_logger, config as cfg
from Utils.utils import function_start_save_params
from Utils.Constants import Diff as Const, FileNamesConstants as Names
from ModelsUtils.ModelCallbacks import ModelCheckpointCallback
from ModelsUtils.ModelTrainer import ModelTrainer
from ModelsUtils.Metrics import regression_log_loss
from DataHandling.DataUtils import stats_folders_train_test_split
from BaselineModel.NaiveDataHandling import LSTMNaiveFeatureSelection, PipeLines, LSTMNaiveDataset
from BaselineLSTMRegressionModel import LSTMRegressor
from DataHandling.LSTMSingleFeatureMapBasedDataset import RegressionSingleFeatureMapDataset
from DataHandling.LSTMCombinedFeatureMapsDataset import CombinedRegressionMapsDataset


def _folder_to_train_val_files(folders, is_weights, val_size) -> Tuple[List[str], List[str]]:
    """
    Split folder into train and validation files
    :param folders:
    :param is_weights:
    :param val_size:
    :return: Tuple of lists of train files and validation files
    """
    all_folders = [os.path.join(curr_folder, 'stats') for curr_folder in folders]
    train_files, val_files = stats_folders_train_test_split(is_weights, all_folders, val_size)
    return train_files, val_files


def _create_pipe_input_data(folders: List, is_weights: bool, val_size: float, is_single_layer: bool,
                            variance_th: float, num_features: int):
    """
    Train or create inputs for basic pipeline
    :param folders: list of folder with stats folder within and results files
    :param is_weights:
    :param val_size:
    :param is_single_layer:
    :param variance_th: variance th for basic pipeline
    :param num_features: num of features for basic pipeline
    :return:
    """
    results_folders = folders
    train_files, val_files = _folder_to_train_val_files(folders, is_weights, val_size)
    pipeline, pipeline_name = PipeLines.get_initial_pipeline(variance_th, num_features)
    fs = LSTMNaiveFeatureSelection(train_files, val_files, pipeline, pipeline_name=pipeline_name,
                                   is_weights=is_weights, is_single_layer=is_single_layer, layers_first=True,
                                   model_input_save_dir='/sise/group/inputs/lstm_1')
    fs.create_lstm_input(results_folders)


def train_validate_pipeline_lstm(folders: List[str], is_weights: bool, val_size: float, is_single_layer: bool,
                                 variance_th: float, num_features: int, save_base_dir: str, model_save_dir: Optional[str],
                                 epochs: int, batch_size: int,
                                 lstm_hidden_size: int, lstm_num_layers: int, model_dense_sizes: Tuple,
                                 sequence_size: int, batch_first: bool, ):
    """
    Train and validate a LSTM regression model with a basic pipeline
    :param folders:
    :param is_weights:
    :param val_size:
    :param is_single_layer:
    :param variance_th:
    :param num_features:
    :param save_base_dir: base directory where inputs folder exists and where the model data will be saved
    :param model_save_dir: directory where to save model outputs, if not an absolute path will use the save_base_dir as
                           the base directory where to create and save the model_save_dir.
                           If None will use the pipeline name with model_prefix and time suffix
    :param epochs:
    :param batch_size:
    :param lstm_hidden_size:
    :param lstm_num_layers:
    :param model_dense_sizes:
    :param sequence_size:
    :param batch_first:
    :return:
    """
    test_parameters = locals().copy()
    pipeline, pipeline_name = PipeLines.get_initial_pipeline(variance_th, num_features)
    test_parameters['pipeline_name'] = pipeline_name
    test_parameters['config'] = cfg.to_json()
    if model_save_dir is None:
        model_save_dir = f'model_{pipeline_name}_{datetime.now().strftime("%Y_%m_%d_%H_%M")}'
    if not os.path.isabs(model_save_dir):
        model_save_dir = os.path.join(save_base_dir, model_save_dir)
    if not os.path.exists(model_save_dir):
        os.makedirs(model_save_dir, exist_ok=True)
    with open(os.path.join(model_save_dir, Names.RUN_PARAMETERS), "w") as f:
        json.dump(test_parameters, f)

    logger().force_log_and_print('BaselineLSTMTraining::train_validate_pipeline_lstm',
                                 f'save_base_dir={save_base_dir}, model_save_dir={model_save_dir}')

    train_files, val_files = _folder_to_train_val_files(folders, is_weights, val_size)
    for name, data in zip([Names.TRAIN_FILES, Names.VAL_FILES], [train_files, val_files]):
        with open(os.path.join(model_save_dir, name), "w") as f:
            json.dump(data, f)

    train_dataset, val_dataset, _ =\
        LSTMNaiveDataset.create_dataset(train_files, val_files, test_files=None, results_files_locations=folders,
                                        result_metric_name='accuracy', pre_load_mode=True,
                                        feature_selection_pipeline=pipeline, pipeline_name=pipeline_name,
                                        is_weights=is_weights, is_single_layer=is_single_layer,
                                        model_inputs_save_dir=save_base_dir)

    lstm_model = LSTMRegressor(embedding_size=num_features, hidden_size=lstm_hidden_size, lstm_layers=lstm_num_layers,
                               inner_dense_layer_sizes=model_dense_sizes, sequence_len=sequence_size,
                               tb_log_path=model_save_dir, batch_first=batch_first, last_layer='sigmoid',
                               bi_directional=True)

    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

    # Maybe should allow for more parameters in this function
    metrics = [L1Loss()]
    metric_name = ['MAE']
    trainer = ModelTrainer(model=lstm_model, tb_writer=lstm_model.tb_writer(), optimizer=None, convert_to_double=True)
    logger().log('BaselineLSTMTraining::train_validate_pipeline_lstm', 'Start fitting LSTM regression model')
    trainer.fit(epochs, train_dataloader, val_dataloader, metrics_funcs=metrics, metrics_names=metric_name,
                checkpoint_cb=ModelCheckpointCallback(os.path.join(model_save_dir, 'lstm_reg'), use_loss=True))
    logger().log('BaselineLSTMTraining::train_validate_pipeline_lstm', '!!!!! FINISH !!!!!!! fitting LSTM regression model')
    loss, metric_res = trainer.evaluate(val_dataloader, metrics, metric_name)
    logger().force_log_and_print('BaselineLSTMTraining::train_validate_pipeline_lstm',
                                 f'LSTM regeression with pipeline: {pipeline_name} - Validation final results:'
                                 f' loss={loss}, {metric_name}={metric_res}')


def __lstm_pretrained_encoder_train_test_files(ae_folder: str, maps_folders: Optional[Iterable[str]], val_size: float):
    """
    Split all maps files for LSTM with pre-trained auto encoder according to the files used by auto encoder training
    :param ae_folder:
    :param maps_folders:
    :param val_size:
    :return:
    """
    if maps_folders is None:
        with open(os.path.join(ae_folder, Names.TRAIN_FILES), 'r') as jf:
            train_files = json.load(jf)
        with open(os.path.join(ae_folder, Names.VAL_FILES), 'r') as jf:
            val_files = json.load(jf)
        with open(os.path.join(ae_folder, Names.TEST_FILES), 'r') as jf:
            test_files = json.load(jf)
    else:
        all_maps_files = [os.path.join(curr_folder, curr_file)
                          for curr_folder in maps_folders for curr_file in os.listdir(curr_folder)]
        if val_size != 0:
            with open(os.path.join(ae_folder, Names.TRAIN_FILES), 'r') as jf:
                ae_train_files = json.load(jf)
            with open(os.path.join(ae_folder, Names.VAL_FILES), 'r') as jf:
                ae_val_files = json.load(jf)
            test_files_path = os.path.join(ae_folder, Names.TEST_FILES)
            if os.path.exists(test_files_path):
                with open(test_files_path, 'r') as jf:
                    ae_test_files = json.load(jf)
            else:
                ae_test_files = list()
            train_files = [file for file in all_maps_files if file in ae_train_files]
            val_files = [file for file in all_maps_files if file in ae_val_files]
            test_files = [file for file in all_maps_files if file in ae_test_files]
            remaining_files = [file for file in all_maps_files if file not in train_files and file not in val_files]
            if len(remaining_files) != 0:
                logger().log('BaselineLSTMTraining::_encoder_lstm_train_test_files',
                             'Files not found in ae train or validation and used in LSTM: ', remaining_files)
                extra_train, extra_val = train_test_split(remaining_files, test_size=val_size, random_state=cfg.seed)
                train_files.extend(extra_train)
                val_files.extend(extra_val)
        else:
            logger().warning('BaselineLSTMTraining::_encoder_lstm_train_test_files',
                             'Not using validation in train_validate_encoder_lstm - dataset might have data leakage')
            train_files = all_maps_files
            val_files = list()
            test_files = list()

    return train_files, val_files, test_files


def _lstm_base_train_eval(logger_start: str, loss_name: str, epochs: int, batch_size: int, model_save_dir: str,
                          lstm_reg_model: BaseLSTMRegressor, train_dataset: FileBasedDataset, val_dataset: FileBasedDataset):

    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True) if val_dataset is not None and len(val_dataset) > 0 else None

    if loss_name == 'mse':
        loss_func = MSELoss()
    elif loss_name == 'mae':
        loss_func = L1Loss()
    elif loss_name == 'log':
        loss_func = regression_log_loss
    else:
        raise ValueError('Bad loss name')
    trainer = ModelTrainer(model=lstm_reg_model, tb_writer=lstm_reg_model.tb_writer(), loss_func=loss_func,
                           convert_to_double=lstm_reg_model.is_double())
    model_cp_callback = ModelCheckpointCallback(save_folder=os.path.join(model_save_dir, 'cp'),
                                                use_train=False, use_loss=True)
    logger().log(logger_start, f'Start Training LSTM with AE for: {epochs} epochs')
    trainer.fit(epochs, train_dataloader, val_dataloader, checkpoint_cb=model_cp_callback,
                metrics_funcs=[MSELoss(), L1Loss(), regression_log_loss], metrics_names=['MSE', 'MAE', 'reg_log_loss'])
    logger().log(logger_start, 'Finished Training LSTM with AE')
    torch.save(lstm_reg_model.state_dict(), os.path.join(model_save_dir, f'lstm_reg_final.pt'))


def train_validate_pre_trained_encoder_lstm(maps_folders: List[str], results_folders: List[str], ae_folder: str,
                                            ae_model_path: str, result_metric_name: str, val_size: float,
                                            model_save_dir: List[str], epochs: int, batch_size: int,
                                            lstm_hidden_size: List[int], lstm_num_layers: List[int], model_dense_sizes: Tuple[int],
                                            batch_first: bool, sequence_size: int, bi_directional: bool,
                                            loss_name: str, last_layer: str, recall_th: float):
    """
    Train an LSTM that uses a pre-trained and fixed encoder. This used pre-created feature maps.
    :param maps_folders:
    :param results_folders:
    :param ae_folder:
    :param ae_model_path: absolute path to AE model or relative from the ae_folder
    :param result_metric_name: metric to use
    :param val_size:
    :param model_save_dir: Where to save LSTM model
    :param epochs:
    :param batch_size:
    :param lstm_hidden_size:
    :param lstm_num_layers:
    :param model_dense_sizes:
    :param batch_first:
    :param sequence_size:
    :param bi_directional: use bidirectional lstm
    :param loss_name:
    :param last_layer: one of sigmoid/ linear
    :return:
    """
    target_mult_10 = last_layer != 'sigmoid'
    pre_load_maps = True

    logger_start = 'BaselineLSTMTraining::train_validate_encoder_lstm'
    func_parameters = locals().copy()
    func_parameters['config'] = cfg.to_json()

    train_files, val_files, test_files = __lstm_pretrained_encoder_train_test_files(ae_folder, maps_folders, val_size)

    for curr_folder in model_save_dir:
        if not os.path.exists(curr_folder):
            os.makedirs(curr_folder, exist_ok=True)
        with open(os.path.join(curr_folder, Names.RUN_PARAMETERS), 'w') as jf:
            json.dump(func_parameters, jf)
        logger().force_log_and_print(logger_start, f'train_validate_encoder_lstm test params {locals()}')

        for name, data in zip([Names.TRAIN_FILES, Names.VAL_FILES, Names.TEST_FILES],
                              [train_files, val_files, test_files]):
            with open(os.path.join(curr_folder, name), "w") as jf:
                json.dump(data, jf)

    if not os.path.isabs(ae_model_path):
        ae_model_path = os.path.join(ae_folder, ae_model_path)
    ae_model, embedding_size = load_single_ae_model(ae_folder, ae_model_path, False)

    train_dataset = RegressionSingleFeatureMapDataset.create_dataset(maps_files=train_files, results_folders=results_folders,
                                                                     result_metric_name=result_metric_name,
                                                                     pre_load_maps=pre_load_maps, auto_encoder=ae_model,
                                                                     maps_folders=None, target_mult_10=target_mult_10)

    if val_size != 0:
        val_dataset = RegressionSingleFeatureMapDataset.create_dataset(maps_files=val_files,
                                                                       results_folders=results_folders,
                                                                       result_metric_name=result_metric_name,
                                                                       pre_load_maps=pre_load_maps, auto_encoder=ae_model,
                                                                       maps_folders=None, target_mult_10=target_mult_10)
    else:
        val_dataset = None

    if len(train_files) != 0:
        test_dataset = RegressionSingleFeatureMapDataset.create_dataset(maps_files=test_files,
                                                                        results_folders=results_folders,
                                                                        result_metric_name=result_metric_name,
                                                                        pre_load_maps=pre_load_maps,
                                                                        auto_encoder=ae_model,
                                                                        maps_folders=None, target_mult_10=target_mult_10)
    else:
        test_dataset = None

    for curr_folder, curr_hidden, curr_num_layers in zip(model_save_dir, lstm_hidden_size, lstm_num_layers):

        lstm_reg_model = LSTMRegressor(embedding_size=embedding_size, hidden_size=curr_hidden, lstm_layers=curr_num_layers,
                                       inner_dense_layer_sizes=model_dense_sizes, sequence_len=sequence_size,
                                       tb_log_path=curr_folder, batch_first=batch_first, last_layer=last_layer,
                                       bi_directional=bi_directional)

        _lstm_base_train_eval(logger_start=logger_start, loss_name=loss_name, model_save_dir=curr_folder,
                              lstm_reg_model=lstm_reg_model, train_dataset=train_dataset, val_dataset=val_dataset,
                              epochs=epochs, batch_size=batch_size)

        logger().log(logger_start, 'Calling Model Ranking')

        _actual_lstm_maps_eval(lstm_model=lstm_reg_model, lstm_save_folder=curr_folder, batch_size=64, recall_th=recall_th,
                               train_dataset=train_dataset, val_dataset=val_dataset, test_dataset=test_dataset)


def train_eval_lstm_combined(results_folders: List[str], ae_weights_folder: str, ae_gradients_folder: str,
                             ae_model_path: str, result_metric_name: str, model_save_dir: List[str], epochs: int, batch_size: int,
                             lstm_hidden_size: List[int], lstm_num_layers: List[int], model_dense_sizes: Tuple[int], use_split_lstm: bool,
                             batch_first: bool, sequence_size: int, loss_name: str, last_layer: str, bi_directional: bool,
                             recall_th: float, augment_few_steps_training: Optional[int]):
    if augment_few_steps_training is None or augment_few_steps_training <= 0:
        augment_few_steps_training = None
    pre_load_maps = True
    target_mult_10 = last_layer != 'sigmoid'
    logger_start = logger_start = 'BaselineLSTMTraining::train_eval_lstm_combined'
    locals_copy = locals().copy()

    train_weights, val_weights, test_weights = __lstm_pretrained_encoder_train_test_files(ae_weights_folder, None, 0)
    train_gradients, val_gradients, test_gradients = __lstm_pretrained_encoder_train_test_files(ae_gradients_folder,
                                                                                                None, 0)

    for curr_folder in model_save_dir:
        function_start_save_params(local_params=locals_copy, extra_data=None, config=cfg,
                                   save_path=os.path.join(curr_folder, Names.RUN_PARAMETERS))
        logger().force_log_and_print(logger_start, f'train_eval_lstm_combined test params {locals_copy}')

        for name, data in zip([Names.TRAIN_FILES, Names.VAL_FILES, Names.TEST_FILES], [train_weights + train_gradients,
                                                                                       val_weights + val_gradients,
                                                                                       test_weights + test_gradients]):
            with open(os.path.join(curr_folder, name), "w") as jf:
                json.dump(data, jf)

    if not os.path.isabs(ae_model_path):
        ae_weights_model_path = os.path.join(ae_weights_folder, ae_model_path)
        ae_gradients_model_path = os.path.join(ae_gradients_folder, ae_model_path)
    else:
        raise NotImplementedError('Not implemented absolute path for ae model')

    ae_weights_model, weights_embedding_size = load_single_ae_model(ae_weights_folder, ae_weights_model_path, False)
    ae_gradients_model, gradients_embedding_size = load_single_ae_model(ae_gradients_folder, ae_gradients_model_path, False)

    train_dataset = CombinedRegressionMapsDataset.create_dataset(
        weights_maps_files=train_weights, weights_encoder=ae_weights_model,
        gradients_maps_files=train_gradients, gradients_encoder=ae_gradients_model,
        results_folders=results_folders, result_metric_name=result_metric_name, pre_load_maps=pre_load_maps,
        target_mult_10=target_mult_10, use_2_lstms=use_split_lstm, augment_few_steps_training=augment_few_steps_training)
    val_dataset = CombinedRegressionMapsDataset.create_dataset(
        weights_maps_files=val_weights, weights_encoder=ae_weights_model,
        gradients_maps_files=val_gradients, gradients_encoder=ae_gradients_model,
        results_folders=results_folders, result_metric_name=result_metric_name, pre_load_maps=pre_load_maps,
        target_mult_10=target_mult_10, use_2_lstms=use_split_lstm, augment_few_steps_training=augment_few_steps_training)
    test_dataset = CombinedRegressionMapsDataset.create_dataset(
        weights_maps_files=test_weights, weights_encoder=ae_weights_model,
        gradients_maps_files=test_gradients, gradients_encoder=ae_gradients_model,
        results_folders=results_folders, result_metric_name=result_metric_name, pre_load_maps=pre_load_maps,
        target_mult_10=target_mult_10, use_2_lstms=use_split_lstm, augment_few_steps_training=None)

    for curr_folder, curr_hidden, curr_num_layers in zip(model_save_dir, lstm_hidden_size, lstm_num_layers):
        if use_split_lstm:
            lstm_reg_model = CombinedLSTMRegressor(weights_embedding_size=weights_embedding_size, hidden_size=curr_hidden,
                                                   gradients_embedding_size=gradients_embedding_size, lstm_layers=curr_num_layers,
                                                   inner_dense_layer_sizes=model_dense_sizes, sequence_len=sequence_size,
                                                   tb_log_path=curr_folder, batch_first=batch_first, last_layer=last_layer,
                                                   bi_directional=bi_directional)
        else:
            lstm_reg_model = LSTMRegressor(embedding_size=weights_embedding_size+gradients_embedding_size,
                                           hidden_size=curr_hidden, lstm_layers=curr_num_layers,
                                           inner_dense_layer_sizes=model_dense_sizes, tb_log_path=curr_folder,
                                           sequence_len=sequence_size, batch_first=batch_first, last_layer=last_layer,
                                           bi_directional=bi_directional)

        _lstm_base_train_eval(logger_start=logger_start, model_save_dir=curr_folder, loss_name=loss_name,
                              lstm_reg_model=lstm_reg_model, train_dataset=train_dataset, val_dataset=val_dataset,
                              epochs=epochs, batch_size=batch_size)

        logger().log(logger_start, 'Calling Model Ranking')
        _actual_lstm_maps_eval(lstm_model=lstm_reg_model, lstm_save_folder=curr_folder, batch_size=64,
                               train_dataset=train_dataset, val_dataset=val_dataset, test_dataset=test_dataset,
                               recall_th=recall_th)


def parse_args():
    parser = argparse.ArgumentParser(description='LSTM Training')
    parser.add_argument('-f', '--folders', nargs="+", type=str, default=None,
                        help='Dump folders that contain results files and/or stats directory and/or feature maps')
    parser.add_argument('-s', '--model_save_dir', type=str, nargs='+')
    parser.add_argument('-rf', '--results_folders', nargs='+', type=str,
                        help='Folders that contain relevant results files')
    parser.add_argument('-m', '--mode', type=str,
                        help='options are:\n* inp - input for LSTM input creation\n'
                             '* inp - create basic pipeline lsmt input'
                             '* train_pipe - for training LSTM with basic pipeline for features\n'
                             '* train_ae - for training LSTM with auto encoder\n'
                             '* rank_lstm - evaluate lstm ranking')
    parser.add_argument('-wm', '--weights_modes', type=str, default='True',
                        help='True for using weights stats, False for gradients')
    parser.add_argument('-vs', '--val_size', type=float, default=0.2)
    parser.add_argument('-isl', '--is_single_layer', type=bool, default=True,
                        help='True for single layer stats, False for multi layer stats')

    # Pipeline args
    parser.add_argument('-v', '--variance_threshold', type=float, default=0.3,
                        help='Variance threshold for initial feature selection')
    parser.add_argument('-nf', '--num_features', type=int, default=300,
                        help='Number of features fo pipeline feature selection')
    # LSTM args
    parser.add_argument('-bs', '--batch_size', type=int, default=32)
    parser.add_argument('-e', '--epochs', type=int, default=150)
    parser.add_argument('-lstm_size', '--lstm_hidden_size', type=int, nargs='+')
    parser.add_argument('-lstm_layers', '--lstm_num_layers', type=int, nargs='+')
    parser.add_argument('-reg_dense', '--regressor_dense_sizes', type=int, nargs='+')
    parser.add_argument('-lstm_model', '--lstm_model_save', type=str, default='cp/model_state_cp/model.pt')
    parser.add_argument('--lstm_eval_save_folder', type=str, default=None)
    parser.add_argument('-last_layer', '--last_layer', type=str, default=None)
    parser.add_argument('-loss_name', '--loss_name', type=str, default=None)
    parser.add_argument('--split_combined_lstm', type=str, default=None)
    parser.add_argument('-bidi', '--bidirectional', type=str, default='False')
    parser.add_argument('-augment', '--augment_few_steps_training', type=int, default=0)

    # Auto Encoder args
    parser.add_argument('-ae_folder', '--ae_save_folder', type=str, default=None,
                        help='In case of combined this is the weights ae')
    parser.add_argument('--ae_gradients_save', type=str, default=None)
    parser.add_argument('-ae_model', '--ae_model_save', type=str, default='cp/model_state_cp/model.pt')

    # Ranking args
    parser.add_argument('--cnn_steps_for_eval', type=int, nargs='+', default=(30,))
    parser.add_argument('--recall_th', type=float, default=0.1)

    parsed_args = parser.parse_args(sys.argv[1:])
    return parsed_args


if __name__ == '__main__':
    print('!'*50, '\nCUDA: ', torch.cuda.is_available(), '\n', '!'*50)
    # if not torch.cuda.is_available():
    #     cfg.device = 'cpu'
    args = parse_args()
    get_logger(os.path.basename(__file__).split('.')[0] + '_' + args.mode)
    logger().log('LSTMTraining::main_script', 'args received: ', args)

    model_save_dir_ = args.model_save_dir
    weights_mode = bool(strtobool(args.weights_modes))
    if args.mode == 'inp':
        _create_pipe_input_data(folders=args.folders, val_size=args.val_size,
                                is_weights=weights_mode, is_single_layer=args.is_single_layer,
                                variance_th=args.variance_threshold, num_features=args.num_features)

    elif args.mode == 'train_pipe':
        train_validate_pipeline_lstm(folders=args.folders, val_size=args.val_size, is_weights=weights_mode,
                                     is_single_layer=args.is_single_layer, variance_th=args.variance_threshold,
                                     num_features=args.num_features, epochs=args.epochs, batch_size=args.batch_size,
                                     lstm_hidden_size=args.lstm_hidden_size, lstm_num_layers=args.lstm_num_layers,
                                     model_dense_sizes=args.regressor_dense_sizes,
                                     save_base_dir='/sise/group/inputs/lstm_1',
                                     model_save_dir=model_save_dir_, batch_first=True,
                                     sequence_size=Const.NUMBER_STEPS_SAVED)

    elif args.mode == 'train_ae':
        train_validate_pre_trained_encoder_lstm(maps_folders=args.folders, results_folders=args.results_folders,
                                                ae_folder=args.ae_save_folder, ae_model_path=args.ae_model_save,
                                                result_metric_name='accuracy', model_save_dir=model_save_dir_,
                                                val_size=args.val_size, epochs=args.epochs, batch_size=args.batch_size,
                                                lstm_hidden_size=args.lstm_hidden_size,
                                                lstm_num_layers=args.lstm_num_layers,
                                                model_dense_sizes=args.regressor_dense_sizes,
                                                sequence_size=Const.NUMBER_STEPS_SAVED, batch_first=True,
                                                last_layer=args.last_layer, loss_name=args.loss_name,
                                                bi_directional=bool(ast.literal_eval(args.bidirectional)),
                                                recall_th=args.recall_th)

    elif args.mode == 'rank_lstm':
        if not torch.cuda.is_available():
            cfg.device = 'cpu'
        lstm_maps_ranking_evaluation(lstm_save_folder=args.model_save_dir[0], lstm_model_path=args.lstm_model_save,
                                     input_folders=args.folders, result_metric_override=None, batch_size=args.batch_size,
                                     pre_load_maps=True, cnn_steps_for_eval=args.cnn_steps_for_eval,
                                     recall_th=args.recall_th)

    elif args.mode == 'lstm_combined':
        train_eval_lstm_combined(ae_gradients_folder=args.ae_gradients_save, ae_weights_folder=args.ae_save_folder,
                                 results_folders=args.results_folders, ae_model_path=args.ae_model_save,
                                 result_metric_name='accuracy', model_save_dir=model_save_dir_, epochs=args.epochs,
                                 batch_size=args.batch_size, lstm_hidden_size=args.lstm_hidden_size,
                                 lstm_num_layers=args.lstm_num_layers, model_dense_sizes=args.regressor_dense_sizes,
                                 sequence_size=Const.NUMBER_STEPS_SAVED*2, batch_first=True, last_layer=args.last_layer,
                                 loss_name=args.loss_name, use_split_lstm=bool(ast.literal_eval(args.split_combined_lstm)),
                                 bi_directional=bool(ast.literal_eval(args.bidirectional)),
                                 recall_th=args.recall_th, augment_few_steps_training=args.augment_few_steps_training)
    else:
        raise NotImplementedError(f'Unknown mode: {args.mode}')
