# -*- coding: utf-8 -*-
import random
from typing import List
from collections import OrderedDict

import torch
from torch import nn
import torch.nn.functional as F
from omegaconf import DictConfig


class GATLayer(nn.Module):
    """
    GAT
    """
    def __init__(self, in_dim: int, out_dim: int):
        """
        :param in_dim: input dimension
        :param out_dim: output dimension
        """
        super(GATLayer, self).__init__()
        self.W = nn.Linear(in_dim, out_dim, bias=False)
        self.a = nn.Linear(2 * out_dim, 1, bias=False)

    def forward(self, target_node, neighbor_node_list):
        """
        目标节点、邻居节点的数据维度是 batch_size * time_step * d
        :param target_node:
        :param neighbor_node_list:
        :return:
        """
        if neighbor_node_list.size(1) > 1:
            # formula 1，upgrade in the first dimension
            zi = self.W(target_node)
            zj = self.W(neighbor_node_list)
            # formula 2，calculate similarity value
            zij = torch.cat([zi.repeat(1, neighbor_node_list.size(1), 1, 1), zj], dim=3)
            eij = F.leaky_relu(self.a(zij), negative_slope=0.2)
            # formula 3，calculate weights
            alpha = F.softmax(eij, dim=1)
            # formula 4，get new feature
            h = torch.sum(alpha * zj, dim=1)
        else:
            zj = self.W(neighbor_node_list)
            h = torch.sum(zj, dim=1)
        return h

class IntraRelation(nn.Module):
    def __init__(self, in_dim, out_dim, num_heads, merge='mean'):
        super(IntraRelation, self).__init__()
        self.num_heads = num_heads
        self.heads = nn.ModuleList([GATLayer(in_dim, out_dim) for _ in range(num_heads)])
        self.merge = merge

    def forward(self, target_node, neighbor_node_list):
        head_outs = [self.heads[i](target_node, neighbor_node_list) for i in range(self.num_heads)]
        if self.merge == 'cat':
            return torch.cat(head_outs, dim=1)
        else:
            return torch.stack(head_outs, dim=2).mean(dim=2)


class InterRelation(nn.Module):
    def __init__(self, in_dim: int):
        """
        :param in_dim: int, input dimension
        """
        super(InterRelation, self).__init__()
        hidden_dim = in_dim // 2
        self.project = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, 1, bias=False)
        )

    def forward(self, h):
        w = self.project(h)
        w_mean = w.mean(dim=(0, 1))
        beta = torch.softmax(w_mean, dim=0)
        beta = beta.expand((h.shape[0], h.shape[1],) + beta.shape)
        h_R = (beta * h).sum(2)
        return h_R


class HeterogeneousTemporalGraphNeuralNetwork(nn.Module):
    def __init__(self, config: DictConfig, graph_relation_list: List, device: torch.device):
        super(HeterogeneousTemporalGraphNeuralNetwork, self).__init__()
        self.dst_obj_list = list(set([i.split('-')[-1] for i in graph_relation_list]))
        self.state_size = config.configuration.model.state_size
        self.dropout = config.configuration.model.dropout
        self.device = device
        self.num_time_steps = config.configuration.model.history_len + config.configuration.model.future_len

        self.intra_attention_heads = 1 if config.configuration.model.intra_attention_heads is None else config.configuration.model.intra_attention_heads

        self.intra_relation = nn.ModuleDict({
            key: IntraRelation(config.configuration.model.state_size, config.configuration.model.state_size, self.intra_attention_heads) for key
            in graph_relation_list
        })

        self.inter_relation = nn.ModuleDict({
            key: InterRelation(self.state_size) for key
            in self.dst_obj_list
        })

    def forward(self, x, graph):
        intra_features = dict(
            {obj_type: {} for obj_type in self.dst_obj_list}
        )
        for obj_type in intra_features.keys():
            for graph_type, graph_data_obj in graph.items():
                src_obj_type_name = graph_type.split('-')[0]
                dst_obj_type_name = graph_type.split('-')[-1]
                if obj_type == dst_obj_type_name:
                    graph_data = graph_data_obj.graph_data
                    for obj_id, obj_embedding_result in x[obj_type].items():
                        if intra_features[obj_type].get(obj_id, None) is None:
                            intra_features[obj_type][obj_id] = list()
                        neighbor_obj_id_list = graph_data.loc[obj_id, graph_data.loc[obj_id] == 1].index.tolist()
                        if len(neighbor_obj_id_list) > 0:
                            neighbor_node_data_list = torch.stack(
                                [torch.cat(
                                    [x[src_obj_type_name][str(neighbor_obj_id)]['selected_historical'],
                                     x[src_obj_type_name][str(neighbor_obj_id)]['selected_future']],
                                    dim=1)
                                    for neighbor_obj_id in neighbor_obj_id_list]
                                , dim=1)
                            intra_features[obj_type][obj_id].append(self.intra_relation[graph_type](
                                torch.cat(
                                    [obj_embedding_result['selected_historical'],
                                     obj_embedding_result['selected_future']],
                                    dim=1).unsqueeze(1), neighbor_node_data_list
                            ))
                        else:
                            intra_features[obj_type][obj_id].append(
                                torch.zeros((x[dst_obj_type_name][obj_id]['selected_historical'].shape[0], self.num_time_steps, self.state_size)).to(self.device)
                            )

        inter_features = dict(
            {obj_type: {} for obj_type in self.dst_obj_list}
        )
        for obj_type in inter_features.keys():
            intra_obj_data = intra_features[obj_type]
            intra_obj_data_list = []
            obj_id_list = []
            for obj_id, obj_data in intra_obj_data.items():
                obj_id_list.append(obj_id)
                intra_obj_data_list.append(torch.stack(obj_data, dim=1))
            if len(intra_obj_data_list) > 0:
                intra_obj_time_data_tensor = torch.stack(intra_obj_data_list, dim=1)
                inter_result = self.inter_relation[obj_type](intra_obj_time_data_tensor)
                for index, obj_id in enumerate(obj_id_list):
                    inter_features[obj_type][obj_id] = inter_result[:, index, :, :]
        return inter_features


