import json
import os
from typing import Tuple, List, Dict, Optional

import numpy as np
import torch

from BaselineModel.BaselineLSTMRegressionModel import LSTMRegressor
from BaselineModel.BaselineBaseLSTMRegressor import BaseLSTMRegressor
from BaselineModel.BaselineModelLoader import load_all_objects_from_lstm_run_for_inference
from DataHandling.FileBasedDatasetBase import FileBasedDataset
from DataHandling.DataUtils import collect_models_results
from ModelsUtils.Metrics import precision_at_k, average_precision_at_k, recall_at_k, average_recall_at_k
from Utils import logger, config as cfg
from Utils.Constants import FileNamesConstants, Diff
from Utils.utils import function_start_save_params


def _get_results(train_files, val_files, new_files, results_folders, metric_name):
    """
    Collect real models mteric
    :param train_files:
    :param val_files:
    :param new_files:
    :param results_folders:
    :param metric_name:
    :return: Tuple of lists: train files results, val files results, new files (=test files) results
    """
    all_files = train_files + val_files + new_files
    models_ids, models_results_map = collect_models_results(all_files, results_folders, metric_name)
    models_results = [models_results_map[model_id] for model_id in models_ids]
    train_results = models_results[:len(train_files)]
    val_results = models_results[len(train_files):len(train_files)+len(val_files)]
    new_results = models_results[len(train_files)+len(val_files):len(new_files)]
    return train_results, val_results, new_results


def load_pre_trained_lstm(lstm_folder: str, lstm_model_path: str, use_tb: bool,
                          lstm_class: BaseLSTMRegressor = LSTMRegressor):
    """
    Load a pre trained LSTM model
    :param lstm_folder: Folder with model hyper parameters
    :param lstm_model_path: absolute path to lstm model or relative from lstm_folder
    :param use_tb: should use the model tensorboard for something
    :param lstm_class: class of lstm type
    :return: the LSTM model
    """
    with open(os.path.join(lstm_folder, FileNamesConstants.MODEL_HYPER_PARAMS), 'r') as jf:
        lstm_params = json.load(jf)
    if not use_tb:
        lstm_params['tb_log_path'] = None
    lstm_model = lstm_class(**lstm_params)
    if not os.path.isabs(lstm_model_path):
        lstm_model_path = os.path.join(lstm_folder, lstm_model_path)
    lstm_model.load_state_dict(torch.load(lstm_model_path, map_location=torch.device(cfg.device)))
    return lstm_model, lstm_params.get('use_split_lstm', False)


def _eval_data_gen(dataset: FileBasedDataset, batch_size: int):
    for batch_idx in range(0, len(dataset), batch_size):
        x = list()
        y_true = list()
        for data_idx in range(batch_idx, batch_idx + batch_size):
            if data_idx >= len(dataset):
                break
            inp, result = dataset.__getitem__(data_idx)
            x.append(inp)
            y_true.append(result)
        yield np.array(x), y_true


def preds_results_lstm_model(lstm_model: BaseLSTMRegressor, dataset: FileBasedDataset, cnn_steps_for_eval:  Tuple[int] = (30,),
                             batch_size: int = 128) -> Tuple[np.ndarray, np.ndarray, List[str]]:
    """
    Calculate the prediction and real results of an entire dataset for a LSTMRegressor for multiple number of cnn stpes
    :param lstm_model:
    :param dataset:
    :param cnn_steps_for_eval:
    :param batch_size:
    :return: Tuple of 2 numpy arrays:
        * lstm model predictions for all files in dataset for the different CNN steps, shape: (cnn steps, dataset size)
        * actual models results from dataset - shape: (dataset size)
        * All the files in the dataset, ordered according to preds and results
    """
    curr_preds = np.zeros((len(cnn_steps_for_eval), len(dataset)), dtype='float')
    curr_results = np.zeros(len(dataset), dtype='float')
    if dataset is not None and len(dataset) != 0:
        with torch.no_grad():
            gen = _eval_data_gen(dataset, batch_size)
            for idx, (curr_x, curr_y_true) in enumerate(gen):
                start_idx = idx * batch_size
                end_idx = (idx + 1) * batch_size
                for step_idx, step in enumerate(cnn_steps_for_eval):
                    tensor = torch.tensor(np.array(curr_x[:, :step, :]), device=cfg.device)
                    preds = lstm_model(tensor)
                    curr_preds[step_idx, start_idx:end_idx] = preds.to('cpu').detach().numpy().ravel()
                curr_results[start_idx:end_idx] = curr_y_true
    return curr_preds, curr_results, dataset.files.copy()


