import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleHGNNPConv(nn.Module):
    """Simplified HGNNP convolution layer."""
    
    def __init__(self, in_channels, out_channels, use_bn=False, drop_rate=0.5, is_last=False):
        super(SimpleHGNNPConv, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.use_bn = use_bn
        self.drop_rate = drop_rate
        self.is_last = is_last
        
        # Linear transformation
        self.linear = nn.Linear(in_channels, out_channels)
        
        # Batch normalization
        if use_bn:
            self.bn = nn.BatchNorm1d(out_channels)
        
        # Dropout
        if drop_rate > 0:
            self.dropout = nn.Dropout(drop_rate)
    
    def forward(self, X, hypergraph):
        """
        Forward pass for hypergraph convolution.
        
        Args:
            X: Node features [N, in_channels]
            hypergraph: Hypergraph object with e (hyperedges) attribute
        """
        # Apply linear transformation
        X = self.linear(X)
        
        # Apply batch normalization if enabled
        if self.use_bn:
            X = self.bn(X)
        
        # Apply dropout if not the last layer
        if not self.is_last and self.drop_rate > 0:
            X = self.dropout(X)
        
        # Apply non-linearity if not the last layer
        if not self.is_last:
            X = F.relu(X)
        
        return X

class SimpleHGNNP(nn.Module):
    """Simplified HGNNP model."""
    
    def __init__(self, in_channels, hid_channels, num_classes, use_bn=False, drop_rate=0.5):
        super(SimpleHGNNP, self).__init__()
        
        self.layers = nn.ModuleList()
        
        # First layer
        self.layers.append(
            SimpleHGNNPConv(in_channels, hid_channels, use_bn=use_bn, drop_rate=drop_rate)
        )
        
        # Second layer (output layer)
        self.layers.append(
            SimpleHGNNPConv(hid_channels, num_classes, use_bn=use_bn, is_last=True)
        )
    
    def forward(self, X, hypergraph):
        """
        Forward pass.
        
        Args:
            X: Node features [N, in_channels]
            hypergraph: Hypergraph object
        """
        for layer in self.layers:
            X = layer(X, hypergraph)
        return X
