import dgl.nn.pytorch as dglnn
import torch
import torch.nn as nn
from mygraphconv_calib import MyGraphConv


class GCN(nn.Module):
    def __init__(
        self,
        in_feats,
        n_hidden,
        n_classes,
        n_layers,
        activation,
        dropout,
        use_linear,
    ):
        super().__init__()
        self.n_layers = n_layers
        self.n_hidden = n_hidden
        self.n_classes = n_classes
        self.use_linear = use_linear

        self.convs = nn.ModuleList()
        if use_linear:
            self.linear = nn.ModuleList()
        self.norms = nn.ModuleList()

        for i in range(n_layers):
            in_hidden = n_hidden if i > 0 else in_feats
            out_hidden = n_hidden if i < n_layers - 1 else n_classes
            bias = i == n_layers - 1

            self.convs.append(
                MyGraphConv(in_hidden, out_hidden, "both", bias=bias)
            )

            if i < n_layers - 1:
                self.linear.append(nn.Linear(out_hidden, out_hidden, bias=False))
                self.norms.append(nn.BatchNorm1d(out_hidden))    

        self.input_drop = nn.Dropout(min(0.1, dropout))
        self.dropout = nn.Dropout(dropout)
        self.activation = activation

    def reset_message_stats(self, lfnorm_layer):
        lfnorm_layer.momentum = None  
        lfnorm_layer.reset_running_stats() 

    def forward(self, graph, feat):
        h = feat
        h = self.input_drop(h)

        for i in range(self.n_layers):
            conv = self.convs[i](graph, h)

            if self.use_linear and i < self.n_layers - 1:
                linear = self.linear[i](conv)
                h = conv + linear
            else:
                h = conv

            if i < self.n_layers - 1:
                h = self.norms[i](h)
                h = self.activation(h)
                h = self.dropout(h)

        return h
