from .gnn_model_fixed import SimpleGNN
from .digl_model import GIPLD
from .env_generators import EnvironmentGenerator
from .subgraph_extractor import SubgraphExtractor
from .prototype_aligner import WassersteinPrototype

# 1. First import models needed for disc training (these must exist)
try:
    from .digl_model import DIGLModel
    from .simple_digl import SimpleDIGL
except ImportError as e:
    print(f"Warning: Could not import disc models: {e}")
    # Create placeholders
    class DIGLModel:
        pass
    class SimpleDIGL:
        pass

# 2. GOOD related models (optional, failure allowed)
CompleteDIGL = None
GIPLD = None
EnvironmentGenerator = None
SubgraphExtractor = None
WassersteinPrototype = None
DisentangleLoss = None
CausalIntervention = None

# 3. Try importing GOOD models
try:
    from .good_model import CompleteDIGL
    GIPLD = CompleteDIGL  # Alias
except ImportError:
    # Create simple CompleteDIGL class
    import torch.nn as nn
    class CompleteDIGL(nn.Module):
        def __init__(self, *args, **kwargs):
            super().__init__()
            print("Note: Using simplified CompleteDIGL")
        def forward(self, x):
            return x
    CompleteDIGL = CompleteDIGL
    GIPLD = CompleteDIGL

# 4. Encoder related (handle different import scenarios)
BaseEncoder = None
InvariantEncoder = None
VariantEncoder = None
SharedEncoder = None

try:
    # Try importing actually existing classes from encoders.py
    from .encoders import InvariantEncoder, VariantEncoder, SharedEncoder
    # Create BaseEncoder as an alias for InvariantEncoder
    BaseEncoder = InvariantEncoder
except ImportError:
    # If import fails, create simple base class
    import torch.nn as nn
    class BaseEncoder(nn.Module):
        def __init__(self, input_dim=8, hidden_dim=128, output_dim=64, **kwargs):
            super().__init__()
            self.encoder = nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, output_dim)
            )
        def forward(self, x):
            return self.encoder(x)

    BaseEncoder = BaseEncoder
    InvariantEncoder = BaseEncoder
    VariantEncoder = BaseEncoder
    SharedEncoder = BaseEncoder

# 5. Other components (optional)
try:
    from .env_generators import EnvironmentGenerator
except ImportError:
    pass

try:
    from .subgraph_extractor import SubgraphExtractor
except ImportError:
    pass

try:
    from .prototype_aligner import WassersteinPrototype
except ImportError:
    pass

try:
    from .disentangle_loss import DisentangleLoss
except ImportError:
    pass

try:
    from .causal_intervention import CausalIntervention
except ImportError:
    pass

# Define GIPLD for compatibility (if it doesn't exist)
try:
    from .digl_model import GIPLD
except ImportError:
    # If GIPLD doesn't exist, provide a simple version
    class GIPLD(SimpleGNN):
        """Compatibility version of GIPLD"""
        def __init__(self, in_dim, hidden_dim, out_dim, num_environments):
            super().__init__(in_dim, hidden_dim, out_dim, num_environments)
            print("⚠️  Using SimpleGNN as compatibility version of GIPLD")

# All exported names
__all__ = [
    # Disc models
    'DIGLModel',
    'SimpleDIGL',

    # GOOD models
    'CompleteDIGL',
    'GIPLD',
    'EnvironmentGenerator',
    'SubgraphExtractor',
    'WassersteinPrototype',
    'DisentangleLoss',
    'CausalIntervention',

    # Encoders
    'BaseEncoder',
    'InvariantEncoder',
    'VariantEncoder',
    'SharedEncoder',

    'SimpleGNN',
    'WassersteinPrototype'
]