# -*- coding: utf-8 -*-
import os
import gzip
import copy
import json
import random
import shutil
import pickle
import struct
import yaml
import torch
import threading
import numpy as np
import pandas as pd
from datetime import datetime
from common.log import access_log
from collections import OrderedDict
from typing import Dict, List, Tuple, Union, Any
from torch.utils.data import Dataset
from omegaconf import OmegaConf, DictConfig

from .. import ROOT_DIR


class ProcessedTimeData:

    def __init__(self, project_id: str, data_type: str, original_data_path: str = None):
        self.project_id = project_id
        self.data_type = data_type
        base_path = original_data_path or os.path.join(ROOT_DIR, 'data')
        self.data_path = os.path.join(base_path, f'{project_id}/timeseries/{data_type}.parquet')

    def load_data(self) -> pd.DataFrame:
        if not os.path.exists(self.data_path):
            return pd.DataFrame()
        else:
            data_df = pd.read_parquet(self.data_path, engine='pyarrow')
            data_df.time = pd.to_datetime(data_df.time)
            data_df = data_df.replace({None: np.nan})
            return data_df

    def save_data(self, data_df: pd.DataFrame) -> None:
        data_df.to_parquet(self.data_path, index=False, engine='pyarrow')

    @staticmethod
    def get_numeric_and_categorical_feature_list(feature_info: DictConfig) -> Tuple[List[str], List[str]]:
        feature_map = get_feature_collection(feature_info)
        numeric_feature_list = list(set(
            feature_map['static_feats_numeric'] + feature_map['temporal_feats_numeric'] + feature_map[
                'target_feats_numeric']))
        categorical_feature_list = list(set(
            feature_map['static_feats_categorical'] + feature_map['temporal_feats_categorical'] + feature_map[
                'target_feats_categorical']))
        return numeric_feature_list, categorical_feature_list

    def get_feature_map(self, feature_info: DictConfig, data_all_feature_list: List) -> Dict[str, List[str]]:
        meta_attrs = feature_info[self.data_type].meta_attrs
        static_attrs = feature_info[self.data_type].static_attrs
        categorical_attrs = feature_info[self.data_type].categorical_attrs
        known_attrs = feature_info[self.data_type].known_attrs
        target_signal = feature_info[self.data_type].target_signal
        feature_cols = [col for col in data_all_feature_list if col not in meta_attrs]
        feature_map = {
            'static_feats_numeric': [
                col for col in feature_cols if col in static_attrs and col not in categorical_attrs
            ],
            'static_feats_categorical': [
                col for col in feature_cols if col in static_attrs and col in categorical_attrs
            ],
            'historical_ts_numeric': [
                col for col in feature_cols if col not in static_attrs and col not in categorical_attrs
            ],
            'historical_ts_categorical': [
                col for col in feature_cols if col not in static_attrs and col in categorical_attrs
            ],
            'future_ts_numeric': [
                col for col in feature_cols if col in known_attrs and col not in categorical_attrs
            ],
            'future_ts_categorical': [
                col for col in feature_cols if col in known_attrs and col in categorical_attrs
            ],
            'target_categorical': [
                col for col in feature_cols if col in target_signal and col in categorical_attrs
            ],
            'target_numeric': [
                col for col in feature_cols if col in target_signal and col not in categorical_attrs
            ]
        }
        return feature_map

    def del_data(self):
        try:
            os.remove(self.data_path)
        except Exception as e:
            raise ValueError(e)


