import torch
import torch.nn as nn
import torch.nn.functional as F

class GlobalBranch(nn.Module):
    def __init__(self, input_dim, hidden1_dim, hidden2_dim, hidden3_dim, output_dim, device, type, reduction):
        super(GlobalBranch, self).__init__()
        self.device = device
        self.type = type
        self.reduction = reduction
        self.layer1_output = None
        self.layer2_output = None

        # 定义 GCN 层
        self.gcn1 = GCNLayer(input_dim, hidden1_dim)
        self.gcn2 = GCNLayer(hidden1_dim, hidden2_dim)

        # 定义全连接层
        self.tf_linear1 = nn.Linear(hidden2_dim, hidden3_dim)
        self.target_linear1 = nn.Linear(hidden2_dim, hidden3_dim)

        self.tf_linear2 = nn.Linear(hidden3_dim, output_dim)
        self.target_linear2 = nn.Linear(hidden3_dim, output_dim)

        if self.type == 'MLP':
            self.linear = nn.Linear(2 * output_dim, 2)

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.tf_linear1.weight, gain=1.414)
        nn.init.xavier_uniform_(self.target_linear1.weight, gain=1.414)
        nn.init.xavier_uniform_(self.tf_linear2.weight, gain=1.414)
        nn.init.xavier_uniform_(self.target_linear2.weight, gain=1.414)

    def encode(self, x, adj):
        # 全部转为稠密矩阵
        adj_dense = adj.to_dense()
        degree = torch.sum(adj_dense, dim=1)
        degree_sqrt_inv = torch.pow(degree, -0.5)
        degree_sqrt_inv[degree_sqrt_inv == float('inf')] = 0
        D_inv = torch.diag(degree_sqrt_inv)
        adj_norm = D_inv @ adj_dense @ D_inv

        # 第一层 GCN
        layer1_output = self.gcn1(x, adj_norm)
        layer1_output = F.relu(layer1_output)
        self.layer1_output = layer1_output

        # 第二层 GCN
        layer2_output = self.gcn2(layer1_output, adj_norm)
        layer2_output = F.relu(layer2_output)
        self.layer2_output = layer2_output
        return layer2_output

    def decode(self, tf_embed, target_embed):
        if self.type == 'dot':
            prob = torch.mul(tf_embed, target_embed)
            prob = torch.sum(prob, dim=1).view(-1, 1)
            return prob

        elif self.type == 'cosine':
            prob = torch.cosine_similarity(tf_embed, target_embed, dim=1).view(-1, 1)
            return prob

        elif self.type == 'MLP':
            h = torch.cat([tf_embed, target_embed], dim=1)
            prob = self.linear(h)
            return prob

        else:
            raise TypeError(r'{} is not available'.format(self.type))

    def forward(self, x, adj, train_sample):
        embed = self.encode(x, adj)

        tf_embed = self.tf_linear1(embed)
        tf_embed = F.leaky_relu(tf_embed)
        tf_embed = F.dropout(tf_embed, p=0.01)
        tf_embed = self.tf_linear2(tf_embed)
        tf_embed = F.leaky_relu(tf_embed)

        target_embed = self.target_linear1(embed)
        target_embed = F.leaky_relu(target_embed)
        target_embed = F.dropout(target_embed, p=0.01)
        target_embed = self.target_linear2(target_embed)
        target_embed = F.leaky_relu(target_embed)

        self.tf_ouput = tf_embed
        self.target_output = target_embed

        train_tf = tf_embed[train_sample[:, 0]]
        train_target = target_embed[train_sample[:, 1]]

        pred = self.decode(train_tf, train_target)

        return pred

    def get_embedding(self):
        return self.tf_ouput, self.target_output

    def get_layer_embeddings(self):
        return self.layer1_output, self.layer2_output



class GCNLayer(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(GCNLayer, self).__init__()
        self.linear = nn.Linear(input_dim, output_dim)

    def forward(self, x, adj_norm):
        # adj_norm 已经是归一化后的稠密邻接矩阵
        support = self.linear(x)
        output = torch.mm(adj_norm, support)
        return output