import json
import os
from functools import partial, wraps
from typing import Tuple, Optional, List
import torch

from BaselineModel.BaselineLSTMRegressionModel import LSTMRegressor
from BaselineModel.BaselineCombinedModel import CombinedLSTMRegressor
from BaselineModel.BaselineBaseLSTMRegressor import BaseLSTMRegressor
from DataHandling.LSTMSingleFeatureMapBasedDataset import RegressionSingleFeatureMapDataset
from DataHandling.LSTMCombinedFeatureMapsDataset import CombinedRegressionMapsDataset
from DataHandling.AutoEncodeSingleDataset import AutoEncoderSingleMapsDataset
from Encoder.FeatureMapAutoEncoder import FeatureMapAutoEncoder
from Utils import config as cfg, logger
from Utils.Constants import FileNamesConstants


def load_single_ae_model(ae_folder: str, ae_model_path: str, use_tb: bool) -> Tuple[FeatureMapAutoEncoder, int]:
    """
    :param ae_folder:
    :param ae_model_path:
    :param use_tb:
    :return: Tuple of:
                * full encoder
                 * embedding size of the encoder
    """
    with open(os.path.join(ae_folder, FileNamesConstants.MODEL_HYPER_PARAMS), 'r') as jf:
        ae_params = json.load(jf)
    if not use_tb:
        ae_params['tb_log_path'] = None
    embedding_size = ae_params["layers_shapes"][-1][-1]
    ae_model = FeatureMapAutoEncoder(**ae_params)
    if not os.path.isabs(ae_model_path):
        ae_model_path = os.path.join(ae_folder, ae_model_path)
    ae_model.load_state_dict(torch.load(ae_model_path, map_location=torch.device(cfg.device)))
    return ae_model, embedding_size


def _get_files_to_evaluate(input_folders, lstm_save_folder) -> Tuple:
    """
    Get all files for LSTM evaluation. Splits the files into train, validation and new according to files used in
    LSTM training
    :param input_folders:
    :param lstm_save_folder:
    :return: Tuple of lists: train files, val files, new files (=test files)
    """
    original_files = list()
    for curr_file in [FileNamesConstants.TRAIN_FILES, FileNamesConstants.VAL_FILES, FileNamesConstants.TEST_FILES]:
        with open(os.path.join(lstm_save_folder, curr_file)) as jf:
            original_files.append(json.load(jf))
    train_files, val_files, test_files = original_files
    if input_folders is not None and len(input_folders) > 0:
        rank_files = [os.path.join(curr_folder, curr_file) for curr_folder in input_folders for curr_file in
                      os.listdir(curr_folder)]
        new_files = [file for file in rank_files if file not in train_files and file not in val_files and file not in test_files]
        train_files = [file for file in train_files if file in rank_files]
        val_files = [file for file in val_files if file in rank_files]
        logger().force_log_and_print('BaselineLSTMTraining::_get_files_to_evaluate',
                                     f'Train files evaluated: {train_files}\nVal files evaluated: {val_files}\n'
                                     f'Test files evaluated: {test_files}\nNew files evaluated: {new_files}')
        test_files.extend(new_files)

    old_path = '/sise/group/'
    new_path = '/sise/Group-2/'
    logger().log('BaselineLSTMTraining::_get_files_to_evaluate', f'Updating files to new storage form: {old_path} to {new_path}')
    train_files = [curr.replace(old_path, new_path) for curr in train_files]
    val_files = [curr.replace(old_path, new_path) for curr in val_files]
    test_files = [curr.replace(old_path, new_path) for curr in test_files]
    logger().log('BaselineLSTMTraining::_get_files_to_evaluate', f'Path update example '
                                                                 f'{train_files[0]}, {val_files[0]}, {test_files[0]}')


    return train_files, val_files, test_files


def _load_single_lstm(lstm_run_folder: str, lstm_load_path: str) -> BaseLSTMRegressor:
    with open(os.path.join(lstm_run_folder, FileNamesConstants.MODEL_HYPER_PARAMS), 'r') as jf:
        lstm_params = json.load(jf)

    # Support for pre bidirectional lstm support in my code
    if 'bi_directional' not in lstm_params:
        lstm_params['bi_directional'] = False

    lstm_params['tb_log_path'] = None
    if 'weights_embedding_size' in lstm_params:
        lstm_model = CombinedLSTMRegressor(**lstm_params)
    else:
        lstm_model = LSTMRegressor(**lstm_params)

    lstm_model_save_path = os.path.join(lstm_run_folder, lstm_load_path) if not os.path.isabs(lstm_load_path) else lstm_load_path
    lstm_model.load_state_dict(torch.load(lstm_model_save_path, map_location=torch.device(cfg.device)))

    return lstm_model


def __custom_dataset_creator_wrapper(create_single, maps_files, maps_folders, results_folders, result_metric,
                                     pre_load_data, target_mult_10, auto_encoder=None,
                                     use_2_lstms=False, ae_weights_model=None, ae_gradients_model=None,
                                     weights_maps_files=None, gradients_maps_files=None, augment_few_steps_training=None):
    if create_single:
        original_params = dict(maps_files=maps_files, maps_folders=maps_folders, results_folders=results_folders,
                               result_metric_name=result_metric, pre_load_maps=pre_load_data,
                               auto_encoder=auto_encoder, target_mult_10=target_mult_10,
                               augment_few_steps_training=augment_few_steps_training)
    else:
        original_params = dict(all_files=maps_files, results_folders=results_folders, result_metric_name=result_metric,
                               use_2_lstms=use_2_lstms, pre_load_maps=pre_load_data, target_mult_10=target_mult_10,
                               weights_encoder=ae_weights_model, gradients_encoder=ae_gradients_model,
                               weights_maps_files=weights_maps_files, gradients_maps_files=gradients_maps_files,
                               augment_few_steps_training=augment_few_steps_training)

    def _wrapper(maps_files, **kwargs):
        original_params.update(kwargs)
        if create_single:
            original_params['maps_files'] = maps_files
            dataset = RegressionSingleFeatureMapDataset.create_dataset(**original_params)
        else:
            original_params['all_files'] = maps_files
            dataset = CombinedRegressionMapsDataset.create_dataset(**original_params)
        return dataset

    return _wrapper


