# -*- coding: utf-8 -*-
import os
import functools
import numpy as np
import pandas as pd
from typing import List

from omegaconf import DictConfig
from multiprocessing import Pool
from common.log import access_log
from hgtft.utils.data_utils import (init_minmax_normalization_model, init_standard_normalization_model,
                                    ProcessedTimeData, ScalerInfoData, NormalizationData, remove_dir)


class NormalizationProcessedData:
    def __init__(self, config: DictConfig):
        self.config = config

    def statistics_collector_minmax(self, project_list: List[str] = None) -> None:
        """

        :return:
        """
        if project_list is None:
            project_list = self.config.project_train
        scaler_dict = {}
        access_log.info(f"============== start statistics data ==============")
        for project_index, project_name in enumerate(project_list):
            access_log.info(f"============== {project_name}-{project_index} ==============")
            for data_type, feature_info in self.config.feature_info.items():
                if data_type not in scaler_dict:
                    scaler_dict[data_type] = {"numeric": {}, "categorical": {}}
                processed_data_obj = ProcessedTimeData(project_name, data_type, self.config.get('original_data_path'))
                processed_data_df = processed_data_obj.load_data()
                processed_data_columns = processed_data_df.columns.tolist()
                if processed_data_df.shape[0] > 0:
                    describe_df = processed_data_df.describe(include='all', percentiles=[0.0001, 0.9999])
                    numeric_feature_list, categorical_feature_list = (
                        processed_data_obj.get_numeric_and_categorical_feature_list(feature_info))
                    for numeric_feature in numeric_feature_list:
                        if numeric_feature in processed_data_columns:
                            if numeric_feature not in scaler_dict[data_type]["numeric"]:
                                scaler_dict[data_type]["numeric"][numeric_feature] = {}
                            min_ = describe_df.loc['0.01%', numeric_feature]
                            max_ = describe_df.loc['99.99%', numeric_feature]

                            scaler_dict[data_type]["numeric"][numeric_feature]['min'] = min_ if \
                            scaler_dict[data_type]["numeric"][
                                numeric_feature].get('min',
                                                     None) is None else min(
                                min_, scaler_dict[data_type]["numeric"][numeric_feature].get('min', None))
                            scaler_dict[data_type]["numeric"][numeric_feature]['max'] = max_ if \
                            scaler_dict[data_type]["numeric"][
                                numeric_feature].get('max',
                                                     None) is None else max(
                                max_, scaler_dict[data_type]["numeric"][numeric_feature].get('max', None))
                    for categorical_feature in categorical_feature_list:
                        if categorical_feature in processed_data_columns:
                            if categorical_feature not in scaler_dict[data_type]["categorical"]:
                                scaler_dict[data_type]["categorical"][categorical_feature] = []
                            categorical_feature_list = processed_data_df[categorical_feature].dropna().unique().tolist()
                            scaler_dict[data_type]['categorical'][categorical_feature] = list(
                                set(scaler_dict[data_type]['categorical'][categorical_feature] + categorical_feature_list))
                else:
                    continue

        ScalerInfoData.save_data(scaler_dict)
        access_log.info(f"============== finish statistics data ==============")

    def statistics_collector_standard(self, project_list: List[str] = None):
        scaler_dict = {}
        access_log.info(f"============== start statistics data ==============")
        if project_list is None:
            project_list = self.config.project_train

        for data_type, feature_info in self.config.feature_info.items():
            access_log.info(f"============== obj type: {data_type} ==============")
            if data_type not in scaler_dict:
                scaler_dict[data_type] = {"numeric": {}, "categorical": {}}

            numeric_feature_list, categorical_feature_list = (
                ProcessedTimeData.get_numeric_and_categorical_feature_list(feature_info))
            all_project_df = pd.DataFrame()
            for project_index, project_name in enumerate(project_list):
                processed_data_obj = ProcessedTimeData(project_name, data_type)
                processed_data_df = processed_data_obj.load_data()
                if processed_data_df.shape[0] > 0:
                    all_project_df = pd.concat([all_project_df, processed_data_df], axis=0)
            all_columns = all_project_df.columns.tolist()

            for numeric_feature in numeric_feature_list:
                if numeric_feature not in all_columns:
                    continue
                numeric_feature_data_df = all_project_df.loc[:, [numeric_feature]]
                numeric_feature_data_df.dropna(inplace=True)
                if numeric_feature_data_df.shape[0] > 0:
                    value_mean = numeric_feature_data_df[numeric_feature].mean()
                    value_std = numeric_feature_data_df[numeric_feature].std()
                    if value_std == 0:
                        value_std = 1
                else:
                    value_mean = np.NAN
                    value_std = np.NAN
                scaler_dict[data_type]["numeric"][numeric_feature] = {
                    'mean': value_mean,
                    'std': value_std
                }
            for categorical_feature in categorical_feature_list:
                if categorical_feature not in all_columns:
                    continue
                scaler_dict[data_type]['categorical'][categorical_feature] = (
                    all_project_df[categorical_feature].dropna().unique().tolist())

        ScalerInfoData.save_data(scaler_dict)
        access_log.info(f"============== finish statistics data ==============")

    @staticmethod
    def _project_normalization(arg):
        project_name, feature_info, normalization_info_dict, original_data_path, normalization_data_path = arg[0], arg[
            1], arg[2], arg[3], arg[4]

        access_log.info(f"============== normalization: {project_name} ==============")
        remove_dir(os.path.join(normalization_data_path, f'{project_name}/scaler'))
        try:
            for data_type in feature_info.keys():
                processed_data_obj = ProcessedTimeData(project_name, data_type, original_data_path)
                data_df = processed_data_obj.load_data()
                if data_df.shape[0] == 0:
                    continue
                for feature_name in normalization_info_dict.get(data_type).keys():
                    try:
                        if feature_name not in list(data_df.columns.tolist()):
                            data_df.loc[:, [feature_name]] = np.NAN
                        scaler_model_tuple = normalization_info_dict[data_type][feature_name]
                        for scaler_model in scaler_model_tuple:
                            data_df[feature_name] = data_df[feature_name].apply(lambda x: scaler_model.transform(x))
                        data_df[feature_name] = data_df[feature_name].astype(np.float32)
                    except Exception as e:
                        access_log.error(f"{data_type}-{feature_name}")
                        raise ValueError(e)
                normalization_data_obj = NormalizationData(project_name, data_type, normalization_data_path)
                normalization_data_obj.save_data(data_df)
        except Exception as e:
            access_log.error(f"Error normalization time series: {project_name}")
            access_log.error(e)

    def project_normalization(self, project_sample_id: List[str] = list, scaler_func: str = 'standard') -> None:
        """

        @param project_sample_id:
        @param scaler_func:
        @return:
        """
        if scaler_func == 'standard':
            scaler_model_dict = init_standard_normalization_model()
        elif scaler_func == 'minmax':
            scaler_model_dict = init_minmax_normalization_model(self.config.get('dataset_data_path', None))
        else:
            raise ValueError(f'no function {scaler_func}')
        arg_list = [(project_id, self.config.feature_info, scaler_model_dict, self.config.get('original_data_path'),
                     self.config.get('normalization_data_path')) for project_id in project_sample_id]
        partial_func = functools.partial(self._project_normalization)
        with Pool(processes=self.config.max_workers) as pool:
            pool.map(partial_func, arg_list)
