import torch.nn.functional as F
import torch
import torch.nn as nn
import numpy as np
from models.layers import GraphConvNew
from models.confromer import ConformerBlock

class dot_attention(nn.Module):
    """ 点积注意力机制"""

    def __init__(self, attention_dropout=0.5):
        super(dot_attention, self).__init__()
        self.dropout = nn.Dropout(attention_dropout)
        self.softmax = nn.Softmax(dim=2)

    def forward(self, q, k, v, scale=None, attn_mask=None):
        """
        前向传播
        :param q:
        :param k:
        :param v:
        :param scale:
        :param attn_mask:
        :return: 上下文张量和attention张量。
        """
        attention = torch.bmm(q, k.transpose(1, 2))
        if scale:
            attention = attention * scale        # 是否设置缩放
        if attn_mask:
            attention = attention.masked_fill(attn_mask, -np.inf)     # 给需要mask的地方设置一个负无穷。
        # 计算softmax
        attention = self.softmax(attention)
        # 添加dropout
        attention = self.dropout(attention)
        # 和v做点积。
        context = torch.bmm(attention, v)
        return context, attention

class MultiHeadAttention(nn.Module):
    """ 多头自注意力"""
    def __init__(self, model_dim=400, num_heads=4, dropout=0.0):
        super(MultiHeadAttention, self).__init__()

        self.dim_per_head = model_dim//num_heads   # 每个头的维度
        self.num_heads = num_heads
        self.linear_k = nn.Linear(model_dim, self.dim_per_head * num_heads)
        self.linear_v = nn.Linear(model_dim, self.dim_per_head * num_heads)
        self.linear_q = nn.Linear(model_dim, self.dim_per_head * num_heads)

        self.dot_product_attention = dot_attention(dropout)

        self.linear_final = nn.Linear(model_dim, model_dim)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(model_dim)         # LayerNorm 归一化。

    def forward(self, key, value, query, attn_mask=None):
        # 残差连接
        residual = query

        dim_per_head = self.dim_per_head
        num_heads = self.num_heads
        batch_size = key.size(0)

        # 线性映射。
        key = self.linear_k(key)
        value = self.linear_v(value)
        query = self.linear_q(query)

        # 按照头进行分割
        key = key.view(batch_size * num_heads, -1, dim_per_head)
        value = value.view(batch_size * num_heads, -1, dim_per_head)
        query = query.view(batch_size * num_heads, -1, dim_per_head)

        if attn_mask:
            attn_mask = attn_mask.repeat(num_heads, 1, 1)

        # 缩放点击注意力机制
        scale = (key.size(-1) // num_heads) ** -0.5
        context, attention = self.dot_product_attention(query, key, value, scale, attn_mask)

        # 进行头合并 concat heads
        context = context.view(batch_size, -1, dim_per_head * num_heads)

        # 进行线性映射
        output = self.linear_final(context)

        # dropout
        output = self.dropout(output)

        # 添加残差层和正则化层。
        output = self.layer_norm(residual + output)

        return output, attention

class ClusterLayer(nn.Module):
    def __init__(self, input_dim, cluster_num):
        super(ClusterLayer, self).__init__()
        self.cluster_num = cluster_num
        self.cluster_centers = nn.Parameter(torch.randn(cluster_num, input_dim))

    def forward(self, x):
        B, N, C = x.size()
        x_expand = x.unsqueeze(2).expand(B, N, self.cluster_num, C)
        centers_expand = self.cluster_centers.unsqueeze(0).unsqueeze(0).expand(B, N, self.cluster_num, C)
        distances = torch.sum((x_expand - centers_expand) ** 2, dim=3)
        assignments = torch.argmin(distances, dim=2)
        return assignments


class GNNModel(nn.Module):
    def __init__(self, in_features, out_features, hidden_dim, dropout=0.001, n_hidden_edge=32, adj_sq=False, scale_identity=False, filters='64,64,64', n_hidden=256):
        super(GNNModel, self).__init__()

        if isinstance(filters, str):
            filters = list(map(int, filters.split(',')))

        self.gc1 = GraphConvNew(in_features, hidden_dim, activation=F.relu)
        self.gc2 = GraphConvNew(hidden_dim, hidden_dim)
        self.gc3 = GraphConvNew(hidden_dim, 32)  # 添加第三个GCN层
        self.dropout = nn.Dropout(dropout)

        # 添加聚类层
        self.cluster_layer = ClusterLayer(hidden_dim, cluster_num=10)
        self.multihead_attn = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=8, dropout=dropout)

        # 添加ConformerBlock
        self.conformer_block = ConformerBlock(
            dim=hidden_dim,
            dim_head=64,
            heads=8,
            ff_mult=4,
            conv_expansion_factor=2,
            conv_kernel_size=31,
            attn_dropout=0.0,
            ff_dropout=dropout,
            conv_dropout=dropout,
        )

        self.edge_pred = nn.Sequential(
            nn.Linear(in_features * 2, n_hidden_edge),
            nn.ReLU(inplace=True),
            nn.BatchNorm1d(n_hidden_edge),  # 添加批量归一化层
            nn.Dropout(0.5),  # 添加Dropout层，假设dropout率为0.5
            nn.Linear(n_hidden_edge, n_hidden_edge // 2),  # 增加一个隐藏层
            nn.ReLU(inplace=True),
            nn.Linear(n_hidden_edge // 2, 1)
        )

        self.fc = nn.Sequential(
            nn.Linear(32, 16),  # 聚类后的特征输入维度是 10 * hidden_dim，输出维度是 128
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(16, out_features)  # 输出维度是 out_features
        )

    def forward(self, data):
        x = data[0]
        B, N, C = x.shape
        mask = data[2]

        # 边预测
        x_cat, idx = [], []
        for b in range(B):
            n = int(mask[b].sum())
            node_i = torch.nonzero(mask[b]).repeat(1, n).view(-1, 1)
            node_j = torch.nonzero(mask[b]).repeat(n, 1).view(-1, 1)
            triu = (node_i < node_j).squeeze()
            x_cat.append(torch.cat((x[b, node_i[triu]], x[b, node_j[triu]]), 2).view(int(torch.sum(triu)), C * 2))
            idx.append((node_i * N + node_j)[triu].squeeze())

        x_cat = torch.cat(x_cat)
        idx_flip = np.concatenate((np.arange(C, 2 * C), np.arange(C)))

        y = torch.exp(0.5 * (self.edge_pred(x_cat) + self.edge_pred(x_cat[:, idx_flip])).squeeze())

        A_pred = torch.zeros(B, N * N, device=x.device)
        c = 0
        for b in range(B):
            A_pred[b, idx[b]] = y[c:c + idx[b].nelement()]
            c += idx[b].nelement()
        A_pred = A_pred.view(B, N, N)
        A_pred = (A_pred + A_pred.permute(0, 2, 1))

        # 聚类层
        x = self.gc1((x, A_pred, mask))[0]
        x = self.dropout(x)
        assignments = self.cluster_layer(x)

        pooled_features = []
        for b in range(B):
            cluster_features = []
            for cluster in range(self.cluster_layer.cluster_num):
                cluster_mask = (assignments[b] == cluster)
                if cluster_mask.sum() > 0:
                    cluster_feature = torch.max(x[b][cluster_mask], dim=0)[0]
                    cluster_features.append(cluster_feature)
                else:
                    cluster_features.append(torch.zeros(x.size(-1), device=x.device))
            cluster_features = torch.stack(cluster_features, dim=0)
            pooled_features.append(cluster_features)

        x = torch.stack(pooled_features, dim=0)

        # 调整邻接矩阵
        new_N = x.size(1)
        A_pred = F.interpolate(A_pred.unsqueeze(1), size=(new_N, new_N), mode='bilinear', align_corners=False).squeeze(1)
        mask = F.interpolate(mask.unsqueeze(1).float(), size=(new_N,), mode='nearest').long().squeeze(1)

        x, A, mask = (x, torch.stack((A_pred, A_pred), 3), mask)

        x = self.gc2((x, A, mask))[0]
        x = x * mask.unsqueeze(-1)# mask 的维度是 (batch_size, num_nodes)，通过 unsqueeze(-1) 转变为 (batch_size, num_nodes, 1) 以便与 x 相乘

        # 添加ConformerBlock
        x = self.conformer_block(x)
        x = self.dropout(x)
        x = self.gc3((x, A, mask))[0]
        x = torch.max(x, dim=1)[0]
        # 确保输入到全连接层的维度是正确的
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x