"""
@Description :   GCN
@Author      :   tqychy 
@Time        :   2025/01/02 17:52:41
"""
import torch.nn as nn
from torch_geometric.nn import DeepGCNLayer, GATConv


class GCNLayer(nn.Module):
    def __init__(self, *args):
        super().__init__()
        self.cfg, self.logger = args
        self.conv = GATConv(self.cfg.NET.FEATURE_EXTRACT_DIM, self.cfg.NET.RESGCN.GCN_FILTERS,
                            heads=self.cfg.NET.RESGCN.GCN_GAT_HEADS, concat=False)
        # self.conv = GCNConv(args.in_channels, args.n_filters)
        # if args.norm == 'batch':
        #     self.norm = nn.BatchNorm1d(self.cfg.TRAIN.PAIRING.FEATURE_DIM)
        # elif args.norm == 'layer':
        #     self.norm = nn.LayerNorm(self.cfg.TRAIN.PAIRING.FEATURE_DIM)
        self.norm = nn.BatchNorm1d(self.cfg.NET.FEATURE_EXTRACT_DIM)
        self.act = nn.ReLU(inplace=True) if True else nn.Identity()
        self.deep_gcn = DeepGCNLayer(
            self.conv, self.norm, self.act, block=self.cfg.NET.RESGCN.GCN_BLOCK_TYPE)

    def forward(self, x, a):
        return self.deep_gcn(x, a)


class DeepGCN(nn.Module):
    def __init__(self, *args):
        super().__init__()
        self.cfg, self.logger = args
        n_blocks = self.cfg.NET.BLOCKS
        self.deep_gcn = nn.Sequential(
            *[GCNLayer(self.cfg, self.logger) for _ in range(n_blocks)])

    def forward(self, x, a):
        for block in self.deep_gcn:
            x = block(x, a)
        return x

class VariationalGCNEncoder(nn.Module):
    def __init__(self, *args):
        super().__init__()
        
        self.conv1 = DeepGCN(*args)
        self.conv_mu = DeepGCN(*args)
        self.conv_logstd = DeepGCN(*args)
        self.relu = nn.ReLU()

    def forward(self, x, edge_index):
        x = self.relu(self.conv1(x, edge_index))
        return self.conv_mu(x, edge_index), self.conv_logstd(x, edge_index)