class HeterogeneousGraphNeuralNetwork(nn.Module):
    def __init__(self, config: DictConfig, graph_relation_list: List, device: torch.device):
        super(HeterogeneousGraphNeuralNetwork, self).__init__()
        self.config = config
        self.embed_d = config.configuration.model.state_size
        self.time_step = config.configuration.model.history_len + config.configuration.model.future_len
        self.device = device

        self.arr_types = []
        for item in graph_relation_list:
            fr, _, too = item.split('-')
            self.arr_types.extend([fr, too])
        self.arr_types = list(set(self.arr_types))

        self.neigh_rnn = nn.ModuleDict(
            {k: nn.LSTM(self.embed_d, int(self.embed_d / 2), 1, bidirectional=True).to(device) for k in self.arr_types})

        params = OrderedDict([
            (k, nn.Parameter(nn.init.normal_(torch.empty(self.embed_d * 2, 1, device=device), mean=0.0, std=1.0)))
            for k in self.arr_types
        ])
        self.neigh_att = nn.ParameterDict(params)


        self.softmax = nn.Softmax(dim=1)
        self.act = nn.LeakyReLU()
        # self.drop = nn.Dropout(p=0.5)
        self.bn = nn.BatchNorm1d(self.embed_d)


    @staticmethod
    def gen_object(graph, src_embedding_result):
        # 获取实际测试样例id
        all_ids = []
        for key in src_embedding_result:
            all_ids.extend(list(src_embedding_result[key].keys()))
        max_count = {}
        rel = {}
        for k, v in graph.items():
            if (not list(set(all_ids).intersection(set(v.src_type_id))) or
                    not list(set(all_ids).intersection(set(v.dst_type_id)))):
                continue
            p1, p2 = k.index('-'), k.rfind('-')
            from_type, to_type = k[:p1], k[p2 + 1:]
            df = v.graph_data
            if to_type not in rel:
                rel[to_type] = {}


            if to_type not in max_count:
                max_count[to_type] = {}
            max_count[to_type].update({from_type: 0})
            for s2 in v.dst_type_id:
                if s2 not in all_ids:
                    continue
                if s2 not in rel[to_type]:
                    rel[to_type][s2] = {from_type: []}
                s_df = df.loc[s2, df.loc[s2] == 1]
                arr = []
                if not s_df.empty:
                    arr = [item for item in s_df.index.tolist() if item in all_ids]
                if arr:
                    rel[to_type][s2][from_type] = arr
                    if len(arr) > max_count[to_type][from_type]:
                        max_count[to_type][from_type] = len(arr)
        return rel


    def forward(self, src_embedding_result, graph):
        rel = self.gen_object(graph, src_embedding_result)
        rel_keys = list(rel.keys())
        random.shuffle(rel_keys)
        for object_type in rel_keys:
            for idx, v in rel[object_type].items():
                if not rel[object_type][idx]:
                    continue
                src_embedding_result[object_type][idx]['selected_temporal'] = self.node_het_agg(
                    idx, object_type, src_embedding_result, rel)

        return src_embedding_result
    def node_neigh_agg(self, id_batch, node_type, embedding_result):

        # with torch.no_grad():
        neigh_agg = torch.cat([embedding_result[node_type][idx]['selected_temporal'] for p_index, idx in enumerate(id_batch)], 0).view(len(id_batch), -1, self.embed_d)

        all_state, _ = self.neigh_rnn[node_type](neigh_agg)

        mean_all_state = torch.mean(all_state, 0).view(-1, self.embed_d)

        return mean_all_state

    def node_het_agg(self, idx, node_type, embedding_result, rel):
        agg_tong_neigh_batch = {}
        for nei_type, neighbors in rel[node_type][idx].items():
            if not neighbors:
                continue
            agg_tong_neigh_batch[nei_type] = self.node_neigh_agg(neighbors, nei_type, embedding_result)

        # src_batch = self.batch_size * self.time_step
        c_agg_batch = embedding_result[node_type][idx]['selected_temporal'].contiguous().view(-1, self.embed_d)

        # c_agg_batch = torch.stack(c_agg_batch)
        agg_concat_neigh_batch = [torch.cat((c_agg_batch, c_agg_batch), 1)]
        for k, v in agg_tong_neigh_batch.items():
            agg_concat_neigh_batch.append(torch.cat((c_agg_batch, v), 1))

        count_agg = len(agg_concat_neigh_batch)
        concat_embed = torch.cat(agg_concat_neigh_batch, 1).view(-1, count_agg, self.embed_d * 2)
        attention_w = self.act(torch.matmul(concat_embed, self.neigh_att[node_type].unsqueeze(0)))
        attention_w = self.softmax(attention_w).view(-1, 1, count_agg)

        # weighted combination
        concat_embed = torch.cat([c_agg_batch] + list(agg_tong_neigh_batch.values()), 1).view(-1, count_agg, self.embed_d)
        return torch.matmul(attention_w, concat_embed).view(-1, self.time_step, self.embed_d)


