import torch.nn.functional as F
import torch
import torch.nn as nn
from parser_1 import _parser
from models.layers import GraphConvNew
import numpy as np
from torch.nn import TransformerEncoder, TransformerEncoderLayer

args = _parser()

class ResidualGraphConvNew(nn.Module):
    def __init__(self, in_features, out_features, n_relations, activation=nn.ReLU(inplace=True), adj_sq=False, scale_identity=False):
        super(ResidualGraphConvNew, self).__init__()
        self.gconv = GraphConvNew(in_features, out_features, n_relations, activation, adj_sq, scale_identity)
        self.residual = nn.Linear(in_features, out_features) if in_features != out_features else nn.Identity()
        self.activation = activation

    def forward(self, data):
        x, A, mask = data
        out = self.gconv(data) + self.residual(x)
        return self.activation(out), A, mask

class TransformerEncoderModule(torch.nn.Module):
    def __init__(self, embed_size, heads, num_layers, forward_expansion):
        super(TransformerEncoderModule, self).__init__()
        self.encoder_layer = TransformerEncoderLayer(d_model=embed_size, nhead=heads, dim_feedforward=embed_size*forward_expansion, dropout=0.1)
        self.transformer_encoder = TransformerEncoder(self.encoder_layer, num_layers=num_layers)

    def forward(self, src):
        out = self.transformer_encoder(src)
        return out

class dot_attention(nn.Module):
    """ 点积注意力机制"""

    def __init__(self, attention_dropout=0.5):
        super(dot_attention, self).__init__()
        self.dropout = nn.Dropout(attention_dropout)
        self.softmax = nn.Softmax(dim=2)

    def forward(self, q, k, v, scale=None, attn_mask=None):
        """
        前向传播
        :param q:
        :param k:
        :param v:
        :param scale:
        :param attn_mask:
        :return: 上下文张量和attention张量。
        """
        attention = torch.bmm(q, k.transpose(1, 2))
        if scale:
            attention = attention * scale        # 是否设置缩放
        if attn_mask:
            attention = attention.masked_fill(attn_mask, -np.inf)     # 给需要mask的地方设置一个负无穷。
        # 计算softmax
        attention = self.softmax(attention)
        # 添加dropout
        attention = self.dropout(attention)
        # 和v做点积。
        context = torch.bmm(attention, v)
        return context, attention

