# -*- coding: utf-8 -*-
import torch
from torch import nn
import torch.nn.functional as F
from omegaconf import DictConfig
from typing import Dict, Tuple
from .base_temporal import PositionalEmbedding, CheckpointTransformerEncoderLayer
from .base_tft import FeatureEmbeddingLayer, MultiHeadAttentionLayer, TaskOutputLayer
from .base_graph import (HeterogeneousGraphTransformer, HeterogeneousGraphNeuralNetwork,
                        HeterogeneousTemporalGraphNeuralNetwork)


def concat_embedding_result(embedding_result: Dict) -> Tuple[torch.Tensor, Dict]:
    start = 0
    obj_order_dict = {}
    embedding_concat_result = []
    for obj_type, obj_data in embedding_result.items():
        obj_order_dict.setdefault(obj_type, {})
        obj_id_list = []
        obj_order_dict[obj_type]['start'] = start
        batch_size = 0
        for obj_id, obj_embedding_data in obj_data.items():
            if isinstance(obj_embedding_data, dict):
                obj_embedding_data = obj_embedding_data.get('selected_temporal')
            obj_id_list.append(obj_id)
            embedding_concat_result.append(obj_embedding_data)
            batch_size = obj_embedding_data.size(0)
            start += 1
        obj_order_dict[obj_type]['end'] = start
        obj_order_dict[obj_type]['obj_id'] = obj_id_list
        obj_order_dict[obj_type]['batch_size'] = batch_size
    embedding_concat_result = torch.concat(embedding_concat_result, dim=0)
    return embedding_concat_result, obj_order_dict


class FeatureEmbedding(nn.Module):
    def __init__(self, config: DictConfig):
        super(FeatureEmbedding, self).__init__()
        self.feature_embedding_layer_dict = nn.ModuleDict({
            obj_type: FeatureEmbeddingLayer(config, structure, obj_type) for obj_type, structure in config.data_props.items()
        })

    def forward(self, batch):
        embedding_result = {}
        for obj_type, obj_data_dict in batch.items():
            obj_id_list = []
            embedding_result.setdefault(obj_type, {})
            if len(obj_data_dict) > 1:
                obj_data_static_feats_numeric_list = []
                obj_data_static_feats_categorical_list = []
                obj_data_temporal_feats_numeric_list = []
                obj_data_temporal_feats_categorical_list = []
                concat_obj_data = {}
                for obj_id, obj_data in obj_data_dict.items():
                    obj_id_list.append(obj_id)
                    obj_data_static_feats_numeric_list.append(obj_data.get('static_feats_numeric'))
                    obj_data_static_feats_categorical_list.append(obj_data.get('static_feats_categorical'))
                    obj_data_temporal_feats_numeric_list.append(obj_data.get('temporal_feats_numeric'))
                    obj_data_temporal_feats_categorical_list.append(obj_data.get('temporal_feats_categorical'))
                concat_obj_data['static_feats_numeric'] = torch.cat(obj_data_static_feats_numeric_list, dim=0)
                concat_obj_data['static_feats_categorical'] = torch.cat(obj_data_static_feats_categorical_list, dim=0)
                concat_obj_data['temporal_feats_numeric'] = torch.cat(obj_data_temporal_feats_numeric_list, dim=0)
                concat_obj_data['temporal_feats_categorical'] = torch.cat(obj_data_temporal_feats_categorical_list,
                                                                          dim=0)
                group_embedding_result = self.feature_embedding_layer_dict[obj_type](concat_obj_data)
                batch_sample, _, _ = group_embedding_result.shape
                batch_size = batch_sample // len(obj_id_list)
                for i, obj_id in enumerate(obj_id_list):
                    embedding_result[obj_type][str(obj_id)] = group_embedding_result[
                                                              (i * batch_size): (i + 1) * batch_size, ...]
            else:
                for obj_id, obj_data in obj_data_dict.items():
                    embedding_result[obj_type][str(obj_id)] = self.feature_embedding_layer_dict[obj_type](obj_data)

        embedding_result, obj_order_dict = concat_embedding_result(embedding_result)
        return embedding_result, obj_order_dict


