import torch
from torch_geometric.nn import MessagePassing
import torch.nn.init as init
import math
from torch_geometric.utils import add_self_loops, degree
from torch_geometric.nn import global_mean_pool, global_add_pool, MessagePassing, \
    GCNConv, BatchNorm, GINConv, GINEConv, GraphNorm

from infrastructure import pytorch_util as ptu

import torch.nn.functional as F
def glorot(tensor):
    if tensor is not None:
        stdv = math.sqrt(6.0 / (tensor.size(-2) + tensor.size(-1)))
        tensor.data.uniform_(-stdv, stdv)


def zeros(tensor):
    if tensor is not None:
        tensor.data.fill_(0)

class GNNEdgeConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GNNEdgeConv, self).__init__(aggr='add')  # "Add" aggregation of MessagePassing class.
        self.lin = torch.nn.Linear(in_channels, out_channels)
        self.bias = torch.nn.Parameter(torch.Tensor(out_channels))
        self.reset_parameters()


    def forward(self, x, edge_index, edge_attr):

        x = self.lin(x)
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        out = self.propagate(edge_index, x=x, edge_attr=edge_attr, norm=norm)   
        out += self.bias

        return out

    def message(self, x_j, edge_attr, norm):
        return edge_attr * x_j * norm.view(-1, 1)
    
class GCNNet(torch.nn.Module):
    def __init__(self, num_features, hidden, num_classes):
        super(GCNNet, self).__init__()
        self.conv1 = GCNConv(num_features, hidden) 
        self.batch_norm1 = BatchNorm(hidden)
        self.conv2 = GCNConv(hidden,num_classes)
        self.batch_norm2 = BatchNorm(num_classes)
        self.conv3 = GNNEdgeConv(hidden,num_classes)
        self.batch_norm3 = BatchNorm(num_classes)
        self.init_weights()

    def init_weights(self):
        for layer in [self.conv1, self.conv2, self.conv3 ]:
            if hasattr(layer, 'weight'):
                init.xavier_uniform_(layer.weight)
            if hasattr(layer, 'bias') :
                init.normal_(layer.bias)
        for layer in [self.batch_norm1, self.batch_norm2, self.batch_norm3]:
            if hasattr(layer.module, 'bias'):
                init.normal_(layer.module.bias)
            if hasattr(layer.module, 'weight'):
                init.normal_(layer.module.weight)

    def forward(self, x, edge_index, edge_attr, dropout) : #data, dropout):
        x = self.conv1(x, edge_index, edge_attr)
        x = F.leaky_relu(x)
        x = self.conv2(x, edge_index, edge_attr)
        x = F.leaky_relu(x)
        x = F.dropout(x, p=dropout, training=self.training)
        x = torch.sum(x, dim=0)
        x = x.unsqueeze(0)
        out = F.log_softmax(x, dim=1).to(ptu.device)()
        return out
           

def uniform(size, tensor):
    bound = 1.0 / math.sqrt(size)
    if tensor is not None:
        tensor.data.uniform_(-bound, bound)


class SAGEConv(MessagePassing):
    def __init__(self, in_channels, out_channels, aggr):
        super(SAGEConv, self).__init__(aggr=aggr)

        self.in_channels = in_channels
        self.out_channels = out_channels

        self.weight = torch.nn.Parameter(torch.Tensor(2 * in_channels, out_channels))
        
        self.reset_parameters()

    def reset_parameters(self):
        uniform(self.weight.size(0), self.weight)

    def forward(self, x, edge_index):

        return self.propagate(edge_index, x=x) 
        # NOTE: need both edge_index and x for the default just pass everything
        # can put in more args (ex: norm=..., etc) and add those arguments to message
        # as well                  

    def message(self, x_j):

        return x_j
    
        # NOTE: x_i are nodes that aggregate info, x_j are the nodes that send info 
        # in "source_to_target" message passing. both x_j and x_i are understood
        # by default
        


    def update(self, aggr_out, x):

        aggr_out = torch.cat([x, aggr_out], dim=-1) 
        aggr_out = torch.matmul(aggr_out, self.weight)  #first tensor to be multiplied, 
                                                        #other tensor to be mult
        aggr_out = F.normalize(aggr_out, p=2, dim=-1)   # nonlinearity
        
        return aggr_out

class SAGENet(torch.nn.Module):
    def __init__(self, dataset, hidden, aggr='mean'):
        super(SAGENet, self).__init__()
        self.conv1 = SAGEConv(dataset.num_features, hidden, aggr=aggr) 
        self.conv2 = SAGEConv(hidden, dataset.num_classes, aggr=aggr)
        self.reset_parameters()
        

    def reset_parameters(self):
        self.conv1.reset_parameters()
        self.conv2.reset_parameters()

    def forward(self, data, dropout):

        x, edge_index = data.x, data.edge_index
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, p=dropout, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)
        