class GraphData:
    def __init__(self, project_name: str, from_type: str, to_type: str, edge: str, dataset_data_path: str = None):
        self.project_name = project_name
        self.from_type = from_type
        self.to_type = to_type
        self.edge = edge
        if dataset_data_path is None:
            self.file_path = os.path.join(ROOT_DIR, f'data/{project_name}/graph/{from_type}-{edge}-{to_type}.csv')
        else:
            self.file_path = os.path.join(dataset_data_path, f'{project_name}/graph/{from_type}-{edge}-{to_type}.csv')
        self._load_data()
        self._get_all_id()

    def _load_data(self):
        graph_data = pd.read_csv(self.file_path, index_col='id')
        graph_data.columns = [str(col) for col in graph_data.columns]
        graph_data.index = graph_data.index.astype(str)
        self.graph_data = graph_data

    def _get_all_id(self):
        self.dst_type_id = self.graph_data.index.tolist()
        self.src_type_id = self.graph_data.columns.tolist()

    @staticmethod
    def save_data(project_name, sample_name, from_type, edge, to_type, data_df: pd.DataFrame, dataset_data_path: str = None) -> None:
        if dataset_data_path is None:
            file_path = os.path.join(ROOT_DIR, f'data/{project_name}_{sample_name}/graph/{from_type}-{edge}-{to_type}.csv')
        else:
            file_path = os.path.join(dataset_data_path, f'{project_name}_{sample_name}/graph/{from_type}-{edge}-{to_type}.csv')
        data_df.to_csv(file_path, index=True)


class DatasetData:

    @staticmethod
    def make_dir(project_name: str, dataset_data_path: str = None) -> None:
        data_set_dir =  dataset_data_path or os.path.join(ROOT_DIR, 'data')
        data_set_dir = os.path.join(data_set_dir, f'{project_name}/series')
        if not os.path.exists(data_set_dir):
            os.makedirs(data_set_dir)

    @staticmethod
    def load_dataset(project_name: str, data_type: str, dataset_data_path: str = None) -> Dict:
        path = dataset_data_path or os.path.join(ROOT_DIR, 'data')
        path = os.path.join(path, project_name, 'series', f'{data_type}.pkl.gz')
        with gzip.open(path, 'rb') as f:
            data = pickle.load(f)
        return data

    @staticmethod
    def save_dataset(project_name: str, data_type: str, data_dict: Dict, dataset_data_path: str = None) -> None:
        path = dataset_data_path or os.path.join(ROOT_DIR, 'data')
        path = os.path.join(path, project_name, 'series', f'{data_type}.pkl.gz')
        with gzip.open(path, 'wb') as f:
            pickle.dump(data_dict, f)

    @staticmethod
    def del_dataset(project_name: str, data_type: str, dataset_data_path: str = None) -> None:
        path = dataset_data_path or os.path.join(ROOT_DIR, 'data')
        path = os.path.join(path, project_name, 'series', f'{data_type}.pkl.gz')
        if os.path.exists(path):
            os.remove(path)


class DictDataSet(Dataset):
    def __init__(self, array_dict: Dict[str, np.ndarray]):
        self.keys_list = []
        for k, v in array_dict.items():
            self.keys_list.append(k)
            if np.issubdtype(v.dtype, np.dtype('bool')):
                setattr(self, k, torch.ByteTensor(v))
            elif np.issubdtype(v.dtype, np.int8):
                setattr(self, k, torch.CharTensor(v))
            elif np.issubdtype(v.dtype, np.int16):
                setattr(self, k, torch.ShortTensor(v))
            elif np.issubdtype(v.dtype, np.int32):
                setattr(self, k, torch.IntTensor(v))
            elif np.issubdtype(v.dtype, np.int64):
                setattr(self, k, torch.LongTensor(v))
            elif np.issubdtype(v.dtype, np.float32):
                setattr(self, k, torch.FloatTensor(v))
            elif np.issubdtype(v.dtype, np.float64):
                setattr(self, k, torch.DoubleTensor(v))
            else:
                setattr(self, k, torch.FloatTensor(v))

    def __getitem__(self, index):
        return {k: getattr(self, k)[index] for k in self.keys_list}

    def __len__(self):
        return getattr(self, self.keys_list[0]).shape[0]