def __lstm_eval_collect_all_data(all_files_datasets: List[FileBasedDataset], batch_size: int, lstm_model: BaseLSTMRegressor,
                                 cnn_steps_for_eval: Tuple[int]) -> Tuple[List, List, List]:
    """
    Collect all models predictions and expected value
    :param all_files_datasets:
    :param batch_size:
    :param lstm_model:
    :param cnn_steps_for_eval: tuple of steps used for each prediction
    :return: Tuple of lists:
                * List with preds for each dataset - each entry is a numpy array with shape (num cnn steps, dataset size)
                * y_true_labels
                * dataset files names
    """
    cnn_steps_for_eval = np.array(cnn_steps_for_eval)
    if any(cnn_steps_for_eval > Diff.NUMBER_STEPS_SAVED):
        bad_steps = cnn_steps_for_eval[cnn_steps_for_eval > Diff.NUMBER_STEPS_SAVED]
        for to_remove in bad_steps:
            cnn_steps_for_eval = np.delete(cnn_steps_for_eval, np.where(cnn_steps_for_eval, to_remove)[0])

        if len(cnn_steps_for_eval) == 0:
            cnn_steps_for_eval = (Diff.NUMBER_STEPS_SAVED, )

        logger().warning('BaselineLSTMRankingEvaluation::__lstm_eval_collect_all_data',
                         f'Have bad steps: {bad_steps}, removing the bad steps. Steps kept: {cnn_steps_for_eval}')


    all_preds = list()
    all_results = list()
    all_files = list()
    for curr_dataset in all_files_datasets:
        curr_preds, curr_results, curr_files = preds_results_lstm_model(dataset=curr_dataset, lstm_model=lstm_model,
                                                                        cnn_steps_for_eval=cnn_steps_for_eval,
                                                                        batch_size=batch_size)
        all_preds.append(curr_preds)
        all_results.append(curr_results)
        all_files.append(curr_files)

    return all_preds, all_results, all_files


def __lstm_actual_ranking_evaluation(all_files_lists: List[List[str]], all_preds: List[List[float]],
                                     all_results: List[List[float]], recall_th: float) -> Dict:
    """
    Evaluate LSTM ranking
    :param all_files_lists: list of lists with all inputs files
    :param all_preds: list of lists with LSTM predictions
    :param all_results: list of lists with LSTM results
    :return: dictionary with final ranking results. Keys are the names of the datasets, each value in another dictionary
            where keys are the metric name and values are the ranking values
    """
    final_eval = dict()
    for name, curr_files, curr_preds, curr_results in zip(['train', 'validation', 'test'],
                                                          all_files_lists, all_preds, all_results):
        final_eval[name] = dict()
        final_eval[name]['dataset_size'] = len(curr_preds)
        if len(curr_preds) != 0:
            for k in [3, 5, 10, 15, 20, 50, 100]:
                final_eval[name][f'precision@{k}_opt'] = precision_at_k(names=curr_files, y_preds=curr_preds,
                                                                        y_true=curr_results, k=k, random_optimistic=True)
                final_eval[name][f'precision@{k}'] = precision_at_k(names=curr_files, y_preds=curr_preds,
                                                                    y_true=curr_results, k=k, random_optimistic=False)
                final_eval[name][f'average_precision@{k}_opt'] = average_precision_at_k(curr_files, y_preds=curr_preds,
                                                                                        y_true=curr_results, k=k,
                                                                                        random_optimistic=True)
                final_eval[name][f'average_precision@{k}'] = average_precision_at_k(curr_files, y_preds=curr_preds,
                                                                                    y_true=curr_results, k=k,
                                                                                    random_optimistic=False)
                if recall_th > 0:
                    final_eval[name][f'recall_{recall_th}@{k}_opt'] = recall_at_k(
                        names=curr_files, y_preds=curr_preds, th=recall_th, y_true=curr_results, k=k, random_optimistic=True)
                    final_eval[name][f'recall_{recall_th}@{k}'] = recall_at_k(
                        names=curr_files, y_preds=curr_preds, th=recall_th, y_true=curr_results, k=k, random_optimistic=False)
                    final_eval[name][f'average_recall_{recall_th}@{k}'] = average_recall_at_k(
                        names=curr_files, y_preds=curr_preds, th=recall_th, y_true=curr_results, k=k, random_optimistic=True)
                    final_eval[name][f'average_recall_{recall_th}@{k}'] = average_recall_at_k(
                        names=curr_files, y_preds=curr_preds, th=recall_th, y_true=curr_results, k=k, random_optimistic=False)
    return final_eval


