######################################################################################################################
# A feature map shape is: (layers, values, features(stats))
#
######################################################################################################################
import os
from typing import Dict, Tuple, Union, List, Callable, Optional, Iterator
from collections import defaultdict
import numpy as np

from Utils.Constants import StatsConst
from Utils.Constants import Diff as DiffConst
from Utils import config, logger, get_logger
from Utils.utils import pkl_file_reader_gen, flatten_problematic_stats

WEIGHTS_SINGLE_LAYER_FEATURES = ('mean', 'variance', 'median', 'std', 'max', 'min', 'covariance', 'geometric_mean',
                                 'harmonic_mean', 'skewness', 'kurtosis',  # 'geometric_std',
                                 'anderson_norm', 'anderson_logistic', 'anderson_gumbel',  # 'anderson_expon',
                                 'q-th_percentile', 'L1_norm', 'L2_norm')

GRADIENTS_SINGLE_LAYER_FEATURES = ('mean', 'variance', 'median', 'std', 'max', 'min', 'covariance', 'geometric_mean',
                                   'hmean', 'skewness', 'kurtosis',  # 'geometric_std',
                                   'anderson_norm', 'anderson_logistic', 'anderson_gumbel',  # 'anderson_expon',
                                   'q-th_percentile', 'L1_norm', 'L2_norm')

WEIGHTS_DIFF_LAYERS_FEATURES = ('spearman_correlation', 'point_biserial_correlation', 'weighted_tau', 't_test',
                                'kolmogorov_smirnov_test', 'epps_singleton', 'brunner_munzel', 'levene_var_test',
                                'median_test', 'wasserstein_distance', 'energy_distance', 'kl_div')

GRADIENTS_DIFF_LAYERS_FEATURES = WEIGHTS_DIFF_LAYERS_FEATURES

SINGLE_STAT_FEATURES_TO_FIX = ('anderson_norm', 'anderson_expon', 'anderson_logistic', 'anderson_gumbel',
                               'median_test', 'spearman_corr')


