import torch
import torch.nn.functional as F
import torch.nn as nn
from torch import Tensor

from utils.register import register

class DoorGate(torch.nn.Module): 
    def __init__(self, n_labels: int):
        super().__init__()
        self.n_labels = n_labels
        self.gate = nn.Sequential(
            nn.Linear(n_labels * 2, 1),
            nn.Sigmoid()
        )

    def forward(self, x: Tensor) -> Tensor:
        # 分割输入为两部分：原始logits 和 标签logits
        emb_logits = x[:, :self.n_labels]          # [B, C]
        label_logits = x[:, self.n_labels:]        # [B, C]

        # 分别 softmax 成概率分布
        emb_probs = torch.softmax(emb_logits, dim=-1)         # [B, C]
        label_probs = torch.softmax(label_logits, dim=-1)     # [B, C]

        # 门控输入：使用 logits 而不是 probs（更稳定）
        gate_input = torch.cat([emb_logits, label_logits], dim=-1)  # [B, 2C]
        gate_weights = self.gate(gate_input).squeeze(-1)            # [B]

        # 扩展 gate weights 到 [B, C]
        gate_weights_expanded = gate_weights.unsqueeze(-1)  # [B, 1]

        # 加权融合两个概率分布
        final_probs = gate_weights_expanded * emb_probs + (1 - gate_weights_expanded) * label_probs  # [B, C]

        return final_probs
        
    @torch.no_grad() 
    def inference(self, x_all: Tensor, device: torch.device,  subgraph_loader: NeighborLoader) -> Tensor:
        # 特征提取部分 
        if hasattr(self, "encoder"):
            xs = []
            for batch in subgraph_loader:
                x = x_all[batch.n_id.to(x_all.device)].to(device) 
                x = self.encoder(x) 
                x = torch.flatten(x,  start_dim=1)
                xs.append(x.cpu()) 
            x_all = torch.cat(xs,  dim=0)
            
        for i, conv in enumerate(self.convs): 
            xs = []
            for batch in subgraph_loader:
                x = x_all[batch.n_id.to(x_all.device)].to(device) 
                x = conv(x, batch.edge_index.to(device))[:batch.batch_size] 
                if i < len(self.convs)  - 1:
                    x = x.relu_() 
                xs.append(x.cpu()) 
            x_all = torch.cat(xs,  dim=0)
 
        # 计算最终分类概率
        class_probs = []
        for batch in subgraph_loader:
            x = x_all[batch.n_id.to(x_all.device)].to(device) 
            probs = torch.softmax(x  @ self.label_embeddings.T  / self.temperature,  dim=-1)
            class_probs.append(probs.cpu()) 
            
        return torch.cat(class_probs,  dim=0)  # [N, C]