def load_all_objects_from_lstm_run_for_inference(lstm_run_folder: str, lstm_load_path: str = 'cp/model_state_cp/model.pt',
                                                 pre_load_data: bool = True, input_folders: Optional[List[str]] = None,
                                                 result_metric_override: Optional[str] = None,
                                                 skip_original_files: bool = False, augment_few_steps_training = None):
    """
    This functions loads everything needed to evaluate any of my LSTM regression models.
    :param lstm_run_folder:  Where all regressor data was saved
    :param lstm_load_path:
    :param pre_load_data:
    :param input_folders: For more files to evaluate
    :param result_metric_override: results metric to use that is different from the one used during the run
    :return:
    """
    logger().log('load_all_objects_from_lstm_run_for_inference', f'Loading data from: {lstm_run_folder}')
    with open(os.path.join(lstm_run_folder, FileNamesConstants.RUN_PARAMETERS), 'r') as jf:
        lstm_run_params = json.load(jf)

    results_folders = lstm_run_params['results_folders']
    result_metric = lstm_run_params['result_metric_name'] if result_metric_override is None else result_metric_override
    target_mult_10 = lstm_run_params['target_mult_10']

    ae_model_path = lstm_run_params['ae_model_path']
    if 'ae_folder' in lstm_run_params:
        ae_folder = lstm_run_params['ae_folder']
        ae_model, embedding_size_weights = load_single_ae_model(ae_folder, ae_model_path, False)
        ae_model.eval()
        ae_gradients_model = None
        ae_weights_model = None
        embedding_size = embedding_size_weights
        lstm_model = _load_single_lstm(lstm_run_folder, lstm_load_path=lstm_load_path)
        dataset_creator = __custom_dataset_creator_wrapper(True,
            maps_files=None, maps_folders=None, results_folders=results_folders, result_metric=result_metric,
            pre_load_data=pre_load_data, auto_encoder=ae_model, target_mult_10=target_mult_10,
            augment_few_steps_training=augment_few_steps_training)
    else:
        ae_model = None
        ae_weights_folder = lstm_run_params['ae_weights_folder']
        ae_weights_model, embedding_size_weights = load_single_ae_model(ae_weights_folder, ae_model_path, False)
        ae_gradients_folder = lstm_run_params['ae_gradients_folder']
        ae_gradients_model, gradients_embedding_size = load_single_ae_model(ae_gradients_folder, ae_model_path, False)
        ae_gradients_model.eval()
        embedding_size = embedding_size_weights + gradients_embedding_size
        lstm_model = _load_single_lstm(lstm_run_folder, lstm_load_path=lstm_load_path)
        use_2_lstm = lstm_run_params['use_split_lstm']
        dataset_creator = __custom_dataset_creator_wrapper(
            False, maps_files=None, results_folders=results_folders, result_metric=result_metric, use_2_lstms=use_2_lstm,
            pre_load_data=pre_load_data, target_mult_10=target_mult_10, ae_weights_model=ae_weights_model,
            ae_gradients_model=ae_gradients_model, weights_maps_files=None, gradients_maps_files=None, maps_folders=None,
            augment_few_steps_training=augment_few_steps_training)

    lstm_model.eval()

    train_files, val_files, test_files = _get_files_to_evaluate(input_folders=input_folders, lstm_save_folder=lstm_run_folder)
    if skip_original_files:
        train_dataset = None,
        val_dataset = None
        test_dataset = None
    else:
        # train_dataset = dataset_creator(train_files)  # Skip train speed
        train_dataset = None
        val_dataset = dataset_creator(val_files) if val_files is not None and len(val_files) > 0 else None
        test_dataset = dataset_creator(test_files) if test_files is not None and len(test_files) > 0 else None

    return dict(train_dataset=train_dataset, val_dataset=val_dataset, test_dataset=test_dataset, target_mult_10=target_mult_10,
                ae_model=ae_model, ae_weights_model=ae_weights_model, ae_gradients_model=ae_gradients_model,
                embedding_size=embedding_size, lstm_model=lstm_model, dataset_creator=dataset_creator,
                train_files=train_files, val_files=val_files, test_files=test_files)


def load_ae_for_checking_inference(ae_folder: str, ae_save_path: str, pre_load_data: bool):
    ae_model, _ = load_single_ae_model(ae_folder=ae_folder, ae_model_path=ae_save_path, use_tb=False)

    train_files, val_files, test_files = _get_files_to_evaluate(input_folders=None,
                                                                lstm_save_folder=ae_folder)
    dataset_creator = partial(AutoEncoderSingleMapsDataset.create_dataset,
                              is_weights='weights' in ae_folder.split('/')[-2], pre_load_data=pre_load_data)
    train_dataset = dataset_creator(train_files)
    val_dataset = dataset_creator(val_files) if val_files is not None and len(val_files) > 0 else None
    test_dataset = dataset_creator(test_files) if test_files is not None and len(test_files) > 0 else None

    return ae_model, dataset_creator, train_dataset, val_dataset, test_dataset
