"""
Unified DIGL Trainer - Supports GOOD and DisC Datasets
"""
import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import time
import json
import argparse
import random
from pathlib import Path
from collections import defaultdict

sys.path.append(str(Path(__file__).parent.parent))

print("=" * 80)
print("DIGL Unified Trainer")
print("=" * 80)


class UnifiedConfig:
    """Unified Configuration Class"""

    @staticmethod
    def get_config(dataset_type, args):
        """Get configuration"""
        # Set default lambda values
        lambda_adv = getattr(args, 'lambda_adv', 0.1)
        lambda_irm = getattr(args, 'lambda_irm', 0.1)
        lambda_vrex = getattr(args, 'lambda_vrex', 0.1)

        base_config = {
            'training': {
                'epochs': args.epochs,
                'batch_size': args.batch_size,
                'lr': args.lr,
                'weight_decay': 1e-4,
                'patience': args.patience
            },
            'model': {
                'hidden_dim': args.hidden_dim,
                'num_environments': 2
            }
        }

        if dataset_type == 'disc':
            disc_config = {
                'data': {
                    'color_bias': 0.9,
                    'num_colors': 10,
                    'img_size': 28
                },
                'loss': {
                    'lambda_w': 0.05,  # Wasserstein weight
                    'lambda_adv': lambda_adv
                }
            }
            return {**base_config, **disc_config}

        elif dataset_type == 'good':
            good_config = {
                'data': {
                    'feature_dim': 16,
                    'num_classes': 2
                },
                'loss': {
                    'lambda_adv': lambda_adv,
                    'lambda_irm': lambda_irm,
                    'lambda_vrex': lambda_vrex
                }
            }
            return {**base_config, **good_config}