class ModelData:

    @staticmethod
    def save_model(model_name: str, state_dict: Dict, dataset_data_path: str = None) -> None:
        if dataset_data_path is None:
            path = os.path.join(ROOT_DIR, 'model_file', f'{model_name}.pth')
        else:
            path = os.path.join(dataset_data_path, 'model_file', f'{model_name}.pth')
        torch.save(state_dict, path)

    @staticmethod
    def load_model(model_name: str, device: torch.device, dataset_data_path: str = None) -> Dict:
        """

        @param model_name:
        @param device:
        @param dataset_data_path:
        @return:
        """
        if dataset_data_path is None:
            path = os.path.join(ROOT_DIR, 'model_file', f'{model_name}.pth')
        else:
            path = os.path.join(dataset_data_path, 'model_file', f'{model_name}.pth')
        model_params = torch.load(path, map_location=device)
        new_state_dict = OrderedDict()
        for k, v in model_params.items():
            if k.startswith('module.'):
                new_state_dict[k[7:]] = v
            else:
                new_state_dict[k] = v
        return new_state_dict


def scale_back(scaler_obj, signal):
    inv_trans = scaler_obj.inverse_transform(copy.deepcopy(signal))
    return inv_trans


def make_batch(nested_dict: Dict, obj_type: str, obj_id: str, data_type: str, data: torch.Tensor, device: torch.device,
              is_cuda: bool = False, ) -> Dict:
    if obj_type not in nested_dict:
        nested_dict[obj_type] = {}
    if obj_id not in nested_dict[obj_type]:
        nested_dict[obj_type][obj_id] = {}
    if is_cuda:
        nested_dict[obj_type][obj_id][data_type] = data.to(device)
    return nested_dict


def remove_dir(dir_path: str) -> None:
    """

    @param dir_path:
    @return:
    """
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)
    else:
        try:
            shutil.rmtree(dir_path)
            os.makedirs(dir_path)
        except Exception as e:
            print(e)


def rm_tree(dir_path: str) -> None:
    try:
        shutil.rmtree(dir_path)
    except Exception as e:
        access_log.error(e)


class CustomLabelEncoder:
    def __init__(self, classes: List[float], nan_value: float = -0.0000001):
        self.classes_ = sorted(classes)
        self.nan_value = nan_value

    def transform(self, y):
        if np.isnan(y):
            return self.nan_value
        else:
            try:
                return self.classes_.index(y) + 1
            except ValueError:
                return 0

    def inverse_transform(self, y):
        if y > 0:
            return self.classes_[y - 1]
        else:
            return np.nan

    def name(self):
        return self.__class__.__name__


class CustomMinMaxScaler:
    def __init__(self, data_min, data_max, nan_value: float = -0.0000001):
        self.min_ = data_min
        self.max_ = data_max
        self.nan_value = nan_value
        self.feature_range = (0, 1)
        if np.isnan(self.max_) or np.isnan(self.max_):
            self.data_range_ = 0
        else:
            self.data_range_ = self.max_ - self.min_
        self.scale_range_ = self.feature_range[1] - self.feature_range[0]

    def transform(self, x):
        if np.isnan(x):
            return self.nan_value
        elif self.data_range_ == 0:
            return self.feature_range[0]
        else:
            return self.feature_range[0] + (x - self.min_) / self.data_range_ * self.scale_range_

    def inverse_transform(self, x):
        if x == self.nan_value or self.feature_range[0]:
            return np.nan
        else:
            return (x - self.feature_range[0]) / self.scale_range_ * self.data_range_ + self.min_

    def name(self):
        return self.__class__.__name__


class CustomStandardScaler:
    def __init__(self, value_mean: float, value_std: float, nan_value: float = -0.0000001):
        self.value_mean = value_mean
        self.value_std = value_std
        self.nan_value = nan_value

    def transform(self, x):
        if np.isnan(x):
            return self.nan_value
        else:
            return (x - self.value_mean) / self.value_std

    def inverse_transform(self, x):
        if x == self.nan_value:
            return np.nan
        else:
            return x * self.value_std + self.value_mean

    def name(self):
        return self.__class__.__name__


class CustomOffsetScaler:
    def __init__(self, offset_value: float):
        self.offset_value = offset_value

    def transform(self, x):
        if np.isnan(x):
            return x
        elif x < self.offset_value:
            return 0
        else:
            return x - self.offset_value

    def inverse_transform(self, x):
        return x + self.offset_value

    def name(self):
        return self.__class__.__name__


