import torch
from torch_geometric.nn import GINConv, JumpingKnowledge
from torch.nn import Linear, Parameter, ModuleList, BatchNorm1d
from .utils import MLP
class GINExtractor(torch.nn.Module):
    def __init__(self,
                 input_channels,
                 hidden_channels,
                 num_layer_mlp,
                 num_layer_gin,
                 jump_mode = None):
        super().__init__()
        torch.manual_seed(1111)
        if jump_mode is not None:
            self.jump_layer = JumpingKnowledge(jump_mode)
        else:
            self.jump_layer = None
        self.GIN_layers = ModuleList()
        self._num_layer_gin = num_layer_gin
        for layer in range(num_layer_gin):
            if layer == 0:
                local_input_channels = input_channels
            else:
                local_input_channels = hidden_channels
            self.GIN_layers.append(
                GINConv(MLP(
                    local_input_channels, 
                    hidden_channels, 
                    num_layer_mlp, 
                    hidden_channels, 
                    "leaky_relu", 
                    negative_slope=0.1
                ), train_eps=False)
            )
    
    def forward(self, data):
        x = data.x
        x = x.to(torch.float)
        assert torch.any(torch.isnan(x)) == False, "x contains NaN"
        edge_index = data.edge_index
        xs = []
        curr_x = x
        for i in range(self._num_layer_gin):
            curr_x = self.GIN_layers[i](x=curr_x, edge_index=edge_index)
            assert torch.any(torch.isnan(curr_x)) == False, "curr_x in GIN contains NaN"
            xs.append(curr_x)
        if self.jump_layer is None:
            return xs[-1]
        else:
            return self.jump_layer(xs)
    
    def reset_parameters(self):
        for layer in self.GIN_layers:
            layer.reset_parameters()