class SimpleGoodModel(nn.Module):
    """Simplified GOOD model (avoid torch_geometric dependency)"""

    def __init__(self, input_dim=16, hidden_dim=256, num_classes=2, num_envs=2):
        super().__init__()
        # Use simple MLP instead of GNN
        self.encoder = nn.Sequential(
            nn.Linear(input_dim * 20, hidden_dim),  # Assume at most 20 nodes
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.3)
        )
        self.classifier = nn.Linear(hidden_dim // 2, num_classes)
        self.env_classifier = nn.Linear(hidden_dim // 2, num_envs)

    def extract_features(self, batch):
        """Extract features from batch"""
        device = next(self.parameters()).device  # Get device where model is located

        if isinstance(batch, list):
            # List format returned by VirtualDataset
            x_list = [item['x'] for item in batch]
            labels = torch.tensor([item['label'] for item in batch], device=device)
            envs = torch.tensor([item['environment'] for item in batch], device=device) if 'environment' in batch[0] else None
        elif isinstance(batch, dict):
            # Dictionary format
            x_list = batch['x']
            labels = batch['label'].to(device) if torch.is_tensor(batch['label']) else torch.tensor(batch['label'], device=device)
            envs = batch.get('environment')
            if envs is not None and torch.is_tensor(envs):
                envs = envs.to(device)
            elif envs is not None:
                envs = torch.tensor(envs, device=device)
        else:
            # Try torch_geometric Batch format
            try:
                x_list = [batch.x]
                labels = batch.y.to(device) if hasattr(batch, 'y') else None
                envs = batch.env_id.to(device) if hasattr(batch, 'env_id') else None
            except:
                raise ValueError(f"Unrecognized input format: {type(batch)}")

        # Ensure x_list is a list
        if not isinstance(x_list, list):
            x_list = [x_list]

        batch_size = len(x_list)
        features = []

        for i in range(batch_size):
            x = x_list[i]

            # Ensure tensor is on correct device
            if isinstance(x, torch.Tensor):
                x = x.to(device)
            else:
                x = torch.tensor(x, device=device, dtype=torch.float32)

            if x.dim() == 2:
                # Take average of node features as graph representation
                graph_feature = x.mean(dim=0, keepdim=True)
            elif x.dim() == 1:
                # 1D feature, use directly
                graph_feature = x.unsqueeze(0)
            else:
                # Flatten features
                graph_feature = x.view(1, -1)

            # Adjust feature dimension
            target_dim = 20 * 16  # input_dim * 20
            current_dim = graph_feature.size(1)

            if current_dim < target_dim:
                # Pad to fixed dimension
                padding = torch.zeros(1, target_dim - current_dim, device=device)
                graph_feature = torch.cat([graph_feature, padding], dim=1)
            elif current_dim > target_dim:
                # Truncate to target dimension
                graph_feature = graph_feature[:, :target_dim]

            features.append(graph_feature)

        features = torch.cat(features, dim=0)
        return features, labels, envs

    def forward(self, batch, env_labels=None):
        """Forward propagation"""
        # Extract features
        features, labels, batch_envs = self.extract_features(batch)

        # Encode
        encoded = self.encoder(features)

        # Classify
        task_logits = self.classifier(encoded)

        output = {'task_logits': task_logits}

        # If environment classification is needed
        if env_labels is not None or batch_envs is not None:
            env_target = env_labels if env_labels is not None else batch_envs
            env_logits = self.env_classifier(encoded.detach())
            output['env_logits'] = env_logits

        return output


class DisCCompatibleModel(nn.Module):
    """DisC-compatible model wrapper"""

    def __init__(self, disc_model):
        super().__init__()
        self.disc_model = disc_model

    def forward(self, batch, env_labels=None):
        """Handle DisC data format"""
        if isinstance(batch, dict):
            # DisC dataset returns dictionary
            images = batch['image']
            labels = batch.get('label')
            envs = batch.get('environment')

            # Ensure correct image dimensions
            if images.dim() == 3:
                images = images.unsqueeze(1)  # Add channel dimension

            # Flatten images
            batch_size = images.size(0)
            images = images.view(batch_size, -1)

            # Call original model
            output = self.disc_model(images)

            # Add environment classifier output
            if envs is not None and hasattr(self.disc_model, 'env_classifier'):
                env_logits = self.disc_model.env_classifier(self.disc_model.feature_extractor(images).detach())
                output['env_logits'] = env_logits

            return output
        else:
            # Other formats, pass through directly
            return self.disc_model(batch)


class UnifiedTrainer:
    """Unified Trainer"""

    def __init__(self, model, train_loader, val_loader, test_loader, config,
                 dataset_type='good', method='erm'):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader
        self.config = config
        self.dataset_type = dataset_type
        self.method = method

        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print(f"Using device: {self.device}")
        self.model.to(self.device)

        # Optimizer
        self.optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=config['training']['lr'],
            weight_decay=config['training']['weight_decay'],
            betas=(0.9, 0.999)
        )

        # Learning rate scheduler
        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, mode='min', patience=5, factor=0.5, verbose=True
        )

        # Training history
        self.history = {
            'train_loss': [], 'train_acc': [],
            'val_loss': [], 'val_acc': [],
            'best_val_acc': 0, 'best_val_loss': float('inf')
        }

        self.patience_counter = 0
        self.best_model_state = None

    def extract_batch_data(self, batch):
        """Extract data from batch (handle multiple formats)"""
        if isinstance(batch, list):
            # List format returned by VirtualDataset
            labels = torch.tensor([item['label'] for item in batch])
            envs = torch.tensor([item['environment'] for item in batch]) if 'environment' in batch[0] else None
            return batch, labels, envs

        elif isinstance(batch, dict):
            # Dictionary format
            labels = batch['label']
            envs = batch.get('environment')
            return batch, labels, envs

        else:
            # torch_geometric Batch format or others
            try:
                labels = batch.y if hasattr(batch, 'y') else None
                envs = batch.env_id if hasattr(batch, 'env_id') else None
                return batch, labels, envs
            except:
                # If unrecognized, return directly
                return batch, None, None

    def compute_loss(self, output, labels, env_labels=None):
        """Compute loss"""
        # Ensure labels are on correct device
        if labels is None:
            return None

        if isinstance(labels, torch.Tensor):
            labels = labels.to(self.device)
        else:
            labels = torch.tensor(labels, device=self.device, dtype=torch.long)

        if self.method == 'erm':
            return F.cross_entropy(output['task_logits'], labels)

        elif self.method == 'adversarial':
            # Adversarial training loss
            task_loss = F.cross_entropy(output['task_logits'], labels)

            if 'env_logits' in output and env_labels is not None:
                # Ensure env_labels are on correct device
                if isinstance(env_labels, torch.Tensor):
                    env_labels = env_labels.to(self.device)
                else:
                    env_labels = torch.tensor(env_labels, device=self.device, dtype=torch.long)

                env_loss = F.cross_entropy(output['env_logits'], env_labels)
                lambda_adv = self.config['loss'].get('lambda_adv', 0.1)

                # GOOD datasets use gradient reversal, DisC uses Wasserstein
                if self.dataset_type == 'good':
                    return task_loss - lambda_adv * env_loss
                else:
                    # DisC: use positive weight
                    return task_loss + lambda_adv * env_loss

            return task_loss

        elif self.method == 'wasserstein':
            # Wasserstein loss (mainly for DisC)
            task_loss = F.cross_entropy(output['task_logits'], labels)

            if 'env_logits' in output and env_labels is not None:
                # Ensure env_labels are on correct device
                if isinstance(env_labels, torch.Tensor):
                    env_labels = env_labels.to(self.device)
                else:
                    env_labels = torch.tensor(env_labels, device=self.device, dtype=torch.long)

                env_loss = F.cross_entropy(output['env_logits'], env_labels)
                lambda_w = self.config['loss'].get('lambda_w', 0.05)
                return task_loss + lambda_w * env_loss

            return task_loss

        else:
            return F.cross_entropy(output['task_logits'], labels)

    def train_epoch(self, epoch):
        """Train one epoch"""
        self.model.train()
        total_loss = 0
        total_correct = 0
        total_samples = 0

        for batch_idx, batch in enumerate(self.train_loader):
            # Extract data
            batch_data, labels, envs = self.extract_batch_data(batch)

            # Move data to device (if it's a dictionary)
            if isinstance(batch_data, dict):
                for key in batch_data:
                    if isinstance(batch_data[key], torch.Tensor):
                        batch_data[key] = batch_data[key].to(self.device)

            # Forward pass
            if envs is not None:
                output = self.model(batch_data, env_labels=envs)
            else:
                output = self.model(batch_data)

            # Compute loss
            loss = self.compute_loss(output, labels, envs)
            if loss is None:
                continue

            # Backward pass
            self.optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            self.optimizer.step()

            # Statistics
            total_loss += loss.item()

            # Calculate accuracy
            preds = output['task_logits'].argmax(dim=1)
            if labels is not None:
                labels_tensor = labels.to(self.device) if isinstance(labels, torch.Tensor) else torch.tensor(labels, device=self.device)
                total_correct += (preds == labels_tensor).sum().item()
                total_samples += labels_tensor.size(0)

            # Print progress
            if (batch_idx + 1) % 10 == 0:
                if labels is not None:
                    batch_acc = (preds == labels_tensor).sum().item() / labels_tensor.size(0)
                    print(f"  Batch {batch_idx + 1}/{len(self.train_loader)}: Loss={loss.item():.4f}, Acc={batch_acc:.2%}")

        avg_loss = total_loss / len(self.train_loader) if len(self.train_loader) > 0 else 0
        accuracy = total_correct / total_samples if total_samples > 0 else 0

        return avg_loss, accuracy

    def validate(self):
        """Validation"""
        self.model.eval()
        total_loss = 0
        total_correct = 0
        total_samples = 0

        with torch.no_grad():
            for batch in self.val_loader:
                # Extract data
                batch_data, labels, _ = self.extract_batch_data(batch)

                # Move data to device (if it's a dictionary)
                if isinstance(batch_data, dict):
                    for key in batch_data:
                        if isinstance(batch_data[key], torch.Tensor):
                            batch_data[key] = batch_data[key].to(self.device)

                output = self.model(batch_data)

                if labels is not None:
                    labels_tensor = labels.to(self.device) if isinstance(labels, torch.Tensor) else torch.tensor(labels, device=self.device)
                    loss = F.cross_entropy(output['task_logits'], labels_tensor)
                    total_loss += loss.item()

                    preds = output['task_logits'].argmax(dim=1)
                    total_correct += (preds == labels_tensor).sum().item()
                    total_samples += labels_tensor.size(0)

        avg_loss = total_loss / len(self.val_loader) if len(self.val_loader) > 0 else 0
        accuracy = total_correct / total_samples if total_samples > 0 else 0

        return avg_loss, accuracy

    def train(self):
        """Training loop"""
        print(f"\n🎯 Starting training ({self.config['training']['epochs']} epochs)...")

        for epoch in range(self.config['training']['epochs']):
            start_time = time.time()

            # Train
            train_loss, train_acc = self.train_epoch(epoch)

            # Validate
            val_loss, val_acc = self.validate()

            # Update learning rate
            self.scheduler.step(val_loss)

            # Record history
            self.history['train_loss'].append(train_loss)
            self.history['train_acc'].append(train_acc)
            self.history['val_loss'].append(val_loss)
            self.history['val_acc'].append(val_acc)

            # Print progress
            epoch_time = time.time() - start_time
            current_lr = self.optimizer.param_groups[0]['lr']

            print(f"Epoch {epoch + 1:3d}/{self.config['training']['epochs']}: "
                  f"Loss={train_loss:.4f}/{val_loss:.4f}, "
                  f"Acc={train_acc:.2%}/{val_acc:.2%}, "
                  f"LR={current_lr:.6f}, Time={epoch_time:.1f}s")

            # Save best model
            if val_acc > self.history['best_val_acc']:
                self.history['best_val_acc'] = val_acc
                self.history['best_val_loss'] = val_loss
                self.patience_counter = 0
                self.best_model_state = self.model.state_dict().copy()

                print(f"  🔥 New best validation accuracy: {val_acc:.2%}")
            else:
                self.patience_counter += 1
                if self.patience_counter >= self.config['training']['patience']:
                    print(f"  ⏹️ Early stopping at epoch {epoch + 1}")
                    break

        # Restore best model
        if self.best_model_state is not None:
            self.model.load_state_dict(self.best_model_state)

        return self.history

    def test(self):
        """Testing"""
        self.model.eval()
        total_correct = 0
        total_samples = 0
        env_stats = {}

        with torch.no_grad():
            for batch in self.test_loader:
                # Extract data
                batch_data, labels, envs = self.extract_batch_data(batch)

                # Move data to device (if it's a dictionary)
                if isinstance(batch_data, dict):
                    for key in batch_data:
                        if isinstance(batch_data[key], torch.Tensor):
                            batch_data[key] = batch_data[key].to(self.device)

                output = self.model(batch_data)
                preds = output['task_logits'].argmax(dim=1).cpu().numpy()

                if labels is not None:
                    labels_np = labels.cpu().numpy() if isinstance(labels, torch.Tensor) else np.array(labels)

                    # Overall statistics
                    total_correct += (preds == labels_np).sum()
                    total_samples += len(labels_np)

                    # Environment statistics
                    if envs is not None:
                        envs_np = envs.cpu().numpy() if isinstance(envs, torch.Tensor) else np.array(envs)
                        for env in np.unique(envs_np):
                            mask = envs_np == env
                            if mask.any():
                                env_correct = (preds[mask] == labels_np[mask]).sum()
                                env_total = mask.sum()

                                if env not in env_stats:
                                    env_stats[env] = {'correct': 0, 'total': 0}

                                env_stats[env]['correct'] += env_correct
                                env_stats[env]['total'] += env_total

        overall_acc = total_correct / total_samples if total_samples > 0 else 0

        # Environment accuracies
        env_accuracies = {}
        for env, stats in env_stats.items():
            env_accuracies[env] = stats['correct'] / stats['total']

        # Fairness gap
        if len(env_accuracies) >= 2:
            env_accs = list(env_accuracies.values())
            fairness_gap = max(env_accs) - min(env_accs)
        else:
            fairness_gap = 0

        return {
            'accuracy': overall_acc,
            'env_accuracies': env_accuracies,
            'fairness_gap': fairness_gap,
            'total_samples': total_samples,
            'total_correct': total_correct
        }