class CustomLogScaler:
    def __init__(self, offset_value: float = 0.01):
        self.offset_value = offset_value

    def transform(self, x):
        if np.isnan(x):
            return x
        else:
            return np.log(x + self.offset_value)

    def inverse_transform(self, x):
        if np.isnan(x):
            return x
        else:
            return np.exp(x) - self.offset_value

    def name(self):
        return self.__class__.__name__


class ScalerInfoData:
    PATH = os.path.join(ROOT_DIR, 'data/scaler_dict.json')

    @classmethod
    def save_data(cls, scaler_dict: Dict, data_path: str = None) -> None:
        info_data_path = os.path.join(data_path, 'scaler_dict.json') if data_path else cls.PATH
        if os.path.exists(info_data_path):
            os.remove(info_data_path)
        with open(info_data_path, 'w') as f:
            json.dump(scaler_dict, f)

    @classmethod
    def get_data(cls, data_path: str = None) -> Dict:
        info_data_path = os.path.join(data_path, 'scaler_dict.json') if data_path else cls.PATH
        if not os.path.exists(info_data_path):
            raise IOError('not exist scaler info file')
        with open(info_data_path, 'r') as f:
            scalers = json.load(f)
        return scalers

    @classmethod
    def del_data(cls) -> None:
        os.remove(cls.PATH)


def init_minmax_normalization_model(dataset_data_path: str = None) -> Dict:
    scaler_dict = ScalerInfoData.get_data(dataset_data_path)
    scaler_model_dict = {}
    for obj_type, obj_data_dict in scaler_dict.items():
        scaler_model_dict[obj_type] = {}
        numeric_dict = obj_data_dict.get('numeric')
        for numeric_feature_name, min_max_range in numeric_dict.items():
            data_min, data_max = min_max_range.get('min'), min_max_range.get('max')
            scaler_model_dict[obj_type][numeric_feature_name] = [CustomMinMaxScaler(data_min=data_min, data_max=data_max)]
        categorical_dict = obj_data_dict.get('categorical')
        for categorical_feature_name, label_list in categorical_dict.items():
            data_min, data_max = 0, len(label_list) - 1
            scaler_model_dict[obj_type][categorical_feature_name] = \
                [CustomLabelEncoder(classes=label_list), CustomMinMaxScaler(data_min=data_min, data_max=data_max)]
    return scaler_model_dict


def init_standard_normalization_model() -> Dict:
    scaler_dict = ScalerInfoData.get_data()
    scaler_model_dict = {}
    for obj_type, obj_data_dict in scaler_dict.items():
        scaler_model_dict[obj_type] = {}
        numeric_dict = obj_data_dict.get('numeric')
        for numeric_feature_name, min_max_range in numeric_dict.items():
            data_mean, data_std = min_max_range.get('mean'), min_max_range.get('std')
            scaler_func_list = [CustomStandardScaler(data_mean, data_std)]
            scaler_model_dict[obj_type][numeric_feature_name] = scaler_func_list
        categorical_dict = obj_data_dict.get('categorical')
        for categorical_feature_name, label_list in categorical_dict.items():
            data_min, data_max = 0, len(label_list) - 1
            scaler_model_dict[obj_type][categorical_feature_name] = \
                [CustomLabelEncoder(classes=label_list), CustomMinMaxScaler(data_min=data_min, data_max=data_max)]
    return scaler_model_dict


def get_feature_collection(feature_info):
    all_cols = feature_info.all_feature
    meta_attrs = feature_info.meta_attrs
    static_attrs = feature_info.static_attrs
    categorical_attrs = feature_info.categorical_attrs
    history_attrs = feature_info.history_attrs
    target_signal = feature_info.target_signal
    feature_cols = [col for col in all_cols if col not in meta_attrs]
    feature_map = {
        'static_feats_numeric': [
            col for col in feature_cols if col in static_attrs and col not in categorical_attrs],
        'static_feats_categorical': [
            col for col in feature_cols if col in static_attrs and col in categorical_attrs],
        'temporal_feats_numeric': [
            col for col in history_attrs if col not in categorical_attrs],
        'temporal_feats_categorical': [
            col for col in history_attrs if col in categorical_attrs],
        'target_feats_numeric': [
            col for col in target_signal if col not in categorical_attrs],
        'target_feats_categorical': [
            col for col in target_signal if col in categorical_attrs]

    }
    return feature_map