class HeterogeneousGraphTransformer(nn.Module):
    def __init__(self, config: DictConfig, graph_relation_list: List, device: torch.device):
        super(HeterogeneousGraphTransformer, self).__init__()
        self.config = config
        self.state_size = config.configuration.model.state_size
        self.data_props_dict = self.config.data_props
        self.time_steps = config.configuration.model.future_len + config.configuration.model.history_len

        self.num_relations = len(self.graph)

        self.k_linear_dict = nn.ModuleDict({
            obj_type: nn.Sequential(nn.Linear(self.state_size, self.state_size), nn.BatchNorm1d(self.num_steps), nn.LeakyReLU()) for obj_type in config.obj_type_name_list
        })
        self.q_linear_dict = nn.ModuleDict({
            obj_type: nn.Sequential(nn.Linear(self.state_size, self.state_size), nn.BatchNorm1d(self.num_steps), nn.LeakyReLU()) for obj_type in config.obj_type_name_list
        })
        self.v_linear_dict = nn.ModuleDict({
            obj_type: nn.Sequential(nn.Linear(self.state_size, self.state_size), nn.BatchNorm1d(self.num_steps), nn.LeakyReLU()) for obj_type in config.obj_type_name_list
        })
        self.a_linear_dict = nn.ModuleDict({
            obj_type: nn.Sequential(nn.Linear(self.state_size, self.state_size), nn.BatchNorm1d(self.num_steps), nn.LeakyReLU()) for obj_type in config.obj_type_name_list
        })

        self.relation_pri_dict = nn.ParameterDict({
            relation_type: nn.Parameter(torch.randn(1)) for relation_type in self.graph.keys()
        })
        self.relation_att_dict = nn.ModuleDict({
            relation_type: nn.Sequential(nn.Linear(self.state_size, self.state_size), nn.BatchNorm1d(self.num_steps), nn.LeakyReLU()) for relation_type in
            self.graph.keys()
        })
        self.relation_msg_dict = nn.ModuleDict({
            relation_type: nn.Sequential(nn.Linear(self.state_size, self.state_size), nn.BatchNorm1d(self.num_steps),
                                         nn.LeakyReLU()) for relation_type in self.graph.keys()
        })
        self.skip = nn.ParameterDict({
            obj_type: nn.Parameter(torch.randn(1)) for obj_type, _ in
            self.data_props_dict.items() if config.configuration.model.output_target_len[obj_type] > 0
        })
        self.norm_dict = nn.ModuleDict({
            obj_type: nn.LayerNorm(self.state_size) for obj_type, _ in
            self.data_props_dict.items() if config.configuration.model.output_target_len[obj_type] > 0
        })

        self.project_dict = nn.ModuleDict({
            obj_type: nn.Sequential(
                nn.Linear(self.state_size, self.state_size // 2), nn.Tanh(),
                nn.Linear(self.state_size // 2, 1, bias=False)
            ) for obj_type, _ in self.data_props_dict.items() if config.configuration.model.output_target_len[obj_type] > 0
        })

    def forward(self, embedding_result):
        obj_type_list = [obj_type for obj_type, out_put_target in self.config.configuration.model.output_target_len.items() if out_put_target > 0]
        for obj_type in obj_type_list:
            for obj_id, obj_embedding_result in embedding_result[obj_type].items():
                res_msg_list = []
                res_att_list = []
                for graph_type, graph_data_obj in self.graph.items():
                    src_obj_type_name = graph_type.split('_')[0]
                    dst_obj_type_name = graph_type.split('_')[1]
                    if obj_type == dst_obj_type_name:
                        graph_data = graph_data_obj.graph_data
                        neighbor_obj_id_list = graph_data.loc[
                            str(obj_id), graph_data.loc[str(obj_id)] == 1].index.tolist()
                        if len(neighbor_obj_id_list) > 0:
                            for neighbor_obj_id in neighbor_obj_id_list:
                                """
                                Step 1
                                """
                                dis_obj_data = torch.cat(
                                    [embedding_result[obj_type][str(obj_id)]['selected_historical'],
                                     embedding_result[obj_type][str(obj_id)]['selected_future']], dim=1)
                                src_obj_data = torch.cat(
                                    [embedding_result[obj_type][str(obj_id)]['selected_historical'],
                                     embedding_result[src_obj_type_name][str(neighbor_obj_id)]['selected_future']], dim=1)
                                q_mat = self.q_linear_dict[obj_type](dis_obj_data)
                                src_obj_k = self.k_linear_dict[src_obj_type_name](dis_obj_data)
                                relation_w_att = self.relation_att_dict[graph_type]
                                src_obj_k_w = relation_w_att(src_obj_k)
                                att_set = torch.mul(src_obj_k_w, q_mat)
                                res_att = torch.mul(att_set, self.relation_pri_dict[graph_type])
                                res_att_list.append(res_att)
                                """
                                Step 2
                                """
                                src_obj_v = self.v_linear_dict[src_obj_type_name](src_obj_data)
                                relation_w_msg = self.relation_msg_dict[graph_type]
                                res_msg = relation_w_msg(src_obj_v)
                                res_msg_list.append(res_msg)
                        else:
                            continue
                if len(res_att_list) > 0:
                    """
                    Step 3
                    """
                    obj_att = torch.stack(res_att_list, dim=1)
                    obj_att = self.project_dict[obj_type](obj_att)
                    obj_att_mean = obj_att.mean(dim=0)

                    obj_msg = torch.stack(res_msg_list, dim=1)
                    att = F.softmax(obj_att_mean, dim=0)
                    att = att.expand((obj_msg.shape[0],) + att.shape)

                    out_put = torch.mul(obj_msg, att)
                    out_put_sum = torch.sum(out_put, dim=1)
                    out_put_sum = F.gelu(out_put_sum)
                    trans_out = self.a_linear_dict[obj_type](out_put_sum)
                    alpha = torch.sigmoid(self.skip[obj_type])

                    original_data = torch.cat(
                        [embedding_result[obj_type][str(obj_id)]['selected_historical'],
                         embedding_result[obj_type][str(obj_id)]['selected_future']], dim=1)

                    hgt_result = self.norm_dict[obj_type](
                        trans_out * alpha + original_data * (1 - alpha)
                    )
                    embedding_result[obj_type][str(obj_id)]['selected_historical'] = hgt_result[:,
                                                                                     :self.config.configuration.model.history_len, :]
                    embedding_result[obj_type][str(obj_id)]['selected_future'] = hgt_result[:,
                                                                                     self.config.configuration.model.future_len:, :]
        return embedding_result