class MultiHeadAttention(nn.Module):
    """ 多头自注意力"""
    def __init__(self, model_dim=400, num_heads=4, dropout=0.5):
        super(MultiHeadAttention, self).__init__()

        self.dim_per_head = model_dim//num_heads   # 每个头的维度
        self.num_heads = num_heads
        self.linear_k = nn.Linear(model_dim, self.dim_per_head * num_heads)
        self.linear_v = nn.Linear(model_dim, self.dim_per_head * num_heads)
        self.linear_q = nn.Linear(model_dim, self.dim_per_head * num_heads)

        self.dot_product_attention = dot_attention(dropout)

        self.linear_final = nn.Linear(model_dim, model_dim)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(model_dim)         # LayerNorm 归一化。

    def forward(self, key, value, query, attn_mask=None):
        # 残差连接
        residual = query

        dim_per_head = self.dim_per_head
        num_heads = self.num_heads
        batch_size = key.size(0)

        # 线性映射。
        key = self.linear_k(key)
        value = self.linear_v(value)
        query = self.linear_q(query)

        # 按照头进行分割
        key = key.view(batch_size * num_heads, -1, dim_per_head)
        value = value.view(batch_size * num_heads, -1, dim_per_head)
        query = query.view(batch_size * num_heads, -1, dim_per_head)

        if attn_mask:
            attn_mask = attn_mask.repeat(num_heads, 1, 1)

        # 缩放点击注意力机制
        scale = (key.size(-1) // num_heads) ** -0.5
        context, attention = self.dot_product_attention(query, key, value, scale, attn_mask)

        # 进行头合并 concat heads
        context = context.view(batch_size, -1, dim_per_head * num_heads)

        # 进行线性映射
        output = self.linear_final(context)

        # dropout
        output = self.dropout(output)

        # 添加残差层和正则化层。
        output = self.layer_norm(residual + output)

        return output, attention

class EdgePredictionNN(nn.Module):
    def __init__(self, in_features, n_hidden_edge = 32):
        super(EdgePredictionNN, self).__init__()
        self.edge_pred = nn.Sequential(
            nn.Linear(in_features * 2, n_hidden_edge),
            nn.ReLU(inplace=True),
            nn.BatchNorm1d(n_hidden_edge),  # 添加批量归一化层
            nn.Dropout(0.5),  # 添加Dropout层，假设dropout率为0.5
            nn.Linear(n_hidden_edge, n_hidden_edge // 2),  # 增加一个隐藏层
            nn.ReLU(inplace=True),
            nn.Linear(n_hidden_edge // 2, 1)
        )

    def forward(self, x):
        return self.edge_pred(x)


class ClusterLayer(nn.Module):
    def __init__(self, input_dim, cluster_num):
        super(ClusterLayer, self).__init__()
        self.cluster_num = cluster_num
        self.cluster_centers = nn.Parameter(torch.randn(cluster_num, input_dim))

    def forward(self, x):
        B, N, C = x.size()
        x_expand = x.unsqueeze(2).expand(B, N, self.cluster_num, C)
        centers_expand = self.cluster_centers.unsqueeze(0).unsqueeze(0).expand(B, N, self.cluster_num, C)
        distances = torch.sum((x_expand - centers_expand) ** 2, dim=3)
        assignments = torch.argmin(distances, dim=2)
        return assignments

class MGCN(nn.Module):
    '''
    Multigraph Convolutional Network
    '''

    def __init__(self,
                 in_features,
                 out_features,
                 n_relations,
                 filters=args.filters,
                 n_hidden=args.n_hidden,
                 n_hidden_edge=32,
                 dropout=args.dropout,
                 adj_sq=False,
                 scale_identity=False):
        super(MGCN, self).__init__()

        # Graph convolution layers
        self.gconv = nn.Sequential(*([GraphConvNew(in_features=in_features if layer == 0 else filters[layer - 1],
                                                out_features=f,
                                                n_relations=n_relations,
                                                activation=nn.ReLU(inplace=True),
                                                adj_sq=adj_sq,
                                                scale_identity=scale_identity) for layer, f in enumerate(filters)]))
        # Cluster Layer
        self.cluster_layer = ClusterLayer(filters[-1], cluster_num=10)

        # Edge prediction NN
        self.edge_pred = nn.Sequential(
            nn.Linear(in_features * 2, n_hidden_edge),
            nn.ReLU(inplace=True),
            nn.BatchNorm1d(n_hidden_edge),  # 添加批量归一化层
            nn.Dropout(0.5),  # 添加Dropout层，假设dropout率为0.5
            nn.Linear(n_hidden_edge, n_hidden_edge // 2),  # 增加一个隐藏层
            nn.ReLU(inplace=True),
            nn.Linear(n_hidden_edge // 2, 1)
        )

        # Fully connected layers
        fc = []
        if dropout > 0:
            fc.append(nn.Dropout(p=dropout))
        if n_hidden > 0:
            fc.append(nn.Linear(filters[-1], n_hidden))
            if dropout > 0:
                fc.append(nn.Dropout(p=dropout))
            n_last = n_hidden
        else:
            n_last = filters[-1]
        fc.append(nn.Linear(n_last, out_features))
        self.fc = nn.Sequential(*fc)

    def forward(self, data):
        # data: [node_features, A, graph_support, N_nodes, label]

        # Predict edges based on features
        # print("data[0]:", data[0])
        # data[0]:[batch_size, N, in_features] N是节点数量，in_features是每个节点的特征数量
        x = data[0]

        B, N, C = x.shape
        mask = data[2]
        # find indices of nodes
        x_cat, idx = [], []
        for b in range(B):
            n = int(mask[b].sum())
            node_i = torch.nonzero(mask[b]).repeat(1, n).view(-1, 1)
            node_j = torch.nonzero(mask[b]).repeat(n, 1).view(-1, 1)
            triu = (node_i < node_j).squeeze()  # skip loops and symmetric connections
            x_cat.append(torch.cat((x[b, node_i[triu]], x[b, node_j[triu]]), 2).view(int(torch.sum(triu)), C * 2))
            idx.append((node_i * N + node_j)[triu].squeeze())

        x_cat = torch.cat(x_cat)
        idx_flip = np.concatenate((np.arange(C, 2 * C), np.arange(C)))

        # predict values and encourage invariance to nodes order
        y = torch.exp(0.5 * (self.edge_pred(x_cat) + self.edge_pred(x_cat[:, idx_flip])).squeeze())

        A_pred = torch.zeros(B, N * N, device=args.device)
        c = 0
        for b in range(B):
            A_pred[b, idx[b]] = y[c:c + idx[b].nelement()]
            c += idx[b].nelement()
        A_pred = A_pred.view(B, N, N)
        A_pred = (A_pred + A_pred.permute(0, 2, 1))  # assume undirected edges

        # Use both annotated and predicted adjacency matrices to learn a GCN
        data = (x, torch.stack((data[1], A_pred), 3), mask)
        #print(f"data[0]:{data[0]},data[0].shape:{data[0].shape}")   # ([20, 1767, 64])
        x = self.gconv(data)[0]

        #print(f"x=self.gconv(data)[0]:{x},x=self.gconv(data)[0].shape:{x.shape}") # ([20, 1767, 64])

        mdim = x.shape[2]
        #print("mdin:",mdim)
        q = k = v = x
        mutil_head_attention = MultiHeadAttention(model_dim=mdim)
        x, attention = mutil_head_attention(q, k, v)

        # Add clustering step
        assignments = self.cluster_layer(x)

        print("assignments:",assignments.shape)

        x = torch.max(x, dim=1)[0].squeeze()  # max pooling over nodes

        # print(f"x = torch.max(x, dim=1)[0].squeeze():{x},x = torch.max(x, dim=1)[0].squeeze().shape:{x.shape}") # ([20, 64])
        x = self.fc(x)
        # print(x.shape)
        # print(f"x = self.fc(x):{x},x = self.fc(x).shape:{x.shape}") # ([20, 1])
        return x
