import torch
from reddit_Layer import FLGnnConv, NodeEmbed

device = "cuda"


class FLGnnA(torch.nn.Module):

    def __init__(self,
                 in_channels:int,
                 hidden: int,
                 out_channels: int,
                 num_mf: int = 3,
                 fix: bool = False,
                 norm: bool = False,
                 windows: int = 3,
                 stride: int = 3,
                 order: int = 1,
                 A_P2: bool = True,
                 refine_ratio: float = 1,
                 refiner: str = "pool",
                 cross: float = 0.5,
                 edge_attr: int = None,
                 value_intervals: list = None,
                 dropout: float = 0.1,
                 layer: int = 1,
                 attention: bool = True,
                 residual: bool = True,
                 **kwargs):
        super().__init__()

        self.fl = torch.nn.ModuleList([FLGnnConv(in_channels=hidden, out_channels=hidden, fix_mf=fix,
                                                 norm=norm, method="mad", order=order, num_mf=num_mf,
                                                 windows_size=windows, residual=residual,
                                                 stride_size=stride, A_P2=A_P2, refine_ratio=refine_ratio,
                                                 refiner=refiner, cross=cross, attention=attention,
                                                 value_intervals=value_intervals) for _ in range(layer)])

        self.node_embed = NodeEmbed(out_feature=hidden, node_features=in_channels)
        self.batch_norm = torch.nn.BatchNorm1d(num_features=hidden)
        self.lin = torch.nn.Linear(in_features=hidden, out_features=out_channels)
        self.dropout = dropout

    def forward(self, x, adjs):

        # x = torch.nn.functional.dropout(x, p=0.3, training=self.training)
        x = self.node_embed(x)

        if not isinstance(adjs, list):
            adjs = [adjs]

        for (edge_index, e_id, size), m in zip(adjs, self.fl):

            edge_index = edge_index.to(device)

            # the purpose of add self-loop is averting isolate node miss after scatter
            # edge_index, edge_attr = add_self_loops(edge_index, edge_attr, num_nodes=x.size(0))

            # x = torch.nn.functional.dropout(x, training=self.training, p=0.3)
            x[:size[1]] = m(x, edge_index, size[1])

        x = self.batch_norm(x)
        x = torch.nn.functional.sigmoid(x)
        x = self.lin(x)
        return x



if __name__ == '__main__':
    pass
    # 查看模型结构
    # m = FLGNNConv(in_channels=5, hidden=128, out_channels=1, num_mf=3, fix_mf=False, p=0.6, norm=True, method="ein").to(
    #     device)
    # print(get_parameter_number(m))

    # 隶属函数可视化
    # mf_visual(m.fuzzier.MFs, 5)

    # 规则层可视化
    # plt.scatter(np.concatenate(m.ruler.recorder[0], axis=0), np.concatenate(m.ruler.recorder[1], axis=0), s=3)
    # plt.show()

    # 运行时间分析
    # lp = LineProfiler()
    # lp.add_function(RulerEdgeSample.forward)
    # lp.add_function(m.forward)
    # lp_wrapper = lp(m.fit)
    # lp_wrapper(1)
    # lp.print_stats()
