import torch
from torch_geometric.data import Data
from torch_geometric.utils import add_self_loops
from torch_scatter import scatter_sum
from molfreesolv_layer import FLGnnConv, NodeEmbed


class FLGnnA(torch.nn.Module):

    def __init__(self,
                 hidden: int,
                 out_channels: int,
                 num_mf: int = 3,
                 fix: bool = False,
                 norm: bool = False,
                 windows: int = 3,
                 stride: int = 3,
                 order: int = 1,
                 A_P2: bool = True,
                 refine_ratio: float = 1,
                 refiner: str = "pool",
                 cross: float = 0.5,
                 edge_attr: int = None,
                 value_intervals: list = None,
                 dropout: float = 0.1,
                 layer: int = 1,
                 attention: bool = True,
                 residual: bool = True,
                 **kwargs):
        super().__init__()

        self.fl = torch.nn.ModuleList([FLGnnConv(in_channels=hidden, out_channels=hidden, fix_mf=fix,
                                                 norm=norm, method="mad", order=order, num_mf=num_mf,
                                                 windows_size=windows, residual=residual,
                                                 stride_size=stride, A_P2=A_P2, refine_ratio=refine_ratio,
                                                 refiner=refiner, cross=cross, attention=attention,
                                                 value_intervals=value_intervals) for _ in range(layer)])

        self.node_embed = NodeEmbed(out_feature=hidden)
        self.batch_norm = torch.nn.BatchNorm1d(num_features=hidden)
        self.lin = torch.nn.Linear(in_features=hidden, out_features=out_channels)
        self.dropout = dropout

    def forward(self, graph: Data, Train: bool = True):
        x, edge_index, edge_attr, batch = graph.x, graph.edge_index, graph.edge_attr, graph.batch

        # the purpose of add self-loop is averting isolate node miss after scatter
        edge_index, edge_attr = add_self_loops(edge_index, edge_attr, num_nodes=batch.size()[0])

        x = self.node_embed(x, edge_attr, edge_index)

        for i, m in enumerate(self.fl):
            x = torch.nn.functional.dropout(x, training=Train, p=self.dropout)
            x = m(x, edge_index)
            # self.firing_strength[i].append(t_res)

        x = scatter_sum(src=x, index=batch, dim=0)
        x = self.batch_norm(x)
        x = torch.nn.functional.gelu(x)
        x = self.lin(x)
        return x


if __name__ == '__main__':
    pass
    # 查看模型结构
    # m = FLGNNConv(in_channels=5, hidden=128, out_channels=1, num_mf=3, fix_mf=False, p=0.6, norm=True, method="ein").to(
    #     device)
    # print(get_parameter_number(m))

    # 隶属函数可视化
    # mf_visual(m.fuzzier.MFs, 5)

    # 规则层可视化
    # plt.scatter(np.concatenate(m.ruler.recorder[0], axis=0), np.concatenate(m.ruler.recorder[1], axis=0), s=3)
    # plt.show()

    # 运行时间分析
    # lp = LineProfiler()
    # lp.add_function(RulerEdgeSample.forward)
    # lp.add_function(m.forward)
    # lp_wrapper = lp(m.fit)
    # lp_wrapper(1)
    # lp.print_stats()
