import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GINConv, global_mean_pool
import math

class DGMMD(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, num_layers, bandwidths=[1.0],
                 dropout=0.1, batch_norm=True, kde_dimension=1):
        super().__init__()
        self.bandwidths = bandwidths
        
        self.convs = nn.ModuleList()
        self.bns = nn.ModuleList() if batch_norm else None
        
        # 第一层
        self.convs.append(GINConv(
            nn.Sequential(nn.Linear(in_dim, hidden_dim), 
                         nn.ReLU(),
                         nn.Linear(hidden_dim, hidden_dim))
        ))
        if batch_norm:
            self.bns.append(nn.BatchNorm1d(hidden_dim))
            
        # 中间层
        for _ in range(num_layers - 2):
            self.convs.append(GINConv(
                nn.Sequential(nn.Linear(hidden_dim, hidden_dim),
                             nn.ReLU(), 
                             nn.Linear(hidden_dim, hidden_dim))
            ))
            if batch_norm:
                self.bns.append(nn.BatchNorm1d(hidden_dim))
                
        # 最后一层
        self.convs.append(GINConv(
            nn.Sequential(nn.Linear(hidden_dim, out_dim),
                         nn.ReLU(),
                         nn.Linear(out_dim, out_dim))
        ))
        if batch_norm:
            self.bns.append(nn.BatchNorm1d(out_dim))
            
        self.dropout = nn.Dropout(dropout)
        self.kde_logits = nn.Parameter(torch.ones(len(bandwidths)) / len(bandwidths))
        self.kde_dimension = kde_dimension
        
    def forward(self, x, edge_index, batch):
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            if self.bns is not None:
                x = self.bns[i](x)
            if i != len(self.convs) - 1:
                x = F.relu(x)
                x = self.dropout(x)
        
        # 图级别表示
        graph_emb = global_mean_pool(x, batch)
        return x, graph_emb
        
    def compute_distance_matrix(self, data_a, data_b=None):
        x_a, edge_index_a, batch_a = data_a.x, data_a.edge_index, data_a.batch
        _, g_emb_a = self(x_a, edge_index_a, batch_a)
        
        if data_b is None:
            dist_matrix = self._pairwise_distances(g_emb_a)
        else:
            x_b, edge_index_b, batch_b = data_b.x, data_b.edge_index, data_b.batch
            _, g_emb_b = self(x_b, edge_index_b, batch_b)
            dist_matrix = self._pairwise_distances_cross(g_emb_a, g_emb_b)
            
        return dist_matrix
        
    def compute_kde_scores(self, dist_matrix):
        alpha = F.softmax(self.kde_logits, dim=0)
        M, N = dist_matrix.shape
        total_kde = torch.zeros(M, device=dist_matrix.device)
        
        for k, bw in enumerate(self.bandwidths):
            exponent = -0.5 * (dist_matrix / bw)**2
            kernel_vals = torch.exp(exponent)
            sum_over_ref = kernel_vals.sum(dim=1)
            
            gauss_factor = 1.0 / ((2 * math.pi * (bw**2)) ** (0.5 * self.kde_dimension))
            partial_kde = (1.0 / N) * gauss_factor * sum_over_ref
            total_kde += alpha[k] * partial_kde
            
        return total_kde
        
    def _pairwise_distances(self, embeddings):
        norm = torch.sum(embeddings**2, dim=1, keepdim=True)
        dist_matrix = norm + norm.t() - 2.0 * torch.mm(embeddings, embeddings.t())
        return torch.clamp(dist_matrix, min=0.0)
        
    def _pairwise_distances_cross(self, emb_a, emb_b):
        norm_a = torch.sum(emb_a**2, dim=1, keepdim=True)
        norm_b = torch.sum(emb_b**2, dim=1, keepdim=True)
        dist_matrix = norm_a + norm_b.t() - 2.0 * torch.mm(emb_a, emb_b.t())
        return torch.clamp(dist_matrix, min=0.0)

def batch_graphs(graph_list):
    """将图列表打包成批处理"""
    return torch.utils.data.dataloader.default_collate(graph_list) 