import numpy as np
import torch.nn as nn
import torch
import torch.nn.functional as F


class Attention(nn.Module):
    def __init__(self, args):
        super(Attention, self).__init__()
        self.args = args
        self.A = nn.ModuleList([nn.Linear(args.hidden_nodes, 1) for _ in range(args.graph_num)])
        # print("self.A:",self.A)
        if self.args.addVector:
            self.B = nn.ModuleList([nn.Linear(args.hidden_nodes + args.n, 1) for _ in range(args.graph_num)])
        self.weight_init()

    def weight_init(self):
        for i in range(self.args.graph_num):
            nn.init.xavier_normal_(self.A[i].weight)
            self.A[i].bias.data.fill_(0.0)

    def forward(self, feat_pos, feat_neg):
        feat_pos, feat_pos_attn = self.attn_feature(feat_pos)
        feat_neg, feat_neg_attn = self.attn_feature(feat_neg)

        return feat_pos, feat_neg


    def attn_feature(self, features):
        # print("features[0].shape:", features[0].shape)
        features_attn = []
        for i in range(self.args.graph_num):
            features_attn.append((self.A[i](features[i].squeeze())))
        # print("features_attn[0].shape:", features_attn[0].shape)
        features_attn = F.softmax(torch.cat(features_attn, 1), -1)
        # print("features_attn.shape:", features_attn.shape)
        features = torch.cat(features,1).squeeze(0)
        # print("features.shape:", features.shape)
        features_attn_reshaped = features_attn.transpose(1, 0).contiguous().view(-1, 1)
        # print("features_attn_reshaped.shape:", features_attn_reshaped.shape)
        features = features * features_attn_reshaped.expand_as(features)
        # print("features.shape:", features.shape)
        features = features.view(self.args.graph_num, self.args.nb_nodes, self.args.hidden_nodes).sum(0).unsqueeze(0)

        return features, features_attn


# import torch.nn as nn
#
#
# class Attention(nn.Module):
#     def __init__(self, num_heads, dim):
#         super().__init__()
#         # Q, K, V 转换矩阵，这里假设输入和输出的特征维度相同
#         self.q = nn.Linear(dim, dim)
#         self.k = nn.Linear(dim, dim)
#         self.v = nn.Linear(dim, dim)
#         self.num_heads = num_heads
#
#     def forward(self, x):
#         B, N, C = x[0].shape
#         # 生成转换矩阵并分多头
#         q = self.q(x[0]).reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
#         k = self.k(x[1]).reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
#         v = self.k(x[2]).reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
#
#         # 点积得到attention score
#         attn = q @ k.transpose(2, 3) * (x[0].shape[-1] ** -0.5)
#         attn = attn.softmax(dim=-1)
#
#         # 乘上attention score并输出
#         v = (attn @ v).permute(0, 2, 1, 3).reshape(B, N, C)
#         return v
