import torch
import torch.nn as nn 
from torch_sparse import spmm, matmul
from torch import Tensor
from torch_geometric.nn.conv import MessagePassing

import numpy as np


def create_label_induced_negative_graph(y, train_mask=None):
    label_matrix = ~(y.unsqueeze(1) == y.unsqueeze(0))
    label_matrix = 2*label_matrix.float() -1 
    if train_mask != None:
        label_matrix[~train_mask, :] = 0  
        label_matrix[:, ~train_mask] = 0  
    return label_matrix

def create_label_induced_negative_graph_sparse(y, train_mask=None):
    num_nodes = y.size(0)
    if train_mask is None:
        train_mask = torch.ones(num_nodes, dtype=torch.bool, device=y.device)
    
    i = torch.arange(num_nodes, device=y.device)
    row_indices, col_indices = torch.meshgrid(i, i)

    mask = train_mask[row_indices] & train_mask[col_indices]
    row_indices = row_indices[mask]
    col_indices = col_indices[mask]

    values = (y[row_indices] != y[col_indices]).float()
    values = 2 * values - 1  
    
    indices = torch.stack([row_indices, col_indices])  
    size = torch.Size([num_nodes, num_nodes])  
    
    label_matrix_sparse = torch.sparse.FloatTensor(indices, values, size)
    
    return label_matrix_sparse

class GraphConv(nn.Module):
    def __init__(self, in_features, out_features, mode,bias=True):
        super(GraphConv, self).__init__()
        self.weight = nn.Parameter(torch.FloatTensor(in_features, out_features))
        if bias:
            self.bias = nn.Parameter(torch.FloatTensor(1, out_features))
       
        self.in_features = in_features
        self.out_features = out_features
        self.mode = mode
        self.reset_parameters()
        
    def reset_parameters(self):
        stdv = 1. / np.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)
            
    def forward(self, x, adj, scale=0):
        x = torch.spmm(adj, x)
        x = torch.mm(x, self.weight)
        if self.bias is not None:
            x = x + self.bias
        return x

    def __repr__(self):
        return self.__class__.__name__ + "({}->{})".format(self.in_features, self.out_features)


class NormLayer(nn.Module):
    def __init__(self, args):
        """
            mode:
              'None' : No normalization 
              'PN'   : PairNorm
              'PN-SI'  : Scale-Individually version of PairNorm
              'PN-SCS' : Scale-and-Center-Simultaneously version of PairNorm
              'LN': LayerNorm
              'CN': ContraNorm
        """
        super(NormLayer, self).__init__()
        self.mode = args.norm_mode
        self.scale = args.norm_scale
                
    def forward(self, x, adj=None, y=None, train_mask=None, tau=1.0):
        if self.mode == 'None' or 'res':
            return x
        
        if self.mode == 'LN':
            x = x - x.mean(dim=1, keepdim=True)
            x = nn.functional.normalize(x, dim=1)
        
        if self.mode == 'BN':
            x = x - x.mean(dim=0, keepdim=True)
            x = nn.functional.normalize(x, dim=0)
      
        if self.mode == 'PN':
            col_mean = x.mean(dim=0)
            x = x - col_mean
            rownorm_mean = (1e-6 + x.pow(2).sum(dim=1).mean()).sqrt() 
            x = self.scale * x / rownorm_mean
            
        if self.mode == 'CN':
            norm_x = nn.functional.normalize(x, dim=1)
            sim = norm_x @ norm_x.T / tau
            if adj.size(1) == 2:
                sim[adj[0], adj[1]] = -np.inf
            else:
                sim.masked_fill_(adj.to_dense() > 1e-5, -np.inf)
            sim = nn.functional.softmax(sim, dim=1)
            x_neg = sim @ x    
            x = (1 + self.scale) * x - self.scale * x_neg   
        
        # if self.mode == 'Center':
        #     return x
            
        # if self.mode == 'Sign':
        #     norm_x = nn.functional.normalize(x, dim=1)
        #     sim = - torch.mm( norm_x, norm_x.T)
        #     if adj.size(1) == 2:
        #         sim[adj[0], adj[1]] = -np.inf
        #     else:
        #         sim.masked_fill_(adj.to_dense() > 1e-5, -np.inf)
        #     sim = nn.functional.softmax(sim, dim=1)
        #     x_neg = sim @ x    
        #     x = x - x_neg   
        
        # if self.mode == 'Label':
        #     # if y and train_mask is not None:
        #     # neg = create_label_induced_negative_graph( y, train_mask)
        #     neg = create_label_induced_negative_graph_sparse( y, train_mask)
        #     # breakpoint()
        #     neg = nn.functional.softmax(neg.to_dense(), dim=1)
        #     # breakpoint()
        #     # neg = torch.sparse.softmax(neg,dim=1)
        #     x_neg = torch.spmm(neg, x)    
        #     x = x - x_neg   
        
        return x



"""
    helpers
"""

from torch_scatter import scatter_max, scatter_add


def softmax(src, index, num_nodes=None):
    """
        sparse softmax
    """
    num_nodes = index.max().item() + 1 if num_nodes is None else num_nodes
    out = src - scatter_max(src, index, dim=0, dim_size=num_nodes)[0][index]
    out = out.exp()
    out = out / (scatter_add(out, index, dim=0, dim_size=num_nodes)[index] + 1e-16)
    return out


if __name__ == "__main__":
    
    y = torch.tensor([0, 1, 2, 1])
    expected_output = torch.tensor([
        [1., 0., 0., 0.],
        [0., 1., 0., 1.],
        [0., 0., 1., 0.],
        [0., 1., 0., 1.]
    ])
    assert torch.equal(create_label_induced_negative_graph(y), expected_output)
    train_mask = torch.tensor([True, False, True, False])
    expected_output_with_mask = torch.tensor([
        [1., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 0.]
    ])
    assert torch.equal(create_label_induced_negative_graph(y, train_mask), expected_output_with_mask)

    y_same = torch.tensor([1, 1, 1, 1])
    expected_output_same = torch.tensor([
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]
    ])
    assert torch.equal(create_label_induced_negative_graph(y_same), expected_output_same)

    y_different = torch.tensor([0, 1, 2, 3])
    expected_output_different = torch.tensor([
        [1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.]
    ])
    assert torch.equal(create_label_induced_negative_graph(y_different), expected_output_different)
