import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn.pytorch import GraphConv, GINConv
from dgl.nn.pytorch.glob import SumPooling

class GCN(nn.Module):
    """
    Graph Convolutional Network (GCN) model.
    """
    def __init__(self, in_dim, hidden_dim, out_dim, num_layers):
        super(GCN, self).__init__()
        self.layers = nn.ModuleList()
        self.layers.append(GraphConv(in_dim, hidden_dim))
        for _ in range(num_layers - 2):
            self.layers.append(GraphConv(hidden_dim, hidden_dim))
        self.layers.append(GraphConv(hidden_dim, out_dim))
        
    def forward(self, graph, features):
        h = features
        for layer in self.layers:
            h = layer(graph, h)
            h = F.relu(h)
        return h

class GIN(nn.Module):
    """
    Graph Isomorphism Network (GIN) model.
    """
    def __init__(self, in_dim, hidden_dim, out_dim, num_layers):
        super(GIN, self).__init__()
        self.layers = nn.ModuleList()
        self.batch_norms = nn.ModuleList()
        
        self.layers.append(GINConv(nn.Sequential(nn.Linear(in_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim))))
        self.batch_norms.append(nn.BatchNorm1d(hidden_dim))
        
        for _ in range(num_layers - 2):
            self.layers.append(GINConv(nn.Sequential(nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim))))
            self.batch_norms.append(nn.BatchNorm1d(hidden_dim))
        
        self.layers.append(GINConv(nn.Sequential(nn.Linear(hidden_dim, out_dim), nn.ReLU(), nn.Linear(out_dim, out_dim))))
        self.batch_norms.append(nn.BatchNorm1d(out_dim))
        
        self.pool = SumPooling()
        
    def forward(self, graph, features):
        h = features
        for layer, batch_norm in zip(self.layers, self.batch_norms):
            h = layer(graph, h)
            h = batch_norm(h)
            h = F.relu(h)
        h = self.pool(graph, h)
        return h

class PXGL_GNN(nn.Module):
    """
    PXGL-GNN model for learning graph representations based on subgraph patterns.
    """
    def __init__(self, in_dim, hidden_dim, out_dim, num_layers, num_patterns):
        super(PXGL_GNN, self).__init__()
        self.num_patterns = num_patterns
        self.pattern_gnns = nn.ModuleList()
        for _ in range(num_patterns):
            self.pattern_gnns.append(GIN(in_dim, hidden_dim, out_dim, num_layers))
        self.pattern_weights = nn.Parameter(torch.ones(num_patterns) / num_patterns)
        
    def forward(self, graph, features, pattern_subgraphs):
        pattern_embeddings = []
        for i in range(self.num_patterns):
            subgraph_embeddings = []
            for subgraph in pattern_subgraphs[i]:
                subgraph_features = features[subgraph]
                subgraph_embedding = self.pattern_gnns[i](subgraph, subgraph_features)
                subgraph_embeddings.append(subgraph_embedding)
            pattern_embedding = torch.stack(subgraph_embeddings).mean(dim=0)
            pattern_embeddings.append(pattern_embedding)
        pattern_embeddings = torch.stack(pattern_embeddings)
        weights = F.softmax(self.pattern_weights, dim=0)
        graph_embedding = (pattern_embeddings * weights.view(-1, 1)).sum(dim=0)
        return graph_embedding