import torch
import torch.nn.functional as F
import torch.nn as nn
from torch import Tensor
from torch_geometric.nn import SAGEConv
from torch_geometric.loader import NeighborLoader
from utils.register import register

class LabelSAGE(torch.nn.Module): 
    def __init__(self, in_channels: int, hidden_channels: int, 
                 layer_num: int, label_embeddings: Tensor, dropout: float):
        super().__init__()
        self.register_buffer('label_embeddings',  label_embeddings)  # [C, D]
        self.dropout_rate  = dropout 
        
        # 图卷积层
        self.convs  = torch.nn.ModuleList() 
        self.convs.append(SAGEConv(in_channels,  in_channels))
        for _ in range(layer_num - 2):
            self.convs.append(SAGEConv(hidden_channels,  hidden_channels))
        self.convs.append(SAGEConv(in_channels,  in_channels))
        
        # 温度系数
        self.temperature  = nn.Parameter(torch.tensor(1.0))  
 
    def reset_parameters(self):
        for conv in self.convs: 
            conv.reset_parameters() 
        nn.init.constant_(self.temperature,  1.0)
 
    def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:
        if hasattr(self, "encoder"):
            x = self.encoder(x) 
            x = torch.flatten(x,  start_dim=1)
            
        for i, conv in enumerate(self.convs): 
            x = conv(x, edge_index)
            if i < len(self.convs)  - 1:
                x = torch.relu(x) 
                x = F.dropout(x,  p=self.dropout_rate,  training=self.training) 
                
        sim_matrix = x @ self.label_embeddings.T   # [N, C]
        scaled_sim = sim_matrix / torch.sqrt(self.temperature) 
        class_probs = torch.softmax(scaled_sim,  dim=-1)  # [N, C]

        return class_probs
 
    @torch.no_grad() 
    def inference(self, x_all: Tensor, device: torch.device,  subgraph_loader: NeighborLoader) -> Tensor:
        # 特征提取部分 
        if hasattr(self, "encoder"):
            xs = []
            for batch in subgraph_loader:
                x = x_all[batch.n_id.to(x_all.device)].to(device) 
                x = self.encoder(x) 
                x = torch.flatten(x,  start_dim=1)
                xs.append(x.cpu()) 
            x_all = torch.cat(xs,  dim=0)
            
        for i, conv in enumerate(self.convs): 
            xs = []
            for batch in subgraph_loader:
                x = x_all[batch.n_id.to(x_all.device)].to(device) 
                x = conv(x, batch.edge_index.to(device))[:batch.batch_size] 
                if i < len(self.convs)  - 1:
                    x = x.relu_() 
                xs.append(x.cpu()) 
            x_all = torch.cat(xs,  dim=0)
 
        # 计算最终分类概率
        class_probs = []
        for batch in subgraph_loader:
            x = x_all[batch.n_id.to(x_all.device)].to(device) 
            probs = torch.softmax(x  @ self.label_embeddings.T  / self.temperature,  dim=-1)
            class_probs.append(probs.cpu()) 
            
        return torch.cat(class_probs,  dim=0)  # [N, C]