def add_feature_configuration(config: DictConfig) -> DictConfig:
    scalers = ScalerInfoData.get_data(data_path=config.get('dataset_data_path', None))
    sample_data_dict = {}
    structure_dict = {}
    output_target_len = {}
    obj_type_name_list = []
    for data_type, data_type_feature in config.feature_info.items():
        feature_map = get_feature_collection(data_type_feature)
        categorical_cardinalities = dict()
        for feature_name, label_list in scalers[data_type]['categorical'].items():
            categorical_cardinalities[feature_name] = len(label_list)
        sample_data_dict[data_type] = {
            "categorical_attrs": data_type_feature.categorical_attrs,
            'feature_cols': [col for col in data_type_feature.all_feature if col not in data_type_feature.meta_attrs],
            'feature_map': feature_map,
            'categorical_cardinalities': categorical_cardinalities,
            'output_target_len': len(data_type_feature.get('target_signal')),
            'target_signal': data_type_feature.get('target_signal'),
            'obj_id_list': []
        }
        structure = {
            'num_temporal_feats_numeric': len(feature_map['temporal_feats_numeric']),
            'num_temporal_feats_categorical': len(feature_map['temporal_feats_categorical']),
            'temporal_categorical_cardinalities': [categorical_cardinalities[feat] + 1 for feat in
                                                   feature_map['temporal_feats_categorical'] if len(categorical_cardinalities) > 0],
            'num_static_numeric': len(feature_map['static_feats_numeric']),
            'num_static_categorical': len(feature_map['static_feats_categorical']),
            'static_categorical_cardinalities': [categorical_cardinalities[feat] + 1 for feat in
                                                 feature_map['static_feats_categorical'] if len(categorical_cardinalities) > 0],
        }
        structure_dict[data_type] = structure
        output_target_len[data_type] = len(data_type_feature.get('target_signal'))
        obj_type_name_list.append(data_type)

    feature_config = OmegaConf.create()
    feature_config.output_target_len = output_target_len
    feature_config.data_props = structure_dict
    feature_config.obj_type_name_list = obj_type_name_list
    feature_config.sample_data = sample_data_dict
    feature_config.scalers = scalers
    config = OmegaConf.merge(config, feature_config)
    return config


def _geom_noise_mask_single(L, lm, masking_ratio):
    keep_mask = np.ones(L, dtype=bool)
    p_m = 1 / lm  # probability of each masking sequence stopping. parameter of geometric distribution.
    p_u = p_m * masking_ratio / (1 - masking_ratio)  # probability of each unmasked sequence stopping. parameter of geometric distribution.
    p = [p_m, p_u]

    # Start in state 0 with masking_ratio probability
    state = int(np.random.rand() > masking_ratio)  # state 0 means masking, 1 means not masking
    for i in range(L):
        keep_mask[i] = state  # here it happens that state and masking value corresponding to state are identical
        if np.random.rand() < p[state]:
            state = 1 - state

    return keep_mask


def noise_mask(X, masking_ratio=0.15, lm=3, mode='separate', distribution='geometric', exclude_feats=None):
    if exclude_feats is not None:
        exclude_feats = set(exclude_feats)

    if distribution == 'geometric':  # stateful (Markov chain)
        if mode == 'separate':  # each variable (feature) is independent
            mask = np.ones(X.shape, dtype=bool)
            for m in range(X.shape[1]):  # feature dimension
                if exclude_feats is None or m not in exclude_feats:
                    mask[:, m] = _geom_noise_mask_single(X.shape[0], lm, masking_ratio)  # time dimension
        else:  # replicate across feature dimension (mask all variables at the same positions concurrently)
            mask = np.tile(np.expand_dims(_geom_noise_mask_single(X.shape[0], lm, masking_ratio), 1), X.shape[1])
    else:  # each position is independent Bernoulli with p = 1 - masking_ratio
        if mode == 'separate':
            mask = np.random.choice(np.array([True, False]), size=X.shape, replace=True,
                                    p=(1 - masking_ratio, masking_ratio))
        else:
            mask = np.tile(np.random.choice(np.array([True, False]), size=(X.shape[0], 1), replace=True,
                                            p=(1 - masking_ratio, masking_ratio)), X.shape[1])

    return mask


