# -*- coding: utf-8 -*-
import gc
import functools
import omegaconf
import numpy as np
from typing import List, Tuple
from multiprocessing import Pool
from common.log import access_log
from hgtft.utils.data_utils import (get_feature_collection, ScalerInfoData, DatasetData, NormalizationData)


class CreateDataset:
    def __init__(self, config: omegaconf.DictConfig, project_list: list = None):
        self.config = config
        self.project_list = list(set(config.project_train+config.project_validation+config.project_test)) if (
                project_list is None) else project_list

    def _get_project_data_condition(self) -> List[Tuple]:
        mode_range_info_list = []
        scalers = ScalerInfoData.get_data(self.config.get('dataset_data_path'))
        project_train = list(set(self.project_list).intersection(set(self.config.project_train)))
        project_validation = list(set(self.project_list).intersection(set(self.config.project_validation)))
        project_test = list(set(self.project_list).intersection(set(self.config.project_test)))
        for item in zip(["train", "validation", "test"], [project_train, project_validation, project_test]):
            data_mode, project_id_list = item[0], item[1]
            for project_id in project_id_list:
                mode_range_info_list.append(
                    (
                        project_id,
                        data_mode,
                        self.config.data.time[f'{data_mode}_start'],
                        self.config.data.time[f'{data_mode}_end'],
                        self.config.data.interval[data_mode],
                        self.config.configuration.model.history_len,
                        self.config.configuration.model.future_len,
                        scalers,
                        self.config.feature_info,
                        self.config.get('normalization_data_path'),
                        self.config.get('dataset_data_path')
                    )
                )
        return mode_range_info_list

    @staticmethod
    def _project_create_dataset(arg):
        project_name = arg[0]
        sets_type = arg[1]
        data_start = arg[2]
        data_end = arg[3]
        step = arg[4]
        history_len = arg[5]
        future_len = arg[6]
        scaler_info = arg[7]
        feature_info = arg[8]
        normalization_data_path = arg[9]
        dataset_data_path = arg[10]
        access_log.info(f"============== {project_name}-{sets_type} ==============")
        DatasetData.make_dir(project_name, dataset_data_path)
        temporal_len = future_len + history_len
        full_data_dict = {}
        try:
            for obj_type, obj_feature in feature_info.items():
                feature_map = get_feature_collection(obj_feature)
                target_signal = obj_feature.target_signal
                categorical_cardinalities = dict()
                for feature_name, label_list in scaler_info[obj_type]['categorical'].items():
                    categorical_cardinalities[feature_name] = len(label_list)
                object_data_dict = {}
                obj_id_list = []
                static_numeric_cols = feature_map['static_feats_numeric']
                static_categorical_cols = feature_map['static_feats_categorical']
                temporal_feats_numeric = feature_map['temporal_feats_numeric']
                temporal_feats_categorical = feature_map['temporal_feats_categorical']
                all_conv_feat_list = list(set(static_numeric_cols +
                                              static_categorical_cols +
                                              temporal_feats_numeric +
                                              temporal_feats_categorical +
                                              target_signal
                                              ))
                normalization_data_obj = NormalizationData(project_name, obj_type, normalization_data_path)
                obj_data = normalization_data_obj.load_data()
                if obj_data.shape[0] == 0:
                    continue
                for obj_id, subset_data in obj_data.groupby('id'):
                    if isinstance(obj_id, float):
                        obj_id = str(int(obj_id))
                    elif isinstance(obj_id, int):
                        obj_id = str(obj_id)
                    subset_data = subset_data[(subset_data['time'] >= data_start) & (subset_data['time'] < data_end)]
                    subset_data.loc[:, all_conv_feat_list] = subset_data[all_conv_feat_list].astype(np.float32)

                    data_sets = {
                        'time_index': [],
                        'static_feats_numeric': [],
                        'static_feats_categorical': [],
                        'temporal_feats_numeric': [],
                        'temporal_feats_categorical': [],
                        'target': []
                    }

                    for start_index, end_index in [(i, i + history_len + future_len) for i in
                                        range(0, len(subset_data) - history_len - future_len + 1, step)]:
                        slc = subset_data.iloc[start_index: end_index]
                        data_sets['time_index'].append(start_index)
                        # static
                        data_sets['static_feats_numeric'].append(slc.iloc[0][static_numeric_cols].values.astype(np.float32))
                        data_sets['static_feats_categorical'].append(
                            slc.iloc[0][static_categorical_cols].values.astype(np.int16))
                        # temporal
                        data_sets['temporal_feats_numeric'].append(
                            slc[temporal_feats_numeric].values.reshape(temporal_len, -1)
                        )
                        data_sets['temporal_feats_categorical'].append(
                            slc[temporal_feats_categorical].values.reshape(temporal_len, -1)
                        )
                        # target
                        data_sets['target'].append(
                            slc[target_signal].values[history_len:]
                        )
                    for key in data_sets:
                        data_sets[key] = np.array(data_sets.get(key))
                    object_data_dict[obj_id] = data_sets
                    obj_id_list.append(obj_id)
                full_data_dict[obj_type] = {
                    'data_sets': object_data_dict,
                    'feature_map': feature_map,
                    'scalers': scaler_info,
                    'categorical_cardinalities': categorical_cardinalities,
                    'output_target_len': len(target_signal),
                    'target_signal': target_signal,
                    'obj_id_list': obj_id_list
                }
                del obj_data
                del subset_data
                del normalization_data_obj
                gc.collect()

            access_log.info(f"============== save full data ==============")
            DatasetData.save_dataset(project_name, sets_type, full_data_dict, dataset_data_path)
            del full_data_dict
            gc.collect()
        except Exception as e:
            DatasetData.del_dataset(project_name, sets_type, dataset_data_path)
            access_log.error(f'fail create dataset {project_name}-{sets_type}')
            access_log.error(e)

    def start(self):
        access_log.info(f"============== start create dataset ==============")
        mode_range_info_list = self._get_project_data_condition()
        partial_func = functools.partial(self._project_create_dataset)

        with Pool(processes=self.config.max_workers) as pool:
            pool.map(partial_func, mode_range_info_list)

        access_log.info(f"============== finish create dataset ==============")
