import torch
from CiteSeer_layer import FLGnnConv, NodeEmbed

device = "cuda"


class FLGnnA(torch.nn.Module):

    def __init__(self,
                 in_channels:int,
                 hidden: int,
                 out_channels: int,
                 num_mf: int = 3,
                 fix: bool = False,
                 norm: bool = False,
                 windows: int = 3,
                 stride: int = 3,
                 order: int = 1,
                 A_P2: bool = True,
                 refine_ratio: float = 1,
                 refiner: str = "pool",
                 cross: float = 0.5,
                 edge_attr: int = None,
                 value_intervals: list = None,
                 dropout: float = 0.1,
                 layer: int = 1,
                 attention: bool = True,
                 residual: bool = True,
                 **kwargs):
        super().__init__()

        self.fl = torch.nn.ModuleList([FLGnnConv(in_channels=hidden, out_channels=hidden, fix_mf=fix,
                                                 norm=norm, method="mad", order=order, num_mf=num_mf,
                                                 windows_size=windows, residual=residual,
                                                 stride_size=stride, A_P2=A_P2, refine_ratio=refine_ratio,
                                                 refiner=refiner, cross=cross, attention=attention,
                                                 value_intervals=value_intervals) for _ in range(layer)])

        self.node_embed = NodeEmbed(out_feature=hidden, node_features=in_channels)
        self.batch_norm = torch.nn.BatchNorm1d(num_features=hidden)
        self.lin = torch.nn.Linear(in_features=hidden, out_features=out_channels)
        self.dropout = dropout


    def forward(self, graph):

        x, edge_index = graph.x, graph.edge_index

        # edge_index, edge_attr = add_self_loops(edge_index, edge_attr, num_nodes=x.size(0))

        if self.node_embed is not None:
            x = torch.nn.functional.dropout(x, p=self.dropout, training=self.training)
            x = self.node_embed(x, edge_index)

        for m in self.fl:
            x = torch.nn.functional.dropout(x, training=self.training, p=self.dropout)
            x = m(x, edge_index)

        x = self.batch_norm(x)
        x = torch.nn.functional.gelu(x)
        x = self.lin(x)
        return x


if __name__ == '__main__':
    pass
