import torch
import torch.nn.functional as F
from torch import Tensor
from torch_geometric.nn import SAGEConv
from torch_geometric.loader import NeighborLoader
from utils.register import register

class SAGE(torch.nn.Module):
    def __init__(
        self, in_channels: int, hidden_channels: int, out_channels: int, layer_num: int, dropout):
        super().__init__()
        
        self.dropout_rate = dropout
        self.convs = torch.nn.ModuleList()
        self.convs.append(SAGEConv(in_channels, hidden_channels))
        #self.convs.append(SAGEConv(in_channels, in_channels))
        for _ in range(layer_num - 2):
            self.convs.append(SAGEConv(hidden_channels, hidden_channels))
            
        self.convs.append(SAGEConv(hidden_channels, out_channels))
        #self.convs.append(SAGEConv(in_channels, out_channels))

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()

    def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:
        # print(type(x))
        # print(x.shape)
        # print(x)
        # print(type(edge_index))
        # print(edge_index.shape)
        # print(edge_index)
        if hasattr(self, "encoder"):
            # print("here!!!!")
            x = self.encoder(x)
            x = torch.flatten(x, start_dim=1)
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            if i < len(self.convs) - 1:
                x = torch.relu(x)
                x = F.dropout(x, p=self.dropout_rate, training=self.training)
        return x
    
    @torch.no_grad()
    def inference(self, x_all: Tensor, device: torch.device, subgraph_loader: NeighborLoader) -> Tensor:
        # Compute representations of nodes layer by layer, using *all*
        # available edges. This leads to faster computation in contrast to
        # immediately computing the final representations of each batch:
        # print("HERE")
        # print(x_all.shape)
        if hasattr(self, "encoder"):
            xs = []
            # print("what?")
            for batch in subgraph_loader:
                # x = x_all[batch.n_id.to(x_all.device)].to(device)
                
                valid_mask = (batch.n_id < x_all.size(0))
                valid_n_id = batch.n_id[valid_mask]
                x = x_all[valid_n_id].to(device)
                
                x = self.encoder(x)
                x = torch.flatten(x, start_dim=1)
                xs.append(x.cpu())
            x_all = torch.cat(xs, dim=0)
        for i, conv in enumerate(self.convs):
            xs = []
            for batch in subgraph_loader:
                # print(batch.shape)
                n_id = batch.n_id
                # x = x_all[batch.n_id.to(x_all.device)].to(device)
                # x = conv(x, batch.edge_index.to(device))
                # x = x[: batch.batch_size]
                
                # 过滤掉无效的 n_id
#                 valid_mask = (batch.n_id < x_all.size(0))
#                 valid_n_id = batch.n_id[valid_mask]

#                 x = x_all[valid_n_id].to(device)
#                 edge_index = batch.edge_index.to(device)

                x = torch.zeros((n_id.size(0), x_all.size(-1)), dtype=x_all.dtype)
                valid_mask = (n_id < x_all.size(0))
                x[valid_mask] = x_all[n_id[valid_mask]]
                x = x.to(device)

#                 # 确保只对有效节点进行计算
#                 x = conv(x, edge_index)
#                 x = x[:batch.batch_size]  # 注意：这里需要确保 batch.batch_size 不超过有效节点的数量
                
#                 if i < len(self.convs) - 1:
#                     x = x.relu_()

                # 使用 edge_index 或 adj_t，优先使用 adj_t（如果存在）
                if hasattr(batch, 'adj_t') and batch.adj_t is not None:
                    graph_input = batch.adj_t.to(device)
                elif hasattr(batch, 'edge_index') and batch.edge_index is not None:
                    graph_input = batch.edge_index.to(device)
                else:
                    raise RuntimeError("Batch does not contain 'edge_index' or 'adj_t'. Cannot perform GNN message passing.")

                # 执行图卷积
                x = conv(x, graph_input)
                x = x[:batch.batch_size]  # 只取目标节点（中心节点）的输出

                if i < len(self.convs) - 1:
                    x = x.relu_()
                xs.append(x.cpu())
            x_all = torch.cat(xs, dim=0)

        return x_all