import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from gcn_conv import GCNConv
from torch_sparse import SparseTensor, fill_diag, matmul, mul


class GCN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers=2,
                 dropout=0.5, save_mem=True, use_bn=True, use_nasc=False):
        super(GCN, self).__init__()

        self.convs = nn.ModuleList()
        self.use_nasc = use_nasc
        if self.use_nasc:
            self.nasc = nn.ModuleList()
        
        self.convs.append(
            GCNConv(in_channels, hidden_channels, cached=not save_mem))
        
        if self.use_nasc:
            self.nasc.append(
                nn.Identity() if in_channels == hidden_channels else nn.Linear(in_channels, hidden_channels)
            )
        self.bns = nn.ModuleList()
        self.bns.append(nn.BatchNorm1d(hidden_channels))
        for _ in range(num_layers - 2):
            self.convs.append(
                GCNConv(hidden_channels, hidden_channels, cached=not save_mem))
            self.bns.append(nn.BatchNorm1d(hidden_channels))
            if self.use_nasc:
                self.nasc.append(nn.Identity())

        self.convs.append(
            GCNConv(hidden_channels, out_channels, cached=not save_mem))
        
        if self.use_nasc:
            self.nasc.append(
                nn.Identity() if out_channels == hidden_channels else nn.Linear(hidden_channels, out_channels)
            )
        
        self.dropout = dropout
        self.activation = F.relu
        self.use_bn = use_bn
        

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        for bn in self.bns:
            bn.reset_parameters()

    def forward(self, data):
        x = data.graph['node_feat']
        edge_index=data.graph['edge_index']
        edge_weight=data.graph['edge_weight'] if 'edge_weight' in data.graph else None
        for i, conv in enumerate(self.convs[:-1]):
            if edge_weight is None:
                x, ah = conv(x, edge_index)
            else:
                x, ah = conv(x,edge_index,edge_weight)

            if self.use_bn:
                x = self.bns[i](x)
            x = self.activation(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
            if self.use_nasc:
                ah = self.nasc[i](ah)
                x = x + ah
        x, ah = self.convs[-1](x, data.graph['edge_index'])
        if self.use_nasc:
            ah = self.nasc[-1](ah)
            x = x + ah
        return x
