# -*- coding: utf-8 -*-
import numpy as np
from omegaconf import DictConfig

from abc import ABC, abstractmethod
import torch

from common.log import access_log
from torch.utils.data import DataLoader
from hgtft.model.model import HeterogeneousGraphTemporalFusionTransformerTask
from hgtft.utils.train_utils import load_graph, load_iter_loaders
from hgtft.utils.data_utils import (make_batch, add_feature_configuration, DatasetData, ModelData, EvalResultData,
                                    ScalerInfoData, CustomMinMaxScaler, CustomLabelEncoder, DictDataSet)


class InferenceBase(ABC):
    def __init__(self, config: DictConfig):
        self.config = add_feature_configuration(config)

    @abstractmethod
    def _load_model(self, model_name, DEVICE):
        model = HeterogeneousGraphTemporalFusionTransformerTask(config=self.config, device=DEVICE)
        model_params = ModelData.load_model(model_name, DEVICE)
        model.load_state_dict(model_params)
        model.to(DEVICE)
        return model

    def _get_normalization_model_dict(self):
        scaler_dict = ScalerInfoData.get_data(dataset_data_path=self.config.get('dataset_data_path', None))
        scaler_model_dict = {}
        for obj_type, obj_data_dict in scaler_dict.items():
            scaler_model_dict[obj_type] = {}
            numeric_dict = obj_data_dict.get('numeric')
            for numeric_feature_name, min_max_range in numeric_dict.items():
                scaler_model_dict[obj_type][numeric_feature_name] = CustomMinMaxScaler(
                    data_min=min_max_range.get('min'), data_max=min_max_range.get('max'))
            categorical_dict = obj_data_dict.get('categorical')
            for categorical_feature_name, label_list in categorical_dict.items():
                scaler_model_dict[obj_type][categorical_feature_name] = CustomLabelEncoder(classes=label_list)
        return scaler_model_dict

    def _get_dataloader(self, project_id: str, data_type: str) -> DataLoader:
        """

        :param project_id:
        :param data_type:
        :return:
        """
        project_data = DatasetData.load_dataset(project_id, data_type, self.config.get('dataset_data_path', None))
        ignore_keys = ['time']
        obj_type_name_list = project_data.keys()
        full_data_dict = {}
        for obj_type in obj_type_name_list:
            data_sets = project_data[obj_type]['data_sets']
            for obj_id in project_data[obj_type]['obj_id_list']:
                obj_data_dict = data_sets[obj_id]
                for k, v in obj_data_dict.items():
                    if k not in ignore_keys:
                        full_data_dict[f"{obj_type}+{obj_id}+{k}"] = v
                    else:
                        continue
        dataset = DictDataSet(full_data_dict)
        if data_type == 'train':
            batch_size = self.config.configuration.optimization.batch_size.training
        else:
            batch_size = self.config.configuration.optimization.batch_size.inference
        data_loader = DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=False
        )
        return data_loader

    def main(self):
        IS_CUDA = torch.cuda.is_available()
        DEVICE = torch.device("cuda" if IS_CUDA else "cpu")
        model_name = self.config.model_name
        access_log.info(f"============== start inference, model name: {model_name} ==============")
        model = self._load_model(model_name, DEVICE)
        scaler_model_dict = self._get_normalization_model_dict()

        model.eval()
        with torch.no_grad():
            for project_id in self.config.project_validation:
                output_aggregator = dict()
                # load data
                graph_data_dict = load_graph(project_id, self.config.graph_relation)
                project_validation_data = DatasetData.load_dataset(project_id, 'validation')
                meta_keys = ['time']
                test_loader, loader_count = load_iter_loaders(self.config, project_validation_data, meta_keys,
                                                              data_type='validation')
                for batch_idx in range(loader_count):
                    batch = {}
                    next_dict = next(test_loader)
                    for key, values in next_dict.items():
                        key_list = key.split('+')
                        obj_type, obj_id, data_type = key_list[0], key_list[1], key_list[2]
                        batch = make_batch(batch, obj_type, obj_id, data_type, values, device=DEVICE, is_cuda=IS_CUDA)
                    batch_outputs = model(batch, graph_data_dict)

                    for obj_type, obj_type_data in batch_outputs.items():
                        output_aggregator.setdefault(obj_type, {})
                        for obj_id, obj_data in obj_type_data.items():
                            output_aggregator[obj_type].setdefault(obj_id, {})
                            output_aggregator[obj_type][obj_id].setdefault('predict', []).append(obj_data.cpu().numpy())
                            output_aggregator[obj_type][obj_id].setdefault('target', []).append(
                                batch[obj_type][obj_id]['target'].cpu().numpy()
                            )
                    access_log.info(f"-------------- {project_id}: {batch_idx} --------------")

                validation_outputs = dict()

                for obj_type in list(output_aggregator.keys()):
                    access_log.info(f'inverse transform: {obj_type}')
                    validation_outputs.setdefault(obj_type, {})

                    for obj_id in list(output_aggregator[obj_type].keys()):
                        validation_outputs[obj_type].setdefault(obj_id, {})

                        for index, signal in enumerate(self.config.feature_info[obj_type].target_signal):
                            validation_outputs[obj_type][obj_id].setdefault(signal, {})
                            target = np.concatenate(
                                output_aggregator[obj_type][obj_id]['target'], axis=0
                            )[:, :, index: index + 1]
                            predict = np.concatenate(
                                output_aggregator[obj_type][obj_id]['predict'], axis=0
                            )[:, :, index: index + 1, :]
                            transformation = scaler_model_dict[obj_type][signal]
                            validation_outputs[obj_type][obj_id][signal]['predict'] = [
                                transformation.inverse_transform(i) for i in predict.reshape(-1, 1)]
                            validation_outputs[obj_type][obj_id][signal]['target'] = [
                                transformation.inverse_transform(i) for i in target.reshape(-1, 1)]
                # save result
                eval_result_data_obj = EvalResultData(project_id)
                eval_result_data_obj.save_data(validation_outputs)
            access_log.info('============== finish predict ==============')