# -*- coding: utf-8 -*-
import copy
import gc
import torch
import numpy as np
from omegaconf import DictConfig

from common.log import access_log
from hgtft.model.model import HeterogeneousGraphTemporalFusionTransformerTask
from hgtft.utils.train_utils import load_graph
from hgtft.utils.data_utils import (task_temporal_batch_mask, task_temporal_target_mask, make_batch, ModelData,
                                    EvalResultData, init_minmax_normalization_model)

from hgtft.inference.inference_base import InferenceBase


class InferenceTaskManager(InferenceBase):
    def __init__(self, config: DictConfig):
        super().__init__(config)

    def _load_model(self, model_name, DEVICE):
        model = HeterogeneousGraphTemporalFusionTransformerTask(config=self.config, device=DEVICE)
        model_params = ModelData.load_model(model_name, DEVICE)
        model.load_state_dict(model_params)
        model.to(DEVICE)
        return model

    @staticmethod
    def _get_input_target_mask_tensor(task_history_feature_list, task_future_feature_list, task_target_feature_list,
                                      temporal_feats_numeric, temporal_feats_categorical, history_len, future_len):
        # input - numeric
        input_history_numeric_tensor_list = []
        for history_feature in temporal_feats_numeric:
            if history_feature in task_history_feature_list:
                input_history_numeric_tensor_list.append(torch.ones((history_len, 1), dtype=torch.bool))
            else:
                input_history_numeric_tensor_list.append(torch.zeros((history_len, 1), dtype=torch.bool))
        input_history_mask_numeric_tensor = torch.concat(input_history_numeric_tensor_list, dim=1)

        input_future_numeric_tensor_list = []
        for future_feature in temporal_feats_numeric:
            if future_feature in task_future_feature_list:
                input_future_numeric_tensor_list.append(torch.ones((future_len, 1), dtype=torch.bool))
            else:
                input_future_numeric_tensor_list.append(torch.zeros((future_len, 1), dtype=torch.bool))
        input_future_mask_numeric_tensor = torch.concat(input_future_numeric_tensor_list, dim=1)
        input_mask_numeric_tensor = torch.concat([input_history_mask_numeric_tensor, input_future_mask_numeric_tensor],
                                                 dim=0)

        # input - numeric
        input_history_categorical_tensor_list = []
        for history_feature in temporal_feats_categorical:
            if history_feature in task_history_feature_list:
                input_history_categorical_tensor_list.append(torch.ones((history_len, 1), dtype=torch.bool))
            else:
                input_history_categorical_tensor_list.append(torch.zeros((history_len, 1), dtype=torch.bool))
        input_history_mask_categorical_tensor = torch.concat(input_history_categorical_tensor_list, dim=1)

        input_future_categorical_tensor_list = []
        for future_feature in temporal_feats_categorical:
            if future_feature in task_future_feature_list:
                input_future_categorical_tensor_list.append(torch.ones((future_len, 1), dtype=torch.bool))
            else:
                input_future_categorical_tensor_list.append(torch.zeros((future_len, 1), dtype=torch.bool))
        input_future_mask_categorical_tensor = torch.concat(input_future_categorical_tensor_list, dim=1)
        input_mask_categorical_tensor = torch.concat(
            [input_history_mask_categorical_tensor, input_future_mask_categorical_tensor], dim=0)

        # target
        target_history_mask_tensor = torch.zeros((history_len, len(temporal_feats_numeric)), dtype=torch.bool)
        target_future_tensor_list = []
        for future_feature in temporal_feats_numeric:
            if future_feature in task_target_feature_list:
                target_future_tensor_list.append(torch.ones((future_len, 1), dtype=torch.bool))
            else:
                target_future_tensor_list.append(torch.zeros((future_len, 1), dtype=torch.bool))
        target_future_mask_tensor = torch.concat(target_future_tensor_list, dim=1)
        target_mask_tensor = torch.concat([target_history_mask_tensor, target_future_mask_tensor], dim=0)

        return input_mask_numeric_tensor, input_mask_categorical_tensor, target_mask_tensor

    def _get_mask_tensor(self):
        history_len = self.config.configuration.model.history_len
        future_len = self.config.configuration.model.future_len
        mask_result_dict = {}
        for task_name, task_item in self.config.task.items():
            task_feature_info = task_item.task_feature_info
            mask_result_dict[task_name] = {}
            for obj_type, obj_task_feature_dict in task_feature_info.items():
                # feature info
                temporal_feats_numeric = self.config.sample_data[obj_type]['feature_map']['temporal_feats_numeric']
                temporal_feats_categorical = self.config.sample_data[obj_type]['feature_map']['temporal_feats_categorical']
                # task feature
                task_future_feature_list = obj_task_feature_dict['future']
                task_history_feature_list = obj_task_feature_dict['history']
                task_target_feature_list = obj_task_feature_dict['target']
                input_mask_numeric_tensor, input_mask_categorical_tensor, target_mask_tensor = (
                    self._get_input_target_mask_tensor(task_history_feature_list, task_future_feature_list,
                                                       task_target_feature_list, temporal_feats_numeric,
                                                       temporal_feats_categorical, history_len, future_len
                                                       )
                )
                mask_result_dict[task_name][obj_type] = {
                    "input_numeric": input_mask_numeric_tensor,
                    "input_categorical": input_mask_categorical_tensor,
                    "target": target_mask_tensor
                }
        return mask_result_dict

    def main(self):
        IS_CUDA = torch.cuda.is_available()
        DEVICE = torch.device("cuda" if IS_CUDA else "cpu")
        model_name = self.config.model_name
        access_log.info(f"============== start predict, model name: {model_name} ==============")

        model = self._load_model(model_name, DEVICE)
        task_mask_tensor_dict = self._get_mask_tensor()

        model.eval()
        with torch.no_grad():
            for test_name in self.config.test_data:
                test_graph_data = load_graph(test_name, self.config.graph_relation,
                                             dataset_data_path=self.config.get('dataset_data_path', None))
                test_loader = self._get_dataloader(test_name, 'test')
                output_aggregator = dict()
                for test_i, next_dict in enumerate(test_loader):
                    access_log.info(f'========== {len(test_loader)}-{test_i+1} ==========')
                    for task_name, task_item in self.config.task.items():
                        task_feature_info_dict = task_item.task_feature_info

                        if task_name not in self.config.use_task:
                            continue

                        output_aggregator.setdefault(task_name, {})

                        task_graph_data_dict = {}
                        for item in task_item.task_graph_relation:
                            file_name = "{}-{}-{}".format(item.get('from'), item.get('edge'), item.get('to'))
                            if file_name in test_graph_data:
                                task_graph_data_dict[file_name] = test_graph_data.get(file_name)
                        test_batch_input, test_batch_target = dict(), dict()
                        for key, values in next_dict.items():
                            key_list = key.split('+')
                            obj_type, obj_id, data_type = key_list[0], key_list[1], key_list[2]
                            if obj_type not in list(task_feature_info_dict.keys()):
                                continue
                            if data_type == 'temporal_feats_numeric':
                                input_value_mask = task_mask_tensor_dict[task_name][obj_type]['input_numeric']
                                X = task_temporal_batch_mask(values, input_value_mask)
                                test_batch_input = make_batch(test_batch_input, obj_type, obj_id, data_type, X,
                                                              device=DEVICE, is_cuda=IS_CUDA)
                                task_target_len = len(task_item.task_feature_info[obj_type]['target'])
                                if task_target_len > 0:
                                    target_value_mask = task_mask_tensor_dict[task_name][obj_type]['target']
                                    target = task_temporal_target_mask(values, target_value_mask, task_target_len)
                                    test_batch_target = make_batch(test_batch_target, obj_type, obj_id, data_type,
                                                                   target, device=DEVICE, is_cuda=IS_CUDA)
                            elif data_type == 'temporal_feats_categorical':
                                value_mask = task_mask_tensor_dict[task_name][obj_type]['input_categorical']
                                X = task_temporal_batch_mask(values, value_mask)
                                test_batch_input = make_batch(test_batch_input, obj_type, obj_id, data_type, X,
                                                              device=DEVICE, is_cuda=IS_CUDA)
                            else:
                                test_batch_input = make_batch(test_batch_input, obj_type, obj_id, data_type, values,
                                                              device=DEVICE, is_cuda=IS_CUDA)
                        with torch.cuda.amp.autocast():
                            batch_outputs = model(test_batch_input, task_graph_data_dict, task_name)
                        for obj_type, obj_type_data in batch_outputs.items():
                            output_aggregator[task_name].setdefault(obj_type, {})
                            for obj_id, obj_data in obj_type_data.items():
                                output_aggregator[task_name][obj_type].setdefault(obj_id, {})

                                output_aggregator[task_name][obj_type][obj_id].setdefault(
                                    'predict', []).append(obj_data.cpu().numpy())

                                output_aggregator[task_name][obj_type][obj_id].setdefault('target', []).append(
                                    test_batch_target[obj_type][obj_id]['temporal_feats_numeric'].unsqueeze(-1).cpu().numpy())

                out_put_data_dict = self._inverse_transform(output_aggregator)
                eval_result_data_obj = EvalResultData(test_name)
                eval_result_data_obj.save_data(out_put_data_dict)

                del output_aggregator
                del out_put_data_dict
                gc.collect()

    def _inverse_transform(self, output_aggregator_dict):
        access_log.info('start inverse transform')
        try:
            out_put_data_dict = {}
            scaler_model_dict = init_minmax_normalization_model(self.config.get('dataset_data_path'))
            for task_name, task_item in self.config.task.items():

                if task_name not in self.config.use_task:
                    continue

                task_feature_info_dict = task_item.get('task_feature_info')
                out_put_data_dict.setdefault(task_name, {})
                for object_type, task_obj_feature_info_dict in task_feature_info_dict.items():
                    if object_type not in output_aggregator_dict.get(task_name).keys():
                        continue
                    out_put_data_dict[task_name].setdefault(object_type, {})
                    if len(task_obj_feature_info_dict.get('target')) > 0:

                        for obj_id, obj_data in output_aggregator_dict.get(task_name).get(object_type).items():
                            out_put_data_dict[task_name][object_type].setdefault(obj_id, {})

                            for predict_type, predict_data in obj_data.items():
                                out_put_data_dict[task_name][object_type][obj_id].setdefault(predict_type, {})

                                for feature_index, feature_name in enumerate(task_obj_feature_info_dict.get('target')):
                                    scaler_model_list = scaler_model_dict.get(object_type).get(feature_name)
                                    scaler_model_list_copy = copy.deepcopy(scaler_model_list)
                                    scaler_model_list_copy.reverse()
                                    predict_data_list = np.concatenate(predict_data, axis=0)[:, :, feature_index, :].flatten().tolist()
                                    for scaler_model in scaler_model_list_copy:
                                        predict_data_list = [scaler_model.inverse_transform(round(i, 7)) for i in predict_data_list]
                                    out_put_data_dict[task_name][object_type][obj_id][predict_type][feature_name] = predict_data_list
            return out_put_data_dict
        except Exception as e:
            access_log.error(e)