class HeterogeneousGraphEmbedding(nn.Module):
    """

        -- HeterogeneousGraphTransformer, https://github.com/acbull/pyHGT
        -- HeterogeneousTemporalGraphNeuralNetwork, https://github.com/YesLab-Code/HTGNN.
        -- HeterogeneousGraphNeuralNetwork, https://github.com/chuxuzhang/KDD2019_HetGNN
    """
    def __init__(self, config: DictConfig, device: torch.device):
        super(HeterogeneousGraphEmbedding, self).__init__()
        embedding_name = config.configuration.model.graph_embedding_type
        graph_relation_list = ["{}-{}-{}".format(i.get('from'), i.get('edge'), i.get('to')) for i in
                               config.graph_relation]
        if embedding_name == 'hgt':
            self.base_graph_embedding = HeterogeneousGraphTransformer(config, graph_relation_list, device)
        elif embedding_name == 'htgnn':
            self.base_graph_embedding = HeterogeneousTemporalGraphNeuralNetwork(config, graph_relation_list, device)
        elif embedding_name == 'hgn':
            self.base_graph_embedding = HeterogeneousGraphNeuralNetwork(config, graph_relation_list, device)
        else:
            ValueError(f'not exit {self.embedding_name} graph embedding function')

    def forward(self, x, graph):
        return self.base_graph_embedding(x, graph)


class TemporalEmbedding(nn.Module):
    def __init__(self, config, activation='gelu'):
        super(TemporalEmbedding, self).__init__()
        dropout = config.configuration.model.dropout
        attention_heads = config.configuration.model.attention_heads
        state_size = config.configuration.model.state_size
        dim_feedforward = state_size * 4
        self.temporal_embedding = nn.ModuleDict({obj_type: nn.TransformerEncoderLayer(state_size, attention_heads,
                                                                                      dim_feedforward, dropout,
                                                                                      activation=activation,
                                                                                      batch_first=True) for
                                                 obj_type, structure in config.data_props.items()})

    def forward(self, x, obj_order_dict, mask=None):
        temporal_embedding_result = []
        for obj_type, obj_info_dict in obj_order_dict.items():
            batch_start, batch_end, obj_id_list, batch_size = (obj_info_dict.get('start'), obj_info_dict.get('end'),
                                                               obj_info_dict.get('obj_id'), obj_info_dict.get('batch_size'))
            temporal_embedding_result.append(
                self.temporal_embedding[obj_type](x[batch_start*batch_size: batch_end*batch_size, ...], src_mask=mask)
            )
        temporal_embedding_result = torch.concat(temporal_embedding_result, dim=0)
        return temporal_embedding_result


class TemporalStackedEncoder(nn.Module):
    def __init__(self, config: DictConfig, activation='gelu'):
        super(TemporalStackedEncoder, self).__init__()
        state_size = config.configuration.model.state_size
        attention_heads = config.configuration.model.attention_heads
        dropout = config.configuration.model.dropout
        dim_feedforward = state_size * 4
        stacked_encoder_layers = config.configuration.model.stacked_encoder_layers
        num_time_steps = config.configuration.model.history_len + config.configuration.model.future_len

        self.activation = F.relu if activation == "relu" else F.gelu
        self.positional_embedding = PositionalEmbedding(state_size, num_time_steps, max_timescale=10000)
        self.stacked_encoder_checkpoint_layers = config.configuration.model.stacked_encoder_checkpoint_layers

        self.encoder_layers = nn.ModuleList()
        for _ in range(stacked_encoder_layers):
            self.encoder_layers.append(
                CheckpointTransformerEncoderLayer(state_size, attention_heads, dim_feedforward, dropout, activation,
                                                  batch_first=True)
            )

    def forward(self, x, mask=None) -> torch.Tensor:
        encoder_result = self.positional_embedding(x)
        for layer_index, layer in enumerate(self.encoder_layers):
            if layer_index < self.stacked_encoder_checkpoint_layers:
                use_check_point = True
            else:
                use_check_point = False
            encoder_result = layer(encoder_result, src_mask=mask, use_check_point=use_check_point)
        return encoder_result


class MaskedInterpretableMultiHeadAttention(nn.Module):
    def __init__(self, config: DictConfig):
        super(MaskedInterpretableMultiHeadAttention, self).__init__()
        self.multi_head_attention_layer_dict = nn.ModuleDict({
            obj_type: MultiHeadAttentionLayer(config) for obj_type, _ in config.data_props.items()
        })

    def forward(self, x, obj_order_dict):
        multi_head_attention_result = []
        for obj_type, obj_info_dict in obj_order_dict.items():
            batch_start, batch_end, obj_id_list, batch_size = (obj_info_dict.get('start'), obj_info_dict.get('end'),
                                                               obj_info_dict.get('obj_id'), obj_info_dict.get('batch_size'))
            multi_head_attention_result.append(
                self.multi_head_attention_layer_dict[obj_type](x[batch_start*batch_size: batch_end*batch_size, ...])
            )
        multi_head_attention_result = torch.concat(multi_head_attention_result, dim=0)
        return multi_head_attention_result