def _actual_lstm_maps_eval(lstm_model: BaseLSTMRegressor, lstm_save_folder: str,
                           train_dataset: FileBasedDataset, val_dataset: FileBasedDataset, test_dataset: FileBasedDataset,
                           batch_size: int, recall_th: float, cnn_steps_for_eval: Tuple[int] = (30,)) -> Dict:
    all_datasets = [train_dataset, val_dataset, test_dataset]
    all_preds, all_results, all_files = __lstm_eval_collect_all_data(all_datasets, batch_size, lstm_model,
                                                                     cnn_steps_for_eval=cnn_steps_for_eval)
    all_evals = dict()
    for step_idx, step, in enumerate(cnn_steps_for_eval):
        curr_preds = [p[step_idx] for p in all_preds]
        final_eval = __lstm_actual_ranking_evaluation(all_files_lists=all_files, all_preds=curr_preds,
                                                      all_results=all_results, recall_th=recall_th)
        all_evals[step] = final_eval
    with open(os.path.join(lstm_save_folder, FileNamesConstants.RANKING_RESULTS), 'w') as jf:
        json.dump(all_evals, jf)
    logger().force_log_and_print('BaselineLSTMTraining::lstm_maps_ranking_evaluation', f'Results:\n {all_evals}\n\n')
    return all_evals


def lstm_maps_ranking_evaluation(lstm_save_folder: str, lstm_model_path: str, input_folders: Optional[List[str]],
                                 result_metric_override: Optional[str], batch_size: int, pre_load_maps: bool,
                                 recall_th: float, cnn_steps_for_eval: Tuple[int] = (30,)) -> Dict:
    """
    Evaluate a LSTM model ranking for different K and different cnn steps
    :param lstm_save_folder:
    :param lstm_model_path:
    :param input_folders:
    :param result_metric_override:
    :param batch_size:
    :param pre_load_maps:
    :param cnn_steps_for_eval:
    :return:
    """
    func_run_params = locals().copy()
    function_start_save_params(func_run_params, config=cfg, extra_data=None,
                               save_path=os.path.join(lstm_save_folder, 'rank_eval_parameters'))
    logger().log('BaselineLSTMTraining::lstm_maps_ranking_evaluation', 'Ranking for: ', lstm_save_folder,
                 'with extra files: ', input_folders)

    all_data = load_all_objects_from_lstm_run_for_inference(lstm_run_folder=lstm_save_folder, lstm_load_path=lstm_model_path,
                                                            pre_load_data=pre_load_maps, input_folders=input_folders,
                                                            result_metric_override=result_metric_override)

    lstm_model = all_data['lstm_model']
    train_dataset = all_data['train_dataset']
    val_dataset = all_data['val_dataset']
    test_dataset = all_data['test_dataset']
    return _actual_lstm_maps_eval(lstm_model, lstm_save_folder, train_dataset, val_dataset, test_dataset,
                                  batch_size=batch_size, recall_th=recall_th, cnn_steps_for_eval=cnn_steps_for_eval)