def reset(nn):
    def _reset(item):
        if hasattr(item, 'reset_parameters'):
            item.reset_parameters()

    if nn is not None:
        if hasattr(nn, 'children') and len(list(nn.children())) > 0:
            for item in nn.children():
                _reset(item)
        else:
            _reset(nn)


class GIN0(torch.nn.Module):
    def __init__(self, dataset, num_layers, hidden):
        super(GIN0, self).__init__()
        self.conv1 = GINConv(torch.nn.Sequential(
            torch.nn.Linear(dataset.num_features, hidden),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden, hidden),
            torch.nn.ReLU(),
            torch.nn.BatchNorm1d(hidden),
        ),
                             train_eps=False)
        self.convs = torch.nn.ModuleList()
        for i in range(num_layers - 1):
            self.convs.append(
                GINConv(torch.nn.Sequential(
                    torch.nn.Linear(hidden, hidden),
                    torch.nn.ReLU(),
                    torch.nn.Linear(hidden, hidden),
                    torch.nn.ReLU(),
                    torch.nn.BatchNorm1d(hidden),
                ),
                        train_eps=False))
        self.lin1 = torch.nn.Linear(hidden, hidden)
        self.lin2 = torch.nn.Linear(hidden, dataset.num_classes)

    def reset_parameters(self):
        self.conv1.reset_parameters()
        for conv in self.convs:
            conv.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, data):
        
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = self.conv1(x, edge_index)
        for conv in self.convs:
            x = conv(x, edge_index)
        x = global_mean_pool(x, batch)
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)
        return F.log_softmax(x, dim=-1)


class GIN(torch.nn.Module):
    def __init__(self, dataset, num_layers, hidden):
        super(GIN, self).__init__()
        self.conv1 = GINConv(torch.nn.Sequential(
            torch.nn.Linear(dataset.num_features, hidden),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden, hidden),
            torch.nn.ReLU(),
            torch.nn.BatchNorm1d(hidden),
        ),
                             train_eps=True)
        self.convs = torch.nn.ModuleList()
        for i in range(num_layers - 1):
            self.convs.append(
                GINConv(torch.nn.Sequential(
                    torch.nn.Linear(hidden, hidden),
                    torch.nn.ReLU(),
                    torch.nn.Linear(hidden, hidden),
                    torch.nn.ReLU(),
                    torch.nn.BatchNorm1d(hidden),
                ),
                        train_eps=True))
        self.lin1 = torch.nn.Linear(hidden, hidden)
        self.lin2 = torch.nn.Linear(hidden, dataset.num_classes)

    def reset_parameters(self):
        self.conv1.reset_parameters()
        for conv in self.convs:
            conv.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, data):
        
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = self.conv1(x, edge_index)
        for conv in self.convs:
            x = conv(x, edge_index)
        x = global_mean_pool(x, batch)
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)
        return F.log_softmax(x, dim=-1)

class GINE(torch.nn.Module):
    def __init__(self, num_features, hidden, num_classes, num_layers):
        super(GINE, self).__init__()
        self.conv1 = GINEConv(torch.nn.Sequential(
            torch.nn.Linear(num_features, hidden),
            torch.nn.Tanh(),
            torch.nn.Linear(hidden, hidden),
            torch.nn.Tanh(),
            torch.nn.Linear(hidden, hidden),
            torch.nn.Tanh(),
        ),
                             train_eps=True, edge_dim=7).to(ptu.device)
        self.convs = torch.nn.ModuleList()
        self.norms = torch.nn.ModuleList()
        for i in range(num_layers - 1):
            self.convs.append(
                GINEConv(torch.nn.Sequential(
                    torch.nn.Linear(hidden, hidden),
                    torch.nn.Tanh(),
                    torch.nn.Linear(hidden, hidden),
                    torch.nn.Tanh(),
                    torch.nn.Linear(hidden, hidden),
                    torch.nn.Tanh(),
                ),
                        train_eps=True, edge_dim =7).to(ptu.device))
            self.norms.append(GraphNorm(num_classes))

        self.norm1 = GraphNorm(num_classes)
    def reset_parameters(self):
        self.conv1.reset_parameters()
        for conv in self.convs:
            conv.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, x, edge_index, edge_attr, dropout, batch):
        
        new_x = self.conv1(x , edge_index, edge_attr)
        new_x = self.norm1(new_x, batch)
        x = new_x
        output = global_add_pool(x, batch=batch) 
        for norm, conv in zip(self.norms, self.convs):
            new_x = conv(x, edge_index, edge_attr)
            new_x = norm(new_x, batch)
            x = new_x + x

            x = new_x
            sum_x = global_add_pool(x, batch=batch)
            output = torch.cat([output, sum_x], dim=1)

        return output.to(ptu.device) 