class FeatureMapGenerator:
    """
    This class is an Iterable that creates feature maps for weights and gradients for both single layer and diff layers
    data.
    Each generator step returns a single array of the current step read from the stats file, with 2 possible shapes:
        * layers_first = True ->  feature_map_layers_size, config.FEATURE_MAP_VALUES_SIZE, number of features
        * layers_first = False ->  number of features, feature_map_layers_size, config.FEATURE_MAP_VALUES_SIZE
    """
    def __init__(self, file_path: str, features_names: Tuple, stats_col_name: str, skip_rows: int, pad_history: bool,
                 is_single_layer_stats: bool, layers_first: bool, features_to_fix: Tuple):
        logger().log(__name__, 'Generator for ', file_path, ' skip rows=', skip_rows)
        self.file_path = file_path
        self.features_names = features_names
        self.stats_col_name = stats_col_name
        self.features_to_fix = features_to_fix
        self.pad_history = pad_history
        self.channels_last = layers_first
        self.file_gen = pkl_file_reader_gen(file_path)
        # Typing is used to remove annoying pycharm warnings
        self.gen_function: Callable[[Dict], np.ndarray] =\
            self._generator_single_layer_step_map if is_single_layer_stats else self._generator_diff_layers_step_map
        for _ in range(skip_rows):
            next(self.file_gen)

    def _fix_features(self, grouped_stats: Dict):
        """
        Fixes stats that have nested info.
        :param grouped_stats: A dictionary of the collected stats into features dictionary
        :return: The same dictionary with fixed stats
        """
        for feature_name in self.features_to_fix:
            if feature_name in grouped_stats:
                to_fix = grouped_stats[feature_name]
                grouped_stats[feature_name] = flatten_problematic_stats(to_fix)
        return grouped_stats

    def _group_stats_to_feature(self, stats: Dict) -> Dict:
        """
        convert stats to features according to name of feature,
        e.g., mean and mean_axis_-1 will be converted to a feature
        :param stats:
        :return: Dictionary with keys as feature_names and values an array of all the stats of that feature
        """
        features = defaultdict(list)
        for feature_name in self.features_names:
            relevant_stats = [stat_name for stat_name in stats.keys() if stat_name.startswith(feature_name) and 'orig' not in stat_name]
            if len(relevant_stats) == 0:
                logger().error(__name__, None, f"Error finding feature {feature_name}")
                continue
            for stat in relevant_stats:
                features[feature_name].append(np.array(stats.get(stat, [0]), dtype='object').flatten())
        features = self._fix_features(features)
        return dict(features)

    @staticmethod
    def _convert_collected_features_to_map(features: Dict) -> Dict:
        """
        :param features: Dictionary of features each entry is a feature with all it values collected from the stats file
        (the result of _concat_stats_by_feature)
        :return: Dictionary of all features. values are np array size of config.FEATURE_MAP_VALUES_SIZE
        """
        flat_features = dict()
        for name, values in features.items():
            flat_features[name] = list()
            for curr_val in values:
                if isinstance(curr_val, np.ndarray):
                    flat_features[name] += curr_val.flatten().tolist()
                elif isinstance(curr_val, list):
                    flat_features[name] += curr_val
                else:
                    flat_features[name].append(curr_val)
            if len(flat_features[name]) >= config.feature_map_values_size:
                flat_features[name] = np.array(flat_features[name][:config.feature_map_values_size])
            else:
                tmp = np.zeros((config.feature_map_values_size,))
                tmp[:len(flat_features[name])] = flat_features[name]
                flat_features[name] = tmp

        return flat_features

    def _generator_single_layer_step_map(self, step_stats: Dict) -> np.ndarray:
        """
        A feature map creator of single step for single layer stats for both gradients and weights
        :param step_stats:
        :return: Numpy array with shape of: (number of layers, config.FEATURE_MAP_VALUES_SIZE, number of features)
        """
        batch_num = list(step_stats.keys())[0]
        step_stats = step_stats[batch_num]
        step_features = list()
        for layer_name, layer in step_stats.items():
            if not self.is_layer_relevant(layer_name):
                continue
            stats = layer[self.stats_col_name]
            features_data = self._group_stats_to_feature(stats)
            features_data = self._convert_collected_features_to_map(features_data)
            features_values = [val for val in features_data.values()]
            step_features.append(features_values)

        ans = np.stack(step_features, axis=0)
        ans = np.swapaxes(ans, 1, 2)
        return ans

    def _weights_diff_stats_to_features(self, stats: Dict) -> List:
        """
        :param stats:
        :return: a list with shape of (features, feature map size)
        """
        layer_step_diff_features = list()
        for diff_step_stats in stats.values():
            features_data = self._group_stats_to_feature(diff_step_stats)
            layer_step_diff_features.append(features_data)

        # Combine all histories of same feature into single list
        concatenated_history = defaultdict(list)
        for curr_hist in reversed(layer_step_diff_features):
            for feature, value in curr_hist.items():
                concatenated_history[feature] += value
        features_dict = self._convert_collected_features_to_map(concatenated_history)
        feature_map = [val for val in features_dict.values()]

        return feature_map

    def _grads_diff_stats_to_features(self, diff_step_stats: Dict) -> List:
        """
        :param diff_step_stats:
        :return: a list with shape of (features, feature map size)
        """
        features_data = dict()
        for feature in GRADIENTS_DIFF_LAYERS_FEATURES:
            feature_stats = diff_step_stats[feature]
            collected_stats = [val for curr_stat in feature_stats for val in curr_stat.values()]
            features_data[feature] = collected_stats
            features_data = self._fix_features(features_data)
        layer_map = self._convert_collected_features_to_map(features_data)
        layer_map = [val for val in layer_map.values()]
        return layer_map

    @staticmethod
    def _is_key_numeric(stats_dict: Dict):
        key = list(stats_dict.keys())[0]
        return isinstance(key, int) or key.isnumeric()

    def _generator_diff_layers_step_map(self, step_stats: Dict) -> np.ndarray:
        """
        A feature map creator of single step for different layers stats for both gradients and weights
        :param step_stats:
        :return: Numpy array with shape of: (number of layers, config.FEATURE_MAP_VALUES_SIZE, number of features)
        """
        batch_num = list(step_stats.keys())[0]
        step_stats = step_stats[batch_num]
        step_features = list()
        for layer_name, layer in step_stats.items():
            if not self.is_layer_relevant(layer_name):
                continue
            stats = layer[self.stats_col_name]
            if self._is_key_numeric(stats):     # Numeric key in this case means this is weights stats
                layer_map = self._weights_diff_stats_to_features(stats)
            else:
                layer_map = self._grads_diff_stats_to_features(stats)
            step_features.append(layer_map)

        ans = np.array(step_features)
        ans = np.swapaxes(ans, 1, 2)
        return ans

    @staticmethod
    def _set_final_layers_size(to_pad: np.ndarray) -> np.ndarray:
        num_layers = to_pad.shape[0]
        if num_layers > config.feature_map_layers_size:
            return to_pad[:config.feature_map_layers_size]
        else:
            missing_layers = config.feature_map_layers_size - num_layers
            return np.concatenate((to_pad, np.zeros((missing_layers, to_pad.shape[1], to_pad.shape[2]))))

    def _final_shape_manipulations(self, arr: np.ndarray) -> np.ndarray:
        final_size_arr = self._set_final_layers_size(arr)
        # Convert into channels first format if needed
        if not self.channels_last:
            return np.moveaxis(final_size_arr, 2, 0)
        return final_size_arr

    @staticmethod
    def is_layer_relevant(layer_name):
        for name in config.relevant_layers_names:
            if name in layer_name:
                return True
        return False

    def __iter__(self):
        return self

    def __next__(self):
        curr_step_data = next(self.file_gen)
        return self._final_shape_manipulations(self.gen_function(curr_step_data))

    def __len__(self):
        return DiffConst.NUMBER_STEPS_SAVED

    @staticmethod
    def create(is_weights: bool, is_single_layer: bool, file_name_or_path: str, stats_dir: Optional[str] = None,
               layers_first: bool = True, stats_to_use: Optional[Tuple] = None) -> Iterator[np.ndarray]:
        """
        :param is_weights:
        :param is_single_layer:
        :param file_name_or_path: file name or full path if stats_dir is None
        :param stats_dir: folder of file if using name or None if file_path is full path to file
        :param layers_first: the generator will return feature map with channels last,
               i.e, shape will be: layers, size, features - if False shape is: features, layers, size
        :param stats_to_use: which stats to create feature map for
        :return: FeatureMapGenerator object for the specified file
        """
        file_name_or_path = os.path.join(stats_dir, file_name_or_path) if stats_dir is not None else file_name_or_path
        if stats_to_use is None:
            features_names_ops = {(True, True): WEIGHTS_SINGLE_LAYER_FEATURES,
                                  (True, False): WEIGHTS_DIFF_LAYERS_FEATURES,
                                  (False, True): GRADIENTS_SINGLE_LAYER_FEATURES,
                                  (False, False): GRADIENTS_DIFF_LAYERS_FEATURES}
            features_names = features_names_ops[(is_weights, is_single_layer)]
        else:
            features_names = stats_to_use
        stats_col_name = StatsConst.LAYER_STATS if is_single_layer else StatsConst.DIFF_STATS
        skip_rows = 1 if is_weights else 0  # Skip layers data in weights stats file
        if not is_single_layer:  # First step in weights and gradients doesn't have diff layer stats
            skip_rows += 1
        gen = FeatureMapGenerator(file_name_or_path, features_names, stats_col_name, skip_rows, is_weights,
                                  is_single_layer, layers_first, SINGLE_STAT_FEATURES_TO_FIX)

        return gen


def test():
    file_path = '/sise/group/models_meds_400_450/stats/gradients_stats_2542335.bz2'
    gen = FeatureMapGenerator.create(is_weights=False, is_single_layer=True, file_name_or_path=file_path,
                                     stats_to_use=('mean', 'variance', 'median', 'std', 'max', 'min', 'covariance',
                                                   'skewness', 'kurtosis', 'q-th_percentile', 'L1_norm', 'L2_norm'))
    for idx, data in enumerate(gen):
        print(idx)


if __name__ == '__main__':
    get_logger(os.path.basename(__file__).split('.')[0])
    test()

