import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleDIGL(nn.Module):
    """
    Simplified DIGL model for virtual dataset testing
    Works without adjacency matrix
    """
    def __init__(self, in_dim=16, hidden_dim=64, out_dim=4, 
                 num_environments=2, use_wasserstein=True,
                 use_causal_intervention=True, memory_size=100):
        super().__init__()
        
        print(f"🔧 SimpleDIGL: in_dim={in_dim}, hidden_dim={hidden_dim}, out_dim={out_dim}")
        
        # Simple encoder
        self.encoder = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        
        # Classifier
        self.classifier = nn.Linear(hidden_dim, out_dim)
    
    def forward(self, x, adj=None, labels=None, training=False):
        """
        Forward pass - adj parameter is optional
        """
        # Global mean pooling
        if x.dim() == 3:
            x_pooled = x.mean(dim=1)  # [batch_size, in_dim]
        else:
            x_pooled = x.mean(dim=0, keepdim=True)  # [1, in_dim]
        
        # Encode
        encoded = self.encoder(x_pooled)
        
        # Classify
        logits = self.classifier(encoded)
        
        output = {'logits': logits}
        
        if training and labels is not None:
            loss = F.cross_entropy(logits, labels)
            output['losses'] = {'total_loss': loss}
            
            accuracy = (logits.argmax(dim=1) == labels).float().mean()
            output['training_stats'] = {'accuracy': accuracy.item()}
        
        return output

# For compatibility with existing code
DIGLModel = SimpleDIGL
