import copy
from datetime import datetime
from typing import Optional

import numpy as np
import pandas as pd
from torch import nn as nn
from ray import cloudpickle

from src.logger.logger import Logger

class ManyFoldLogger():
    """
    A class used to summarize the training information for all folds of a cross validation.

    It copies text information from a logger and calculates the mean values of the training
    loss, test loss, and accuracy of multiple loggers
    """

    def __init__(self, name: str = "", current_time_string: str=None) -> None:

        self.start_time_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") if current_time_string is None else current_time_string
        self.n_folds = 0
        self.fold_logger = []
        self.performance_values = {}
        self.measure_values = {}
        self.summary_logger = Logger(name=f'{name}_summary', current_time_string=self.start_time_str)
        self._training_start: Optional[datetime] = None
        self.name = name
        self.storage = {"dicts": {}, "texts": {}, "model": None}

    def new_fold(self) -> Logger:
        """
        Adds a new fold logger to the list of fold loggers.

        Args:
            fold (int): The index of the new fold.
        """
        self.fold_logger.append(Logger(fold=self.n_folds, current_time_string=self.start_time_str, name=self.name))

        # Write all stored data from the previous fold logger to the new fold logger
        for type, value in self.storage.items():
            if type == "dicts":
                for name, data in value.items():
                    self.fold_logger[-1].write_dict(data, name)
            elif type == "texts":
                for tag, text in value.items():
                    self.fold_logger[-1].write_text(tag, text)
            elif type == "model":
                self.fold_logger[-1].write_model(value)

        self.n_folds += 1

        return self.fold_logger[-1]

    def add_fold(self, logger) -> None:
        self.fold_logger.append(logger)

        # Write all stored data from the previous fold logger to the new fold logger
        for type, value in self.storage.items():
            if type == "dicts":
                for name, data in value.items():
                    self.fold_logger[-1].write_dict(data, name)
            elif type == "texts":
                for tag, text in value.items():
                    self.fold_logger[-1].write_text(tag, text)
            elif type == "model":
                self.fold_logger[-1].write_model(value)

        self.n_folds += 1

    def log_performance(self, value: float, epoch: int, name: str) -> None:
        self.summary_logger.log_performance(value=value,
                                            epoch=epoch,
                                            name=name)

        self.performance_values[name] = value

    def log_measure(self, value: float, epoch: int, name: str) -> None:
        self.summary_logger.log_measure(value=value,
                                            epoch=epoch,
                                            name=name)

        self.measure_values[name] = value


    def write_dict(self, data: dict, name:str="config") -> None:
        """

        """
        for logger in self.fold_logger + [self.summary_logger]:
            logger.write_dict(data, name)

        self.storage["dicts"][name] = copy.deepcopy(data)

    def write_text(self, tag: str, text: str) -> None:
        """
        Writes a custom text to a custom tag in the TensorBoard log file.

        Args:
            tag (str): The tag for the text.
            text (str): The text to write.
        """
        for logger in self.fold_logger + [self.summary_logger]:
            logger.write_text(tag, text)

        self.storage["texts"][tag] = text

    def write_model(self, model: nn.Module) -> None:
        """
        Writes the model architecture formatted as text to the TensorBoard log file.

        Args:
            model (nn.Module): The model.
        """
        for logger in self.fold_logger + [self.summary_logger]:
            logger.write_model(model)

        self.storage["model"] = model

    def summarize_folds(self) -> None:
        """
        Summarizes the mean results for all folds.

        This method calculates the mean of train loss, val loss, and all validation metrics for each epoch
        that is saved in a list of loggers and writes the mean values to the TensorBoard log file.
        :param loggers: A list of loggers for each fold of a cross validation.
        """
        # Get a set of all keys for the epochs for all logger
        all_epochs = set()
        for fold in self.fold_logger:
            all_epochs.update(fold.train_loss.keys())

        # Collect final scores and other text information from the first fold logger
        for key, value in self.fold_logger[0].text.items():
            if 'score' in key:
                mean = np.mean([fold.text[key] for fold in self.fold_logger])
                self.summary_logger.write_text(key, mean)
            else:
                self.summary_logger.write_text(key, value)

        if hasattr(self.fold_logger[0], 'model'):
            self.summary_logger.write_model(self.fold_logger[0].model)

        for epoch in all_epochs:
            losses = [fold.train_loss.get(epoch) for fold in self.fold_logger if epoch in fold.train_loss]
            mean_train_loss = np.mean(losses) if losses else 0
            self.summary_logger.log_loss(mean_train_loss, epoch, '1_train')

            if epoch in self.fold_logger[0].val_loss.keys():
                losses = [fold.val_loss.get(epoch) for fold in self.fold_logger if epoch in fold.val_loss]
                mean_val_loss = np.mean(losses) if losses else 0
                self.summary_logger.log_loss(mean_val_loss, epoch, '2_validation')

            if epoch in self.fold_logger[0].test_loss.keys():
                losses = [fold.test_loss.get(epoch) for fold in self.fold_logger if epoch in fold.test_loss]
                mean_test_loss = np.mean(losses) if losses else 0
                self.summary_logger.log_loss(mean_test_loss, epoch, '3_test')

            for score in self.fold_logger[0].scores.keys():
                for class_label in self.fold_logger[0].scores[score].keys():
                    if epoch in self.fold_logger[0].scores[score][class_label].keys():
                        scores = [fold.scores[score][class_label].get(epoch) for fold in self.fold_logger if epoch in fold.scores[score][class_label]]
                        mean_score = np.mean(scores) if scores else 0
                        self.summary_logger.log_test_score(mean_score, epoch, score, class_label)

    def get_max_scores(self) -> dict:
        """
        Returns the scores for the best performing model from all epochs of the training for all folds as their mean
        and standard deviation.

        Determines the epoch with the best performing model with regard to the given score or loss for each fold.
        Iterates then over all scores and loss to write the values into a dict for this best performing model.

        :return: Dict with score and fold as key and max value as value
        """
        values_best_epoch = {}
        for i, fold in enumerate(self.fold_logger):
            # Collect best epoch
            values_best_epoch[f'optimal_epoch_fold{i}'] = fold.optimal_epoch

            # Collection scores
            for score in fold.best_results["scores"].keys():
                for class_label in fold.best_results["scores"][score].keys():
                    values_best_epoch[f'{score}_{class_label}_fold{i}'] = fold.best_results["scores"][score][class_label]

            # Add loss values
            for loss in ['train', 'val', 'test']:
                values_best_epoch[f'{loss}_loss_fold{i}'] = fold.best_results[f'{loss}_loss']

            for perf_value in fold.performance_values.keys():
                values_best_epoch[f'mean_{perf_value}_performance_fold{i}'] = np.mean(list(fold.performance_values[perf_value].values()))
                values_best_epoch[f'max_{perf_value}_performance_fold{i}'] = np.max(list(fold.performance_values[perf_value].values()))

            for measure_value in fold.measure_values.keys():
                values_best_epoch[f'measure_{measure_value}_fold{i}'] = np.mean(list(fold.measure_values[measure_value].values()))


        # Get list of keys for fold results
        ordered_keys = list(values_best_epoch.keys())


        # Calculate means and std
        trimmed_keys = ["_".join(k.split("_")[:-1]) for k in ordered_keys if values_best_epoch[k] is not None]
        unique_trimmed_keys = list(dict.fromkeys(trimmed_keys))

        for utk in unique_trimmed_keys:
            fold_values = [v for k, v in values_best_epoch.items() if utk in k]
            values_best_epoch[f"{utk}_mean"] = np.mean(fold_values)
            values_best_epoch[f"{utk}_std"] = np.std(fold_values)

        # values_best_epoch[f'optimal_epoch_mean'] = np.mean([values_best_epoch[f'optimal_epoch_fold{fold}'] for fold in range(len(self.fold_logger))])
        #
        # for score in self.fold_logger[0].best_results["scores"].keys():
        #     for class_label in self.fold_logger[0].best_results["scores"][score].keys():
        #         mean = np.mean([values_best_epoch[f'{score}_{class_label}_fold{fold}'] for fold in range(len(self.fold_logger))])
        #         std = np.std([values_best_epoch[f'{score}_{class_label}_fold{fold}'] for fold in range(len(self.fold_logger))])
        #
        #         values_best_epoch[f'{score}_{class_label}_mean'] = mean
        #         values_best_epoch[f'{score}_{class_label}_std'] = std
        #
        # for loss in ['train', 'val']:
        #     mean = np.mean([values_best_epoch[f'{loss}_loss_fold{i}'] for i in range(len(self.fold_logger))])
        #     std = np.std([values_best_epoch[f'{loss}_loss_fold{i}'] for i in range(len(self.fold_logger))])
        #
        #     values_best_epoch[f'{loss}_loss_mean'] = mean
        #     values_best_epoch[f'{loss}_loss_std'] = std
        #
        # for perf_value in fold.performance_values.keys():
        #     for fun in ['mean', 'max']:
        #         mean = np.mean([values_best_epoch[f'{fun}_{perf_value}_performance_fold{i}'] for i in range(len(self.fold_logger))])
        #         std = np.std([values_best_epoch[f'{fun}_{perf_value}_performance_fold{i}'] for i in range(len(self.fold_logger))])
        #
        #         values_best_epoch[f'{fun}_{perf_value}_performance_mean'] = mean
        #         values_best_epoch[f'{fun}_{perf_value}_performance_std'] = std

        # Prepend keys of mean and std to the ordered keys
        # Performance values are not in the fold logger, so we add them here
        for key, value in self.performance_values.items():
            values_best_epoch[f'mean_{key}_performance'] = value

        for key, value in self.measure_values.items():
            values_best_epoch[f'measure_{key}'] = value

        ordered_keys = [x for x in values_best_epoch.keys() if x not in ordered_keys] + ordered_keys

        # Reorder the dict to have the keys in the same order as they were added
        values_best_epoch = {key: values_best_epoch[key] for key in ordered_keys}

        return values_best_epoch

    def save_test_scores(self, eval_csv_path: str) -> None:
        """
        Saves test scores from all fold logger and stores it in the csv.

        Retrieves test scores from each sub/fold logger, which are a train and test with a random weight
        initialisation of the model. Calculates the mean and the std for each score across all initialisations. Saves
        the aggretation as well as the separate scores together with the parameters as a new row into the csv,
        which is at the give path.

        :param eval_csv_path: Path of a csv, where test results will be saves
        :return: Nona
        """
        row = {}
        row.update({'run_name': f'{self.start_time_str}_{self.name}'})

        # Insert dicts like config at the beginning of the results dict
        for _, top_d in self.storage['dicts'].items():
            for top_key, d in top_d.items():
                if isinstance(d, dict):
                    for key, value in d.items():
                        row[f"{key}"] = value
                else:
                    row[top_key] = d

        # Insert the results
        row.update(self.get_max_scores())

        # Insert any texts that have been logged
        for key, value in self.storage['texts'].items():
            row[key] = value

        row = pd.DataFrame([row])

        try:
            scores_file = pd.read_csv(eval_csv_path)
        except:
            scores_file = pd.DataFrame({})

        scores_file = pd.concat([scores_file, row], ignore_index=True)
        scores_file.to_csv(eval_csv_path, index=False)


    def close(self):
        self.summary_logger.write_training_end("Summary finished")
