import torch
from torch import nn
from torch.nn import Sequential as Seq, Linear as Lin
from utils.data_util import get_atom_feature_dims, get_bond_feature_dims
import numpy as np

def create_label_induced_negative_graph_sparse(y, train_mask=None):
    # 获取节点数量
    # breakpoint()
    num_nodes = y.size(0)
    
    # 如果未提供train_mask，创建一个全为True的mask
    if train_mask is None:
        train_mask = torch.ones(num_nodes, dtype=torch.bool, device=y.device)
    
    # 获取所有可能的节点对组合的索引
    i = torch.arange(num_nodes, device=y.device)
    row_indices, col_indices = torch.meshgrid(i, i)
    
    # 应用train_mask过滤
    mask = train_mask[row_indices] & train_mask[col_indices]
    row_indices = row_indices[mask]
    col_indices = col_indices[mask]
    
    # 比较节点标签，生成稀疏矩阵的值
    values = (y[row_indices] != y[col_indices]).float()
    values = 2 * values - 1  # 将True/False转换为1/-1
    
    # 创建稀疏张量
    indices = torch.stack([row_indices, col_indices])  # 稀疏张量需要的索引形式
    size = torch.Size([num_nodes, num_nodes])  # 稀疏张量的大小
    
    label_matrix_sparse = torch.sparse.FloatTensor(indices, values, size)
    
    return label_matrix_sparse

##############################
#    Basic layers
##############################
class PairNorm(nn.Module):
    def __init__(self, scale=1.0):
        super(PairNorm, self).__init__()
        self.scale = scale

    def forward(self, x):
        mean = x.mean(dim=1, keepdim=True)
        std = x.std(dim=1, keepdim=True)
        return self.scale * (x - mean) / std

class ContraNorm(nn.Module):
    def __init__(self, scale=1.0, tau=1.0):
        super(ContraNorm, self).__init__()
        self.scale = scale
        self.tau = tau

    def forward(self, x, adj):
        norm_x = nn.functional.normalize(x, dim=1)
        sim = torch.spmm(norm_x.T , norm_x  )/ self.tau
        # if adj.size(1) == 2:
        #     sim[adj[0], adj[1]] = -np.inf
        # else:
        #     sim.masked_fill_(adj.to_dense() > 1e-5, -np.inf)
        sim = nn.functional.softmax(sim, dim=1)
        x_neg =  x @ sim
        x = (1 + self.scale) * x - self.scale * x_neg
        return x

class Sign(nn.Module):
    def __init__(self, nc, scale=1.0, tau=1.0):
        super(Sign, self).__init__()
        self.scale = scale
        self.tau = tau
        self.layer = nn.LayerNorm(nc, elementwise_affine=True)

    def forward(self, x, h0):
        norm_x = nn.functional.normalize(h0, dim=1)
        sim = - torch.mm( norm_x.T,norm_x)
        sim = nn.functional.softmax(sim, dim=1)
        x_neg =  x @ sim    
        x = x - x_neg 
        x = self.layer(x)
        return x

class Label(nn.Module):
    def __init__(self, scale=1.0, tau=1.0):
        super(Label, self).__init__()
        self.scale = scale
        self.tau = tau

    def forward(self, x, adj, y, train_mask):
        # neg = create_label_induced_negative_graph_sparse( y, train_mask)
        # neg = nn.functional.softmax(neg.to_dense(), dim=1)
        x_neg = torch.spmm(neg, x)    
        x = x - x_neg
        return x

def act_layer(act_type, inplace=False, neg_slope=0.2, n_prelu=1):
    # activation layer
    act = act_type.lower()
    if act == 'relu':
        layer = nn.ReLU(inplace)
    elif act == 'leakyrelu':
        layer = nn.LeakyReLU(neg_slope, inplace)
    elif act == 'prelu':
        layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
    else:
        raise NotImplementedError('activation layer [%s] is not found' % act)
    return layer


def norm_layer(norm_type, nc, scale=0.0):
    # normalization layer 1d
    norm = norm_type.lower()
    if norm == 'batch':
        layer = nn.BatchNorm1d(nc, affine=True)
    elif norm == 'layer':
        layer = nn.LayerNorm(nc, elementwise_affine=True)
    elif norm == 'instance':
        layer = nn.InstanceNorm1d(nc, affine=False)
    elif norm == 'none':
        layer = None
    elif norm == 'pair':
        layer = PairNorm()
    elif norm == 'contra':
        layer = ContraNorm(scale=scale)
    elif norm == 'sign':
        layer = Sign(nc, scale=scale)
    elif norm == 'label':
        layer = nn.LayerNorm(nc, elementwise_affine=True)
    else:
        raise NotImplementedError('normalization layer [%s] is not found' % norm)
    return layer


class MultiSeq(Seq):
    def __init__(self, *args):
        super(MultiSeq, self).__init__(*args)

    def forward(self, *inputs):
        for module in self._modules.values():
            if type(inputs) == tuple:
                inputs = module(*inputs)
            else:
                inputs = module(inputs)
        return inputs


class MLP(Seq):
    def __init__(self, channels, act='relu',
                 norm=None, bias=True,
                 drop=0., last_lin=False):
        m = []

        for i in range(1, len(channels)):

            m.append(Lin(channels[i - 1], channels[i], bias))

            if (i == len(channels) - 1) and last_lin:
                pass
            else:
                if norm is not None and norm.lower() != 'none':
                    m.append(norm_layer(norm, channels[i]))
                if act is not None and act.lower() != 'none':
                    m.append(act_layer(act))
                if drop > 0:
                    m.append(nn.Dropout2d(drop))

        self.m = m
        super(MLP, self).__init__(*self.m)


class AtomEncoder(nn.Module):

    def __init__(self, emb_dim):
        super(AtomEncoder, self).__init__()

        self.atom_embedding_list = nn.ModuleList()
        full_atom_feature_dims = get_atom_feature_dims()

        for i, dim in enumerate(full_atom_feature_dims):
            emb = nn.Embedding(dim, emb_dim)
            nn.init.xavier_uniform_(emb.weight.data)
            self.atom_embedding_list.append(emb)

    def forward(self, x):
        x_embedding = 0
        for i in range(x.shape[1]):
            x_embedding += self.atom_embedding_list[i](x[:, i])

        return x_embedding


class BondEncoder(nn.Module):

    def __init__(self, emb_dim):
        super(BondEncoder, self).__init__()

        self.bond_embedding_list = nn.ModuleList()
        full_bond_feature_dims = get_bond_feature_dims()

        for i, dim in enumerate(full_bond_feature_dims):
            emb = nn.Embedding(dim, emb_dim)
            nn.init.xavier_uniform_(emb.weight.data)
            self.bond_embedding_list.append(emb)

    def forward(self, edge_attr):
        bond_embedding = 0
        for i in range(edge_attr.shape[1]):
            bond_embedding += self.bond_embedding_list[i](edge_attr[:, i])

        return bond_embedding