# ======================================================================================================================
class TaskOutput(nn.Module):
    def __init__(self,config: DictConfig, task_item_dict: DictConfig):
        super(TaskOutput, self).__init__()
        task_feature_info = task_item_dict.task_feature_info
        self.tasks = nn.ModuleDict({
            obj_type: TaskOutputLayer(config, obj_data) for obj_type, obj_data in task_feature_info.items() if
            len(obj_data.target)
        })

    def forward(self, x, obj_order_dict):
        out_put = {}
        for obj_type, obj_task in self.tasks.items():
            if obj_type not in obj_order_dict.keys():
                continue

            out_put[obj_type] = {}
            batch_start, batch_end, obj_id_list, batch_size = (obj_order_dict[obj_type]['start'], obj_order_dict[obj_type]['end'],
                                                               obj_order_dict[obj_type]['obj_id'], obj_order_dict[obj_type]['batch_size'])

            predict_result = obj_task(x[batch_start*batch_size: batch_end*batch_size, ...])
            for obj_index, object_id in enumerate(obj_id_list):
                out_put[obj_type][object_id] = predict_result[(obj_index * batch_size): ((obj_index + 1) * batch_size), ...]
        return out_put


class HeterogeneousGraphTemporalFusionTransformerTask(nn.Module):
    def __init__(self, config: DictConfig, device: torch.device):
        super(HeterogeneousGraphTemporalFusionTransformerTask, self).__init__()

        self.config = config
        self.device = device
        self.num_time_steps = config.configuration.model.history_len + config.configuration.model.future_len

        self.feature_embedding = FeatureEmbedding(config)

        self.graph_embedding = HeterogeneousGraphEmbedding(config, device)

        self.temporal_embedding_first = TemporalEmbedding(config)
        self.temporal_embedding_second = TemporalEmbedding(config)
        self.temporal_embedding_third = TemporalEmbedding(config)

        state_size = config.configuration.model.state_size
        self.positional_embedding = PositionalEmbedding(state_size, self.num_time_steps, max_timescale=10000)

        self.multi_head_attention = MaskedInterpretableMultiHeadAttention(config)

        self.task_output = nn.ModuleDict({
            task_name: TaskOutput(config, task_item) for task_name, task_item in config.task.items()
        })

    def forward(self, batch, graph, task_name):
        # =========== feature embedding ==============
        embedding_result, obj_order_dict = self.feature_embedding(batch)

        # =========== temporal mask ==============
        temporal_mask = torch.triu(torch.ones(self.num_time_steps, self.num_time_steps, device=self.device),
                                   diagonal=1).bool()

        # =========== temporal embedding ==============
        embedding_result = self.temporal_embedding_first(embedding_result, obj_order_dict, mask=temporal_mask)

        # =========== graph embedding ===============
        graph_embedding_result = {}
        for obj_type, obj_info_dict in obj_order_dict.items():
            graph_embedding_result.setdefault(obj_type, {})
            batch_start, batch_end, obj_id_list, batch_size = (obj_info_dict.get('start'), obj_info_dict.get('end'),
                                                               obj_info_dict.get('obj_id'), obj_info_dict.get('batch_size'))
            obj_embedding_result = embedding_result[batch_start * batch_size: batch_end * batch_size, ...]
            for obj_index, obj_id in enumerate(obj_id_list):
                graph_embedding_result[obj_type][obj_id] = {"selected_temporal": obj_embedding_result[obj_index*batch_size: (obj_index+1)*batch_size, ...]}

        # embedding_result = self.graph_embedding(embedding_result, graph)
        graph_embedding_result = self.graph_embedding(graph_embedding_result, graph)
        embedding_result, obj_order_dict = concat_embedding_result(graph_embedding_result)

        # =========== temporal embedding ==============
        embedding_result = self.positional_embedding(embedding_result)
        embedding_result = self.temporal_embedding_second(embedding_result, obj_order_dict, mask=temporal_mask)
        embedding_result = self.temporal_embedding_third(embedding_result, obj_order_dict, mask=temporal_mask)

        # =========== multi head attention ==============
        embedding_result = self.multi_head_attention(embedding_result, obj_order_dict)

        # =========== task ==============
        out_put = self.task_output[task_name](embedding_result, obj_order_dict)

        return out_put