def get_virtual_dataloaders(train_dataset, val_dataset, test_dataset, batch_size=32):
    """Get virtual data loaders"""
    from torch.utils.data import DataLoader

    # Custom collate function to handle list returned by VirtualDataset
    def custom_collate(batch):
        # batch is already a list, return directly
        return batch

    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=custom_collate
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=custom_collate
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=custom_collate
    )

    return train_loader, val_loader, test_loader


def run_experiment(dataset_name, method, config):
    """Run single experiment - fixed version"""
    print(f"\n{'=' * 60}")
    print(f"Dataset: {dataset_name.upper()}, Method: {method.upper()}")
    print(f"{'=' * 60}")

    # Create output directory
    output_dir = Path(f"results/unified/{dataset_name}_{method}")
    output_dir.mkdir(parents=True, exist_ok=True)

    try:
        # 1. Load data
        print("1. Loading data...")

        if dataset_name in ['good-motif', 'good-cmnist', 'good-sst2', 'good-hiv']:
            dataset_type = 'good'

            # Import virtual GOOD data
            from data.virtual_dataset import VirtualDataset

            # Create datasets
            train_dataset = VirtualDataset(
                num_graphs=800,
                feature_dim=config['data']['feature_dim'],
                num_classes=config['data']['num_classes'],
                num_environments=config['model']['num_environments'],
            )

            val_dataset = VirtualDataset(
                num_graphs=100,
                feature_dim=config['data']['feature_dim'],
                num_classes=config['data']['num_classes'],
                num_environments=config['model']['num_environments'],
            )

            test_dataset = VirtualDataset(
                num_graphs=100,
                feature_dim=config['data']['feature_dim'],
                num_classes=config['data']['num_classes'],
                num_environments=config['model']['num_environments'],
            )

            # Create data loaders
            train_loader, val_loader, test_loader = get_virtual_dataloaders(
                train_dataset, val_dataset, test_dataset,
                batch_size=config['training']['batch_size']
            )

            # Create model
            model = SimpleGoodModel(
                input_dim=config['data']['feature_dim'],
                hidden_dim=config['model']['hidden_dim'],
                num_classes=config['data']['num_classes'],
                num_envs=config['model']['num_environments']
            )

        elif dataset_name in ['cmnist', 'cfashion', 'ckuzushiji']:
            # Use DisC data generator
            dataset_type = 'disc'
            from experiments.train_disc import AdvancedDisCDataset
            train_dataset = AdvancedDisCDataset(
                name=dataset_name,
                num_samples=10000,
                color_bias=config['data']['color_bias'],
                num_colors=config['data']['num_colors'],
                split='train'
            )
            val_dataset = AdvancedDisCDataset(
                name=dataset_name,
                num_samples=2000,
                color_bias=0.5,
                num_colors=config['data']['num_colors'],
                split='val'
            )
            test_dataset = AdvancedDisCDataset(
                name=dataset_name,
                num_samples=2000,
                color_bias=0.5,
                num_colors=config['data']['num_colors'],
                split='test'
            )

            from torch.utils.data import DataLoader
            train_loader = DataLoader(train_dataset, batch_size=config['training']['batch_size'], shuffle=True)
            val_loader = DataLoader(val_dataset, batch_size=config['training']['batch_size'], shuffle=False)
            test_loader = DataLoader(test_dataset, batch_size=config['training']['batch_size'], shuffle=False)

            dataset_type = 'disc'

            # Create model
            from experiments.train_disc import AdvancedDIGLModel
            disc_model = AdvancedDIGLModel(
                in_dim=config['data']['img_size'] * config['data']['img_size'],
                hidden_dim=config['model']['hidden_dim'],
                out_dim=10,
                num_environments=config['model']['num_environments'],
                use_wasserstein=(method == 'wasserstein')
            )

            # Wrap model to handle DisC data format
            model = DisCCompatibleModel(disc_model)

        else:
            raise ValueError(f"Unknown dataset: {dataset_name}")

        print(f"   Training set: {len(train_loader)} batches")
        print(f"   Validation set: {len(val_loader)} batches")
        print(f"   Test set: {len(test_loader)} batches")

        # 2. Create trainer
        print("2. Creating trainer...")
        trainer = UnifiedTrainer(
            model=model,
            train_loader=train_loader,
            val_loader=val_loader,
            test_loader=test_loader,
            config=config,
            dataset_type=dataset_type,
            method=method
        )

        # 3. Training
        print("3. Starting training...")
        history = trainer.train()

        # 4. Testing
        print("4. Testing model...")
        test_results = trainer.test()

        # 5. Save results
        print("5. Saving results...")

        results_file = output_dir / 'results.json'
        with open(results_file, 'w') as f:
            json.dump({
                'accuracy': float(test_results['accuracy']),
                'fairness_gap': float(test_results['fairness_gap']),
                'env_accuracies': {str(k): float(v) for k, v in test_results['env_accuracies'].items()},
                'best_val_acc': float(history['best_val_acc']),
                'best_val_loss': float(history['best_val_loss']),
                'config': config,
                'history': {
                    'train_loss': [float(x) for x in history['train_loss']],
                    'train_acc': [float(x) for x in history['train_acc']],
                    'val_loss': [float(x) for x in history['val_loss']],
                    'val_acc': [float(x) for x in history['val_acc']]
                }
            }, f, indent=2)

        print(f"\n✅ Experiment completed!")
        print(f"   Test accuracy: {test_results['accuracy']:.2%}")
        print(f"   Fairness gap: {test_results['fairness_gap']:.4f}")
        print(f"   Best validation accuracy: {history['best_val_acc']:.2%}")
        print(f"   Results saved to: {results_file}")

        if test_results['env_accuracies']:
            print(f"   Environment accuracies:")
            for env, acc in test_results['env_accuracies'].items():
                print(f"     Environment {env}: {acc:.2%}")

        return {
            'dataset': dataset_name,
            'method': method,
            'config': config,
            'test_results': test_results,
            'history': history
        }

    except Exception as e:
        print(f"❌ Experiment failed: {e}")
        import traceback
        traceback.print_exc()
        return None


