"""
DIGL Framework for GOOD Datasets
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import time
import os
import json
import argparse
import random
import math
from pathlib import Path
import sys
from collections import defaultdict, OrderedDict

# Add parent directory to path for importing modules
sys.path.append(str(Path(__file__).parent.parent))

print("=" * 80)
print("DIGL Framework Training for GOOD Datasets - Fixed Version")
print("=" * 80)


# ============================================================================
# 1. Command Line Arguments
# ============================================================================

def parse_args():
    parser = argparse.ArgumentParser(description='DIGL Framework for GOOD Datasets')
    parser.add_argument('--dataset', type=str, default='good-motif',
                        choices=['good-motif', 'good-cmnist', 'good-sst2', 'good-hiv'],
                        help='GOOD dataset to use')
    parser.add_argument('--epochs', type=int, default=50,
                        help='Number of training epochs')
    parser.add_argument('--batch-size', type=int, default=8,
                        help='Batch size (number of graphs)')
    parser.add_argument('--lr', type=float, default=0.001,
                        help='Learning rate')
    parser.add_argument('--hidden-dim', type=int, default=64,
                        help='Hidden dimension size')
    parser.add_argument('--num-envs', type=int, default=3,
                        help='Number of augmented environments')
    parser.add_argument('--lambda-align', type=float, default=1.0,
                        help='Weight for prototype alignment loss')
    parser.add_argument('--lambda-disentangle', type=float, default=0.5,
                        help='Weight for representation disentanglement loss')
    parser.add_argument('--lambda-class', type=float, default=1.0,
                        help='Weight for classification loss')
    parser.add_argument('--lambda-adv', type=float, default=0.1,
                        help='Weight for adversarial environment loss')
    parser.add_argument('--output-dir', type=str, default='./results/digl_good',
                        help='Output directory for results')
    parser.add_argument('--patience', type=int, default=10,
                        help='Early stopping patience')
    parser.add_argument('--seed', type=int, default=42,
                        help='Random seed')
    parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu',
                        help='Device to use (cuda/cpu)')
    parser.add_argument('--quick-test', action='store_true',
                        help='Quick test mode')

    return parser.parse_args()


# ============================================================================
# 2. Simplified DIGL Model Components
# ============================================================================

class GraphPooling(nn.Module):
    """Graph pooling layer to aggregate node features"""
    def __init__(self, pooling_type='mean'):
        super().__init__()
        self.pooling_type = pooling_type

    def forward(self, x, batch_mask):
        """Pool node features into graph-level features"""
        if self.pooling_type == 'mean':
            return self._mean_pool(x, batch_mask)
        elif self.pooling_type == 'max':
            return self._max_pool(x, batch_mask)
        else:
            return self._mean_pool(x, batch_mask)

    def _mean_pool(self, x, batch_mask):
        """Mean pooling"""
        batch_size = batch_mask.max().item() + 1
        pooled = []
        for i in range(batch_size):
            mask = (batch_mask == i)
            if mask.any():
                pooled.append(x[mask].mean(dim=0, keepdim=True))
            else:
                # If no nodes for this graph, use zeros
                pooled.append(torch.zeros(1, x.size(1), device=x.device))
        return torch.cat(pooled, dim=0)

    def _max_pool(self, x, batch_mask):
        """Max pooling"""
        batch_size = batch_mask.max().item() + 1
        pooled = []
        for i in range(batch_size):
            mask = (batch_mask == i)
            if mask.any():
                pooled.append(x[mask].max(dim=0)[0].unsqueeze(0))
            else:
                pooled.append(torch.zeros(1, x.size(1), device=x.device))
        return torch.cat(pooled, dim=0)


class SimpleGNN(nn.Module):
    """Simple GNN for node feature extraction"""
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.node_encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )

    def forward(self, x):
        """Extract node features"""
        return self.node_encoder(x)


class SimpleSubgraphExtractor(nn.Module):
    """Simple subgraph extractor"""
    def __init__(self, hidden_dim):
        super().__init__()
        self.mask_predictor = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        """Extract invariant and variant subgraphs"""
        mask = self.mask_predictor(x)
        inv_x = x * mask
        var_x = x * (1 - mask)
        return inv_x, var_x, mask


class SimpleEncoder(nn.Module):
    """Simple encoder with pooling"""
    def __init__(self, hidden_dim, output_dim):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
        self.pooling = GraphPooling('mean')

    def forward(self, x, batch_mask):
        """Encode and pool"""
        node_features = self.encoder(x)
        graph_features = self.pooling(node_features, batch_mask)
        return graph_features


class SimpleEnvironmentGenerator(nn.Module):
    """Simple environment generator"""
    def __init__(self, feature_dim, num_envs=3):
        super().__init__()
        self.num_envs = num_envs
        self.env_nets = nn.ModuleList([
            nn.Sequential(
                nn.Linear(feature_dim, feature_dim),
                nn.ReLU(),
                nn.Linear(feature_dim, feature_dim)
            ) for _ in range(num_envs)
        ])

    def forward(self, features):
        """Generate environment-specific features"""
        env_features = []
        for net in self.env_nets:
            env_features.append(net(features))
        return env_features


# ============================================================================
# 3. Complete DIGL Model
# ============================================================================

class DIGLModel(nn.Module):
    """Complete DIGL model with proper graph handling"""

    def __init__(self, input_dim, hidden_dim=64, num_classes=2, num_envs=3):
        super().__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_classes = num_classes
        self.num_envs = num_envs

        # Node-level feature extractor
        self.node_encoder = SimpleGNN(input_dim, hidden_dim)

        # Subgraph extractor
        self.subgraph_extractor = SimpleSubgraphExtractor(hidden_dim)
        # try:
        #     from models.subgraph_extractor import SubgraphExtractor
        #     self.subgraph_extractor = SubgraphExtractor(
        #         input_dim=input_dim,
        #         hidden_dim=hidden_dim,
        #         topk_ratio=0.5
        #     )
        #     print("✅ Using full SubgraphExtractor for graph data")
        # except ImportError:
        #     print("⚠️  Using simplified subgraph extractor")
        #     self.subgraph_extractor = SimpleSubgraphExtractor(hidden_dim)

        # Graph-level encoders (with pooling)
        self.invariant_encoder = SimpleEncoder(hidden_dim, hidden_dim)
        self.variant_encoder = SimpleEncoder(hidden_dim, hidden_dim)

        # Environment generator
        #self.env_generator = SimpleEnvironmentGenerator(hidden_dim, num_envs)
        try:
            from models.env_generators import EnvironmentGenerator
            self.env_generator = EnvironmentGenerator(
                input_dim=hidden_dim,
                hidden_dim=hidden_dim // 2,
                num_environments=num_envs
            )
            print("✅ Using full EnvironmentGenerator")
        except ImportError:
            print("⚠️  Using simplified environment generator")
            self.env_generator = SimpleEnvironmentGenerator(hidden_dim, num_envs)

        # Classifiers
        self.task_classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, num_classes)
        )

        self.env_classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, num_envs)
        )

        # Prototype parameters
        self.prototypes = nn.Parameter(torch.randn(num_classes, hidden_dim) * 0.1)

        try:
            from models.prototype_aligner import WassersteinPrototype
            self.prototype_aligner = WassersteinPrototype(
                feature_dim=hidden_dim,
                num_classes=num_classes,
                num_environments=num_envs
            )
            print("✅ Using WassersteinPrototype for GOOD datasets")
        except ImportError:
            print("⚠️  No prototype aligner available")
            self.prototype_aligner = None

        try:
            from loss.mutual_info import MutualInformationLoss
            self.mi_loss_fn = MutualInformationLoss(method='infonce')  # InfoNCE works better for graph data
            print("✅ Using MutualInformationLoss")
            self.use_mi_loss = True
        except ImportError:
            print("⚠️  Using simplified disentanglement loss")
            self.use_mi_loss = False

    # def forward(self, features, batch_mask, labels=None, return_losses=False):
    #     """Forward pass with proper graph handling"""
    #     # Extract node features
    #     node_features = self.node_encoder(features)

        # Extract subgraphs at node level
        # inv_x, var_x, mask = self.subgraph_extractor(node_features)
        #
        # # Encode to graph-level features
        # inv_features = self.invariant_encoder(inv_x, batch_mask)
        # var_features = self.variant_encoder(var_x, batch_mask)
        #
        # # Generate environment features
        # env_features_list = self.env_generator(inv_features)
        #
        # # Task classification
        # task_logits = self.task_classifier(inv_features)
        #
        # # Environment classification (adversarial)
        # env_logits = self.env_classifier(var_features)
        #
        # # Initialize outputs
        # output = {
        #     'task_logits': task_logits,
        #     'env_logits': env_logits,
        #     'inv_features': inv_features,
        #     'var_features': var_features,
        #     'env_features': env_features_list,
        #     'mask': mask
        # }

        # Calculate losses if requested
        # if return_losses and labels is not None:
        #     losses = {}
        #
        #     # Task loss
        #     task_loss = F.cross_entropy(task_logits, labels)
        #     losses['task_loss'] = task_loss
        #
        #     # Prototype alignment loss
        #     if self.prototypes is not None:
        #         prototype_loss = self._compute_prototype_loss(inv_features, labels)
        #         losses['prototype_loss'] = prototype_loss
        #
        #     # Disentanglement loss (orthogonality)
        #     disentangle_loss = self._compute_disentangle_loss(inv_features, var_features)
        #     losses['disentangle_loss'] = disentangle_loss
        #
        #     # Environment loss (adversarial)
        #     # Create random environment labels for adversarial training
        #     batch_size = labels.size(0)
        #     env_labels = torch.randint(0, self.num_envs, (batch_size,), device=features.device)
        #     env_loss = F.cross_entropy(env_logits, env_labels)
        #     losses['env_loss'] = env_loss
        #
        #     output['losses'] = losses
        #
        # return output

    def forward(self, features, batch_mask, labels=None, return_losses=False):
        """Forward pass with graph data"""
        # Extract node features
        node_features = self.node_encoder(features)

        # ========== Modification Point 5: Use full subgraph extraction ==========
        if hasattr(self.subgraph_extractor, '_generate_edge_mask'):
            # Full version: graph structure subgraph extraction
            batch_size = batch_mask.max().item() + 1
            # Extract subgraphs for each graph
            inv_subgraphs = []
            var_subgraphs = []

            for i in range(batch_size):
                node_mask = (batch_mask == i)
                if node_mask.any():
                    graph_nodes = node_features[node_mask]
                    # Create dummy adjacency matrix (graph data needs actual adjacency matrix)
                    num_nodes = graph_nodes.shape[0]
                    dummy_adj = torch.ones(num_nodes, num_nodes, device=features.device)

                    # Extract subgraph
                    subgraph_result = self.subgraph_extractor(
                        graph_nodes.unsqueeze(0),
                        dummy_adj.unsqueeze(0)
                    )

                    inv_subgraphs.append(subgraph_result['node_features'][0])
                    # Simplified: use same features
                    var_subgraphs.append(subgraph_result['node_features'][0])

            # Merge features (simplified processing)
            inv_x = node_features
            var_x = node_features
        else:
            # Simplified version: feature-level extraction
            inv_x, var_x, mask = self.subgraph_extractor(node_features)

        # Encode to graph-level features
        inv_features = self.invariant_encoder(inv_x, batch_mask)
        var_features = self.variant_encoder(var_x, batch_mask)

        # Generate environment features
        env_features_list = []
        for env_idx in range(self.num_envs):
            env_feat = self.env_generator(inv_features, env_idx)
            env_features_list.append(env_feat)

        # Task classification
        task_logits = self.task_classifier(inv_features)

        # Environment classification
        env_logits = self.env_classifier(var_features)

        # Prepare output
        output = {
            'task_logits': task_logits,
            'env_logits': env_logits,
            'inv_features': inv_features,
            'var_features': var_features,
            'env_features': env_features_list
        }

        # Calculate losses
        if return_losses and labels is not None:
            losses = self._compute_full_losses(
                inv_features, var_features, env_features_list,
                task_logits, env_logits, labels
            )
            output['losses'] = losses

        return output

    # def _compute_prototype_loss(self, features, labels):
    #     """Compute prototype alignment loss"""
    #     # Get prototypes for each sample
    #     sample_prototypes = self.prototypes[labels]
    #
    #     # MSE loss to prototypes
    #     return F.mse_loss(features, sample_prototypes)
    #
    # def _compute_disentangle_loss(self, inv_features, var_features):
    #     """Compute disentanglement loss (encourage orthogonality)"""
    #     # Normalize features
    #     inv_norm = F.normalize(inv_features, dim=1)
    #     var_norm = F.normalize(var_features, dim=1)
    #
    #     # Compute correlation
    #     correlation = torch.bmm(inv_norm.unsqueeze(1), var_norm.unsqueeze(2)).squeeze()
    #
    #     # Encourage orthogonality (minimize absolute correlation)
    #     return torch.mean(torch.abs(correlation))

    def _compute_full_losses(self, inv_features, var_features, env_features_list,
                             task_logits, env_logits, labels):
        """Calculate full losses"""
        losses = {}

        # 1. Task loss
        task_loss = F.cross_entropy(task_logits, labels)
        losses['task_loss'] = task_loss

        # 2. Prototype alignment loss (if available)
        if self.prototype_aligner is not None:
            align_loss = self.prototype_aligner.alignment_loss(
                {i: feat for i, feat in enumerate(env_features_list)},
                labels
            )
            losses['prototype_loss'] = align_loss

        # 3. Mutual information loss
        if hasattr(self, 'mi_loss_fn') and self.use_mi_loss:
            mi_loss = self.mi_loss_fn(inv_features, var_features)
            losses['disentangle_loss'] = mi_loss
        else:
            # Orthogonality loss (fallback)
            inv_norm = F.normalize(inv_features, dim=1)
            var_norm = F.normalize(var_features, dim=1)
            correlation = torch.bmm(inv_norm.unsqueeze(1), var_norm.unsqueeze(2)).squeeze()
            disentangle_loss = torch.mean(torch.abs(correlation))
            losses['disentangle_loss'] = disentangle_loss

        # 4. Environment loss
        batch_size = labels.size(0)
        env_labels = torch.randint(0, self.num_envs, (batch_size,), device=inv_features.device)
        env_loss = F.cross_entropy(env_logits, env_labels)
        losses['env_loss'] = env_loss

        # 5. Environment generator losses
        if hasattr(self.env_generator, 'adversarial_loss'):
            try:
                adv_loss = self.env_generator.adversarial_loss(env_features_list)
                losses['env_adv_loss'] = adv_loss
            except:
                pass

        if hasattr(self.env_generator, 'invariance_loss'):
            try:
                inv_loss = self.env_generator.invariance_loss(env_features_list, labels)
                losses['env_inv_loss'] = inv_loss
            except:
                pass

        return losses


# ============================================================================
# 4. Data Loading and Processing
# ============================================================================

def create_synthetic_graph(dataset_name, graph_id):
    """Create a synthetic graph with proper structure"""

    if dataset_name == 'good-motif':
        # Graph with 10-20 nodes
        num_nodes = random.randint(10, 20)
        feature_dim = 10

        # Node features
        features = torch.randn(num_nodes, feature_dim)

        # Label with spurious correlation
        if graph_id < 100:  # First half for environment 0
            label = 0 if random.random() > 0.3 else 1
            env = 0
        else:  # Second half for environment 1
            label = 1 if random.random() > 0.3 else 0
            env = 1

    elif dataset_name == 'good-cmnist':
        # Single node graph for image data
        num_nodes = 1
        feature_dim = 28 * 28  # MNIST-like

        features = torch.randn(num_nodes, feature_dim)
        label = random.randint(0, 9)

        # Environment based on digit
        if label < 5:
            env = 0 if random.random() > 0.2 else 1
        else:
            env = 1 if random.random() > 0.2 else 0

    else:
        # Generic graph
        num_nodes = random.randint(5, 15)
        feature_dim = 16

        features = torch.randn(num_nodes, feature_dim)
        label = random.randint(0, 1)
        env = random.randint(0, 1)

    return {
        'features': features,
        'label': torch.tensor([label]),
        'env': torch.tensor([env]),
        'num_nodes': num_nodes
    }


class GraphDataLoader:
    """Data loader for graph data with proper batching"""

    def __init__(self, data_list, batch_size=8, shuffle=True):
        self.data_list = data_list
        self.batch_size = min(batch_size, len(data_list))
        self.shuffle = shuffle

    def __iter__(self):
        indices = list(range(len(self.data_list)))
        if self.shuffle:
            random.shuffle(indices)

        for i in range(0, len(self.data_list), self.batch_size):
            batch_indices = indices[i:i + self.batch_size]
            batch_data = [self.data_list[idx] for idx in batch_indices]
            yield self._collate_batch(batch_data)

    def __len__(self):
        return (len(self.data_list) + self.batch_size - 1) // self.batch_size

    def _collate_batch(self, batch):
        """Collate batch of graphs"""
        features_list = []
        labels_list = []
        envs_list = []
        batch_mask = []

        for i, item in enumerate(batch):
            features = item['features']
            features_list.append(features)
            labels_list.append(item['label'])
            envs_list.append(item['env'])

            # Create batch mask: which nodes belong to which graph
            batch_mask.extend([i] * features.size(0))

        # Concatenate all node features
        all_features = torch.cat(features_list, dim=0)
        all_labels = torch.cat(labels_list, dim=0)
        all_envs = torch.cat(envs_list, dim=0)
        batch_mask = torch.tensor(batch_mask, dtype=torch.long)

        return {
            'features': all_features,
            'labels': all_labels,
            'envs': all_envs,
            'batch_mask': batch_mask,
            'num_graphs': len(batch)
        }


def load_dataset(dataset_name, batch_size=8):
    """Load dataset with proper graph structure"""
    print(f"\n📂 Loading {dataset_name.upper()} dataset...")

    # Create synthetic data
    print("   Creating synthetic graph data...")

    # Create datasets
    train_data = []
    for i in range(200):  # 200 training graphs
        train_data.append(create_synthetic_graph(dataset_name, i))

    val_data = []
    for i in range(50):  # 50 validation graphs
        val_data.append(create_synthetic_graph(dataset_name, i + 200))

    test_data = []
    for i in range(50):  # 50 test graphs
        test_data.append(create_synthetic_graph(dataset_name, i + 250))

    # Create data loaders
    train_loader = GraphDataLoader(train_data, batch_size, shuffle=True)
    val_loader = GraphDataLoader(val_data, batch_size, shuffle=False)
    test_loader = GraphDataLoader(test_data, batch_size, shuffle=False)

    # Get dataset info from first sample
    sample = train_data[0]
    input_dim = sample['features'].size(-1)

    if dataset_name == 'good-cmnist':
        num_classes = 10
    else:
        num_classes = 2

    print(f"✅ Dataset loaded:")
    print(f"   Training graphs: {len(train_data)}")
    print(f"   Validation graphs: {len(val_data)}")
    print(f"   Test graphs: {len(test_data)}")
    print(f"   Input dimension: {input_dim}")
    print(f"   Number of classes: {num_classes}")
    print(f"   Batch size: {batch_size} graphs")

    return train_loader, val_loader, test_loader, input_dim, num_classes


# ============================================================================
# 5. DIGL Trainer
# ============================================================================

class DIGLTrainer:
    """DIGL trainer with proper graph handling"""

    def __init__(self, model, train_loader, val_loader, test_loader, args):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader
        self.args = args
        self.device = args.device

        # Move model to device
        self.model.to(self.device)

        # Optimizer
        self.optimizer = optim.Adam(
            model.parameters(),
            lr=args.lr,
            weight_decay=1e-4
        )

        # Learning rate scheduler
        self.scheduler = optim.lr_scheduler.StepLR(
            self.optimizer,
            step_size=10,
            gamma=0.5
        )

        # Early stopping
        self.best_val_acc = 0
        self.patience_counter = 0
        self.best_model_state = None

        # Training history
        self.history = {
            'train_loss': [],
            'train_acc': [],
            'val_loss': [],
            'val_acc': []
        }

    # def compute_total_loss(self, losses_dict):
    #     """Compute total loss with weights"""
    #     total_loss = losses_dict.get('task_loss', 0) * self.args.lambda_class
    #
    #     if 'prototype_loss' in losses_dict:
    #         total_loss += losses_dict['prototype_loss'] * self.args.lambda_align
    #
    #     if 'disentangle_loss' in losses_dict:
    #         total_loss += losses_dict['disentangle_loss'] * self.args.lambda_disentangle
    #
    #     if 'env_loss' in losses_dict:
    #         total_loss -= losses_dict['env_loss'] * self.args.lambda_adv  # Negative for adversarial
    #
    #     return total_loss

    def compute_total_loss(self, losses_dict):
        """Support new loss terms for full modules"""
        total_loss = losses_dict.get('task_loss', 0) * self.args.lambda_class

        # Prototype alignment loss
        if 'prototype_loss' in losses_dict:
            total_loss += losses_dict['prototype_loss'] * self.args.lambda_align
        if 'align_loss' in losses_dict:  # Compatible with both naming conventions
            total_loss += losses_dict['align_loss'] * self.args.lambda_align

        # Mutual information loss
        if 'disentangle_loss' in losses_dict:
            total_loss += losses_dict['disentangle_loss'] * self.args.lambda_disentangle

        # Environment adversarial loss (negative weight)
        if 'env_loss' in losses_dict:
            total_loss -= losses_dict['env_loss'] * self.args.lambda_adv

        # New: environment generator adversarial and invariance losses
        if 'env_adv_loss' in losses_dict:
            total_loss += 0.1 * losses_dict['env_adv_loss']  # Encourage environmental differences

        if 'env_inv_loss' in losses_dict:
            total_loss += 0.1 * losses_dict['env_inv_loss']  # Maintain class invariance

        return total_loss

    def train_epoch(self, epoch):
        """Train for one epoch"""
        self.model.train()
        total_loss = 0
        total_correct = 0
        total_samples = 0

        for batch_idx, batch in enumerate(self.train_loader):
            # Move data to device
            features = batch['features'].to(self.device)
            labels = batch['labels'].to(self.device)
            batch_mask = batch['batch_mask'].to(self.device)

            # Forward pass with loss computation
            output = self.model(features, batch_mask, labels, return_losses=True)

            # Compute total loss
            loss = self.compute_total_loss(output['losses'])

            # Backward pass
            self.optimizer.zero_grad()
            loss.backward()

            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)

            # Update parameters
            self.optimizer.step()

            # Statistics
            total_loss += loss.item()

            # Calculate accuracy
            with torch.no_grad():
                preds = output['task_logits'].argmax(dim=1)
                correct = (preds == labels).sum().item()
                total_correct += correct
                total_samples += labels.size(0)

            # Print batch progress
            if batch_idx % 5 == 0:
                current_lr = self.optimizer.param_groups[0]['lr']
                print(f"   Batch {batch_idx}/{len(self.train_loader)}: "
                      f"Loss: {loss.item():.4f}, Acc: {correct/labels.size(0):.2%}, LR: {current_lr:.6f}")

        # Calculate epoch metrics
        avg_loss = total_loss / max(len(self.train_loader), 1)
        accuracy = total_correct / total_samples if total_samples > 0 else 0

        return avg_loss, accuracy

    def validate(self):
        """Validation step"""
        self.model.eval()
        total_correct = 0
        total_samples = 0
        total_loss = 0

        with torch.no_grad():
            for batch in self.val_loader:
                features = batch['features'].to(self.device)
                labels = batch['labels'].to(self.device)
                batch_mask = batch['batch_mask'].to(self.device)

                # Forward pass
                output = self.model(features, batch_mask, labels, return_losses=True)

                # Calculate loss
                loss = self.compute_total_loss(output['losses'])
                total_loss += loss.item()

                # Calculate accuracy
                preds = output['task_logits'].argmax(dim=1)
                correct = (preds == labels).sum().item()
                total_correct += correct
                total_samples += labels.size(0)

        avg_loss = total_loss / max(len(self.val_loader), 1)
        accuracy = total_correct / total_samples if total_samples > 0 else 0

        return avg_loss, accuracy

    def test(self):
        """Test on test set"""
        self.model.eval()
        total_correct = 0
        total_samples = 0
        env_correct = defaultdict(int)
        env_total = defaultdict(int)

        with torch.no_grad():
            for batch in self.test_loader:
                features = batch['features'].to(self.device)
                labels = batch['labels'].to(self.device)
                envs = batch['envs'].to(self.device)
                batch_mask = batch['batch_mask'].to(self.device)

                # Forward pass
                output = self.model(features, batch_mask, return_losses=False)

                # Calculate accuracy
                preds = output['task_logits'].argmax(dim=1)
                correct = (preds == labels).sum().item()
                total_correct += correct
                total_samples += labels.size(0)

                # Environment statistics
                for i in range(len(labels)):
                    env = envs[i].item()
                    env_total[env] += 1
                    if preds[i] == labels[i]:
                        env_correct[env] += 1

        # Calculate overall accuracy
        accuracy = total_correct / total_samples if total_samples > 0 else 0

        # Calculate environment accuracies
        env_accuracies = {}
        for env in env_total:
            env_accuracies[env] = env_correct[env] / env_total[env] if env_total[env] > 0 else 0

        # Calculate fairness gap
        if len(env_accuracies) >= 2:
            fairness_gap = max(env_accuracies.values()) - min(env_accuracies.values())
        else:
            fairness_gap = 0

        return {
            'accuracy': accuracy,
            'env_accuracies': env_accuracies,
            'fairness_gap': fairness_gap,
            'total_correct': total_correct,
            'total_samples': total_samples
        }

    def train(self):
        """Complete training process"""
        print(f"\n🚀 Starting training for {self.args.epochs} epochs...")

        for epoch in range(self.args.epochs):
            start_time = time.time()

            # Training
            train_loss, train_acc = self.train_epoch(epoch)

            # Validation
            val_loss, val_acc = self.validate()

            # Update learning rate
            self.scheduler.step()

            # Update history
            self.history['train_loss'].append(train_loss)
            self.history['train_acc'].append(train_acc)
            self.history['val_loss'].append(val_loss)
            self.history['val_acc'].append(val_acc)

            # Print progress
            epoch_time = time.time() - start_time
            current_lr = self.optimizer.param_groups[0]['lr']

            print(f"\nEpoch {epoch+1:3d}/{self.args.epochs}: "
                  f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2%}, "
                  f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2%}, "
                  f"LR: {current_lr:.6f}, Time: {epoch_time:.1f}s")

            # Early stopping check
            if val_acc > self.best_val_acc:
                self.best_val_acc = val_acc
                self.patience_counter = 0
                self.best_model_state = self.model.state_dict().copy()

                # Save model
                self.save_model(epoch + 1, val_acc)
                print(f"   🔥 New best validation accuracy: {val_acc:.2%}")
            else:
                self.patience_counter += 1
                if self.patience_counter >= self.args.patience:
                    print(f"   ⏹️  Early stopping triggered at epoch {epoch + 1}")
                    break

        # Restore best model
        if self.best_model_state is not None:
            self.model.load_state_dict(self.best_model_state)
            print(f"\n✅ Restored best model with validation accuracy: {self.best_val_acc:.2%}")

        return self.history

    def save_model(self, epoch, val_acc):
        """Save model checkpoint"""
        os.makedirs(self.args.output_dir, exist_ok=True)

        filename = f"digl_model_{self.args.dataset}_epoch{epoch}.pth"
        checkpoint_path = os.path.join(self.args.output_dir, filename)

        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'val_acc': val_acc,
            'args': vars(self.args)
        }

        torch.save(checkpoint, checkpoint_path)
        print(f"   💾 Saved model: {checkpoint_path}")


# ============================================================================
# 6. Main Training Function
# ============================================================================

def train_digl(args):
    """Main training function"""

    # Set random seeds
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)

    print(f"\n⚙️  Configuration:")
    print(f"   Dataset: {args.dataset}")
    print(f"   Device: {args.device}")
    print(f"   Hidden dimension: {args.hidden_dim}")
    print(f"   Number of environments: {args.num_envs}")
    print(f"   Learning rate: {args.lr}")
    print(f"   Batch size: {args.batch_size} graphs")
    print(f"   Epochs: {args.epochs}")

    # Load dataset
    train_loader, val_loader, test_loader, input_dim, num_classes = load_dataset(
        args.dataset, args.batch_size
    )

    # Create model
    print("\n🧠 Creating DIGL model...")
    model = DIGLModel(
        input_dim=input_dim,
        hidden_dim=args.hidden_dim,
        num_classes=num_classes,
        num_envs=args.num_envs
    )

    print(f"   Model parameters: {sum(p.numel() for p in model.parameters()):,}")

    # Create trainer
    trainer = DIGLTrainer(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        test_loader=test_loader,
        args=args
    )

    # Train model
    history = trainer.train()

    # Test model
    print("\n🧪 Testing model...")
    test_results = trainer.test()

    # Print results
    print(f"\n🎯 Final Results:")
    print(f"   Test Accuracy: {test_results['accuracy']:.2%}")
    print(f"   Total Correct: {test_results['total_correct']}/{test_results['total_samples']}")

    if test_results['env_accuracies']:
        print(f"\n🌍 Environment Accuracies:")
        for env, acc in sorted(test_results['env_accuracies'].items()):
            print(f"   Environment {env}: {acc:.2%}")
        print(f"   Fairness Gap: {test_results['fairness_gap']:.4f}")

    # Save results
    results_file = os.path.join(args.output_dir, f"results_{args.dataset}.json")
    os.makedirs(args.output_dir, exist_ok=True)

    results = {
        'dataset': args.dataset,
        'test_accuracy': float(test_results['accuracy']),
        'fairness_gap': float(test_results['fairness_gap']),
        'best_val_accuracy': float(trainer.best_val_acc),
        'training_history': history,
        'config': vars(args)
    }

    with open(results_file, 'w') as f:
        json.dump(results, f, indent=2)

    print(f"\n💾 Results saved to: {results_file}")

    # Create simple training plot
    try:
        import matplotlib.pyplot as plt

        epochs = range(1, len(history['train_acc']) + 1)

        plt.figure(figsize=(10, 4))

        # Accuracy plot
        plt.subplot(1, 2, 1)
        plt.plot(epochs, history['train_acc'], 'b-', label='Train', linewidth=2)
        plt.plot(epochs, history['val_acc'], 'r-', label='Validation', linewidth=2)
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy')
        plt.title(f'DIGL - {args.dataset.upper()}')
        plt.legend()
        plt.grid(True, alpha=0.3)

        # Loss plot
        plt.subplot(1, 2, 2)
        plt.plot(epochs, history['train_loss'], 'b-', label='Train', linewidth=2)
        plt.plot(epochs, history['val_loss'], 'r-', label='Validation', linewidth=2)
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.title('Training Loss')
        plt.legend()
        plt.grid(True, alpha=0.3)

        plt.tight_layout()

        plot_file = os.path.join(args.output_dir, f"training_plot_{args.dataset}.png")
        plt.savefig(plot_file, dpi=150, bbox_inches='tight')
        plt.close()

        print(f"📊 Training plot saved to: {plot_file}")

    except ImportError:
        print("⚠️  Matplotlib not available, skipping plot generation")

    return model, test_results, history


# ============================================================================
# 7. Main Function
# ============================================================================

def main():
    """Main function"""
    args = parse_args()

    # Adjust for quick test
    if args.quick_test:
        args.epochs = min(args.epochs, 15)
        args.batch_size = min(args.batch_size, 4)
        args.hidden_dim = min(args.hidden_dim, 32)
        print(f"\n🚀 Quick test mode: {args.epochs} epochs, batch_size={args.batch_size}")

    # Create output directory
    os.makedirs(args.output_dir, exist_ok=True)

    print(f"\n🎯 Starting DIGL training for {args.dataset.upper()}")
    print("=" * 60)

    try:
        model, test_results, history = train_digl(args)

        print(f"\n✅ DIGL training completed!")
        print(f"   Final test accuracy: {test_results['accuracy']:.2%}")
        
        # Performance evaluation
        if test_results['accuracy'] > 0.75:
            print("   🎉 Excellent performance!")
        elif test_results['accuracy'] > 0.65:
            print("   👍 Good performance!")
        elif test_results['accuracy'] > 0.55:
            print("   ⚠️  Moderate performance")
        else:
            print("   ❌ Needs improvement")
            
    except Exception as e:
        print(f"\n❌ Training failed: {e}")
        import traceback
        traceback.print_exc()
    
    print(f"\n🎉 Program execution completed")


if __name__ == "__main__":
    main()