def ssl_temporal_batch_mask(batch: torch.tensor, masks: torch.tensor, mask_values: float = -0.0000001) -> Tuple[torch.tensor, ...]:
    """

    @param batch:
    @param masks:
    @param mask_values:
    @return:
    """
    target_masks = masks.clone()
    targets = batch.clone()

    X = batch * target_masks
    target_masks = ~target_masks

    float_target_masks = target_masks.float()
    X = torch.where(float_target_masks == 1, mask_values, X)

    return X, targets, target_masks


def task_temporal_batch_mask(batch: torch.tensor, masks: torch.tensor, mask_values: float = -0.0000001) -> torch.tensor:
    """

    @param batch:
    @param masks:
    @param mask_values:
    @return:
    """
    batch_clone = batch.clone()
    masks_clone = masks.clone()
    masks_clone_expand = masks_clone.unsqueeze(0).expand(batch_clone.shape)

    X = batch_clone * masks_clone_expand
    masks_clone_expand = ~masks_clone_expand

    float_target_masks = masks_clone_expand.float()
    X = torch.where(float_target_masks == 1, mask_values, X)
    return X


def task_temporal_target_mask(batch: torch.tensor, masks: torch.tensor, task_target_len: int) -> torch.tensor:
    """

    @param batch:
    @param masks:
    @param task_target_len:
    @return:
    """
    batch_clone = batch.clone()
    masks_clone = masks.clone()
    target = batch_clone[:, masks_clone]
    target = target.view(batch.size(0), -1, task_target_len)
    return target


class EvalResultData:
    def __init__(self, project_id: str, result_name: str = None, dataset_data_path: str = None):
        self.project_id = project_id
        if result_name is None:
            if dataset_data_path is None:
                self.data_path = os.path.join(ROOT_DIR, f'output/{project_id}.pkl.gz')
            else:
                self.data_path = os.path.join(dataset_data_path, f'output/{project_id}.pkl.gz')
        else:
            if dataset_data_path is None:
                self.data_path = os.path.join(ROOT_DIR, f'output/{project_id}_{result_name}.pkl.gz')
            else:
                self.data_path = os.path.join(dataset_data_path, f'output/{project_id}_{result_name}.pkl.gz')

    def load_data(self):
        if not os.path.exists(self.data_path):
            raise FileNotFoundError(f'no file {self.data_path}')
        with gzip.open(self.data_path, 'rb') as f:
            data = pickle.load(f)
        return data

    def save_data(self, data):
        if not os.path.exists(os.path.dirname(self.data_path)):
            os.makedirs(os.path.dirname(self.data_path))
        with gzip.open(self.data_path, 'wb') as f:
            pickle.dump(data, f)


class NormalizationData:
    def __init__(self, project_id: str, object_type: str, normalization_data_path: str = None):
        base_path = normalization_data_path or os.path.join(ROOT_DIR, 'data')
        self.data_path = os.path.join(base_path, f'{project_id}/scaler/{object_type}.parquet')

    def load_data(self) -> pd.DataFrame:
        if not os.path.exists(self.data_path):
            return pd.DataFrame()
        else:
            normalization_data_df = pd.read_parquet(self.data_path, engine='pyarrow')
            return normalization_data_df

    def save_data(self, normalization_data_df: pd.DataFrame) -> None:
        normalization_data_df.to_parquet(self.data_path, engine='pyarrow', index=False)

    def del_data(self) -> None:
        try:
            os.remove(self.data_path)
        except Exception as e:
            raise ValueError(e)