def main():
    """Main function"""
    parser = argparse.ArgumentParser(description='Unified Trainer')
    parser.add_argument('--datasets', type=str, nargs='+',
                        default=['good-motif', 'cmnist'],
                        help='List of datasets to train on')
    parser.add_argument('--methods', type=str, nargs='+',
                        default=['erm', 'adversarial', 'wasserstein'],
                        help='List of training methods')
    parser.add_argument('--epochs', type=int, default=50,
                        help='Number of epochs')
    parser.add_argument('--batch-size', type=int, default=32,
                        help='Batch size')
    parser.add_argument('--lr', type=float, default=0.001,
                        help='Learning rate')
    parser.add_argument('--hidden-dim', type=int, default=256,
                        help='Hidden dimension')
    parser.add_argument('--patience', type=int, default=20,
                        help='Early stopping patience')
    parser.add_argument('--lambda-adv', type=float, default=0.1,
                        help='Adversarial training weight')
    parser.add_argument('--lambda-irm', type=float, default=0.1,
                        help='IRM penalty weight')
    parser.add_argument('--lambda-vrex', type=float, default=0.1,
                        help='V-Rex penalty weight')
    parser.add_argument('--quick', action='store_true',
                        help='Quick test mode')

    args = parser.parse_args()

    if args.quick:
        args.epochs = 10
        args.patience = 5
        print("🚀 Quick test mode")

    print(f"\n📋 Experiment configuration:")
    print(f"   Datasets: {args.datasets}")
    print(f"   Methods: {args.methods}")
    print(f"   Epochs: {args.epochs}")
    print(f"   Batch size: {args.batch_size}")
    print(f"   Learning rate: {args.lr}")
    print(f"   Hidden dimension: {args.hidden_dim}")
    print(f"   lambda_adv: {args.lambda_adv}")
    print(f"   lambda_irm: {args.lambda_irm}")
    print(f"   lambda_vrex: {args.lambda_vrex}")

    # Run all experiments
    all_results = []

    for dataset in args.datasets:
        for method in args.methods:
            # Determine dataset type
            if 'good' in dataset:
                dataset_type = 'good'
            else:
                dataset_type = 'disc'

            # Get configuration
            config = UnifiedConfig.get_config(dataset_type, args)

            # Run experiment
            result = run_experiment(dataset, method, config)
            if result:
                all_results.append(result)

    # Summary report
    if all_results:
        print("\n" + "=" * 80)
        print("Experiment Summary Report")
        print("=" * 80)

        summary_table = []
        for result in all_results:
            summary_table.append([
                result['dataset'],
                result['method'],
                f"{result['test_results']['accuracy']:.2%}",
                f"{result['test_results']['fairness_gap']:.4f}",
                f"{result['history']['best_val_acc']:.2%}"
            ])

        print(f"\n{'Dataset':<12} {'Method':<12} {'Test Acc':<12} {'Fairness Gap':<12} {'Best Val':<12}")
        print("-" * 60)
        for row in summary_table:
            print(f"{row[0]:<12} {row[1]:<12} {row[2]:<12} {row[3]:<12} {row[4]:<12}")

        # Save summary
        summary_file = "results/unified/summary.txt"
        with open(summary_file, 'w') as f:
            f.write("DIGL Unified Trainer - Experiment Summary\n")
            f.write("=" * 60 + "\n\n")
            f.write(f"{'Dataset':<12} {'Method':<12} {'Test Acc':<12} {'Fairness Gap':<12} {'Best Val':<12}\n")
            f.write("-" * 60 + "\n")
            for row in summary_table:
                f.write(f"{row[0]:<12} {row[1]:<12} {row[2]:<12} {row[3]:<12} {row[4]:<12}\n")

        print(f"\n📊 Summary saved to: {summary_file}")

    print("\n🎉 All experiments completed!")


if __name__ == "__main__":
    main()