"""
DIGL Framework for DisC Datasets
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import time
import os
import json
import argparse
import random
import math
from pathlib import Path
import sys
from collections import defaultdict, OrderedDict

# Add parent directory to path for importing modules
sys.path.append(str(Path(__file__).parent.parent))

print("=" * 80)
print("Enhanced DIGL Framework for DisC Datasets")
print("=" * 80)


# ============================================================================
# 1. Import DIGL Framework Modules (Simplified)
# ============================================================================

def import_digl_modules():
    """Import DIGL framework modules with fallbacks"""
    modules_loaded = {}

    print("📦 Importing DIGL framework modules...")

    # Import loss modules with fallbacks
    try:
        from loss.wasserstein import WassersteinDistance
        modules_loaded['wasserstein'] = WassersteinDistance
        print("✅ Imported WassersteinDistance")
    except ImportError as e:
        print(f"⚠️  Could not import WassersteinDistance: {e}")
        modules_loaded['wasserstein'] = None

    # Try to import from training directory
    try:
        # Try different import patterns
        from training.alternating_opt import AlternatingOptimizer
        modules_loaded['alternating_optimizer_class'] = AlternatingOptimizer
        print("✅ Imported AlternatingOptimizer")
    except ImportError as e:
        print(f"⚠️  Could not import AlternatingOptimizer: {e}")
        modules_loaded['alternating_optimizer_class'] = None

    return modules_loaded


# ============================================================================
# 2. Command Line Arguments
# ============================================================================

def parse_args():
    parser = argparse.ArgumentParser(description='Enhanced DIGL Framework for DisC Datasets')
    parser.add_argument('--dataset', type=str, default='cmnist',
                        choices=['cmnist', 'cfashion', 'ckuzushiji', 'all'],
                        help='Dataset to train on')
    parser.add_argument('--epochs', type=int, default=50,
                        help='Number of training epochs')
    parser.add_argument('--batch-size', type=int, default=64,
                        help='Batch size')
    parser.add_argument('--lr', type=float, default=0.001,
                        help='Learning rate')
    parser.add_argument('--hidden-dim', type=int, default=128,
                        help='Hidden dimension size')
    parser.add_argument('--num-envs', type=int, default=3,
                        help='Number of augmented environments')
    parser.add_argument('--lambda-align', type=float, default=1.0,
                        help='Weight for prototype alignment loss')
    parser.add_argument('--lambda-disentangle', type=float, default=0.5,
                        help='Weight for representation disentanglement loss')
    parser.add_argument('--lambda-class', type=float, default=1.0,
                        help='Weight for classification loss')
    parser.add_argument('--lambda-adv', type=float, default=0.1,
                        help='Weight for adversarial environment loss')
    parser.add_argument('--output-dir', type=str, default='./results/digl_disc',
                        help='Output directory for results')
    parser.add_argument('--patience', type=int, default=15,
                        help='Early stopping patience')
    parser.add_argument('--seed', type=int, default=42,
                        help='Random seed')
    parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu',
                        help='Device to use (cuda/cpu)')
    parser.add_argument('--use-wasserstein', action='store_true',
                        help='Use Wasserstein distance for prototype alignment')
    parser.add_argument('--quick-test', action='store_true',
                        help='Quick test mode')
    parser.add_argument('--color-bias', type=float, default=0.9,
                        help='Color bias strength for training set')
    parser.add_argument('--img-size', type=int, default=28,
                        help='Image size')

    return parser.parse_args()


# ============================================================================
# 3. DisC Dataset
# ============================================================================

class DisCDataset:
    """DisC Dataset for CMNIST, CFashion, CKuzushiji"""

    def __init__(self, name="cmnist", num_samples=2000, color_bias=0.9,
                 img_size=28, split="train"):
        self.name = name
        self.num_samples = num_samples
        self.color_bias = color_bias
        self.img_size = img_size
        self.split = split

        # Generate data
        self.images, self.labels, self.colors, self.environments = self._generate_data()

        # Dataset info
        self.num_classes = 10
        self.num_environments = 2

    def _create_digit_pattern(self, digit, img_size=28):
        """Create simple digit pattern"""
        pattern = np.zeros((1, img_size, img_size))
        center = img_size // 2

        # Simple patterns for quick testing
        if digit == 0:
            # Circle
            for i in range(img_size):
                for j in range(img_size):
                    if (i-center)**2 + (j-center)**2 < 100:
                        pattern[0, i, j] = 0.8
        elif digit == 1:
            # Vertical line
            pattern[0, :, center-2:center+2] = 0.8
        else:
            # Square for other digits
            size = 8
            pattern[0, center-size:center+size, center-size:center+size] = 0.6

        return pattern

    def _generate_data(self):
        """Generate dataset with color bias"""
        images = []
        labels = []
        colors = []
        environments = []

        for i in range(self.num_samples):
            # Random label (0-9)
            label = np.random.randint(0, 10)

            # Create digit pattern
            img = self._create_digit_pattern(label, self.img_size).copy()

            # Add some noise
            img += np.random.randn(1, self.img_size, self.img_size) * 0.1

            # Assign color and environment
            if self.split == 'train':
                # Training set with bias
                if label < 5:  # Digits 0-4
                    if np.random.random() < self.color_bias:
                        color = 0  # Red bias
                        env = 0
                    else:
                        color = 1  # Green
                        env = 1
                else:  # Digits 5-9
                    if np.random.random() < self.color_bias:
                        color = 1  # Green bias
                        env = 0
                    else:
                        color = 0  # Red
                        env = 1
            else:
                # Validation/Test: random distribution
                color = np.random.randint(0, 2)
                env = np.random.randint(0, 2)

            # Add color tint
            color_strength = 0.3
            if color == 0:  # Red tint
                img += color_strength + np.random.randn() * 0.05
            else:  # Green tint
                img += color_strength * 0.7 + np.random.randn() * 0.05

            # Clip and normalize
            img = np.clip(img, -1, 1)

            images.append(torch.FloatTensor(img))
            labels.append(label)
            colors.append(color)
            environments.append(env)

        return (
            torch.stack(images),
            torch.LongTensor(labels),
            torch.LongTensor(colors),
            torch.LongTensor(environments)
        )

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        return {
            'image': self.images[idx],
            'label': self.labels[idx],
            'color': self.colors[idx],
            'environment': self.environments[idx]
        }


# ============================================================================
# 4. Simplified DIGL Model Components
# ============================================================================

class SimpleCNNEncoder(nn.Module):
    """
    Improved CNN Encoder with Batch Normalization.
    Normalization is crucial for stable Wasserstein distance calculation.
    """

    def __init__(self, img_size=28, hidden_dim=128):
        super().__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),  # Added for gradient stability
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU()
        )
        self.flattened_size = 128 * (img_size // 4) * (img_size // 4)
        self.fc_layers = nn.Sequential(
            nn.Linear(self.flattened_size, hidden_dim),
            nn.BatchNorm1d(hidden_dim),  # Essential for Align/Disentangle modules
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

    def forward(self, x):
        # x shape: [batch, 1, img_size, img_size]
        features = self.conv_layers(x)
        features = features.view(features.size(0), -1)
        return self.fc_layers(features)


class SimpleSubgraphExtractor(nn.Module):
    """Simple subgraph extractor using attention masks"""

    def __init__(self, hidden_dim):
        super().__init__()
        self.mask_predictor = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 1),
            nn.Sigmoid()
        )

    def forward(self, features):
        """Extract invariant and variant features"""
        # Predict mask (invariant vs variant)
        mask = self.mask_predictor(features)
        inv_features = features * mask
        var_features = features * (1 - mask)
        return inv_features, var_features, mask


class SimpleEnvironmentGenerator(nn.Module):
    """Simple environment generator"""

    def __init__(self, feature_dim, num_envs=3):
        super().__init__()
        self.num_envs = num_envs
        self.env_transforms = nn.ModuleList([
            nn.Sequential(
                nn.Linear(feature_dim, feature_dim),
                nn.ReLU(),
                nn.Linear(feature_dim, feature_dim)
            ) for _ in range(num_envs)
        ])

    def forward(self, features):
        """Generate environment-specific features"""
        env_features = []
        for transform in self.env_transforms:
            env_features.append(transform(features))
        return env_features


class SimplePrototypeAligner(nn.Module):
    """Simple prototype aligner"""

    def __init__(self, feature_dim, num_classes):
        super().__init__()
        # Learnable prototypes for each class
        self.prototypes = nn.Parameter(torch.randn(num_classes, feature_dim) * 0.1)

    def forward(self, features, labels):
        """Align features to class prototypes"""
        batch_prototypes = self.prototypes[labels]
        alignment_loss = F.mse_loss(features, batch_prototypes, reduction='mean')
        return alignment_loss


class SimpleDisentanglementLoss(nn.Module):
    """Simple disentanglement loss using orthogonality"""

    def __init__(self):
        super().__init__()

    def forward(self, inv_features, var_features):
        """Encourage strict orthogonality between invariant and variant features"""
        # Normalize features to unit length
        inv_norm = F.normalize(inv_features, p=2, dim=1)
        var_norm = F.normalize(var_features, p=2, dim=1)

        # Calculate absolute cosine similarity for each sample in batch
        # Using sum(a * b) is more efficient than bmm for simple dot products
        cos_sim = torch.abs(torch.sum(inv_norm * var_norm, dim=1))
        return cos_sim.mean()


# ============================================================================
# 5. Complete DIGL Model
# ============================================================================

class DIGLModel(nn.Module):
    """Complete DIGL model with all components"""

    def __init__(self, img_size=28, hidden_dim=128, num_classes=10, num_envs=3):
        super().__init__()

        self.img_size = img_size
        self.hidden_dim = hidden_dim
        self.num_classes = num_classes
        self.num_envs = num_envs

        # Feature encoder
        self.feature_encoder = SimpleCNNEncoder(img_size, hidden_dim)

        # Subgraph extractor
        #self.subgraph_extractor = SimpleSubgraphExtractor(hidden_dim)
        try:
            from models.subgraph_extractor import SubgraphExtractor
            # 由于是图像数据，需要适配版本
            self.subgraph_extractor = nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim)
            )
            print("✅ Using feature-level subgraph extractor")
        except ImportError:
            print("⚠️  Using fallback subgraph extractor")
            self.subgraph_extractor = SimpleSubgraphExtractor(hidden_dim)

        # Environment generator
        #self.env_generator = SimpleEnvironmentGenerator(hidden_dim, num_envs)
        try:
            from models.env_generators import EnvironmentGenerator
            self.env_generator = EnvironmentGenerator(
                input_dim=hidden_dim,
                hidden_dim=hidden_dim // 2,
                num_environments=num_envs,
                num_classes=num_classes
            )
            print("✅ Using full EnvironmentGenerator with adversarial training")
        except ImportError:
            print("⚠️  Using fallback environment generator")
            self.env_generator = SimpleEnvironmentGenerator(hidden_dim, num_envs)

        # Prototype aligner
        #self.prototype_aligner = SimplePrototypeAligner(hidden_dim, num_classes)
        try:
            from models.prototype_aligner import WassersteinPrototype
            self.prototype_aligner = WassersteinPrototype(
                feature_dim=hidden_dim,
                num_classes=num_classes,
                num_environments=num_envs,
                wasserstein_method='trace'
            )
            print("✅ Using WassersteinPrototype for alignment")
        except ImportError:
            print("⚠️  Using fallback prototype aligner")
            self.prototype_aligner = SimplePrototypeAligner(hidden_dim, num_classes)

        # Disentanglement loss
        #self.disentangle_loss = SimpleDisentanglementLoss()
        try:
            from loss.mutual_info import MutualInformationLoss
            self.mi_loss_fn = MutualInformationLoss(method='hsic')  # 或 'infonce'
            print("✅ Using MutualInformationLoss for disentanglement")
            self.use_mi_loss = True
        except ImportError:
            print("⚠️  Using fallback disentanglement loss")
            self.disentangle_loss = SimpleDisentanglementLoss()
            self.use_mi_loss = False

        # Classifiers
        self.task_classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(hidden_dim // 2, num_classes)
        )

        self.env_classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, num_envs)
        )

    def forward(self, x, labels=None, env_labels=None, return_losses=False, epoch=0):
        """Forward pass through DIGL framework"""
        # Ensure proper input shape
        if x.dim() == 2:
            # Flattened image, reshape
            x = x.view(-1, 1, self.img_size, self.img_size)

        # Extract features
        features = self.feature_encoder(x)

        # Extract invariant and variant features
        #inv_features, var_features, mask = self.subgraph_extractor(features)
        if hasattr(self.subgraph_extractor, 'forward') and callable(getattr(self.subgraph_extractor, 'forward', None)):

            try:
                mask_logits = self.subgraph_extractor(features)
                mask = torch.sigmoid(mask_logits)
                inv_features = features * mask
                var_features = features * (1 - mask)
            except:
                inv_features = features
                var_features = features
                mask = torch.ones_like(features)
        else:
            inv_features, var_features, mask = self.subgraph_extractor(features)

        # Generate environment features
        #env_features_list = self.env_generator(inv_features)
        env_features_list = []
        for env_idx in range(self.num_envs):
            env_feat = self.env_generator(inv_features, env_idx)
            env_features_list.append(env_feat)

        # Task classification
        task_logits = self.task_classifier(inv_features)

        # Environment classification (adversarial)
        env_logits = self.env_classifier(var_features)

        # Prepare output
        output = {
            'task_logits': task_logits,
            'env_logits': env_logits,
            'inv_features': inv_features,
            'var_features': var_features,
            'env_features': env_features_list,
            'mask': mask
        }

        # Calculate losses if requested
        # if return_losses and labels is not None:
        #     losses = {}
        #
        #     # 1. Classification loss
        #     cls_loss = F.cross_entropy(task_logits, labels)
        #     losses['cls_loss'] = cls_loss
        #
        #     # 2. Prototype alignment loss
        #     align_loss = self.prototype_aligner(inv_features, labels)
        #     losses['align_loss'] = align_loss
        #
        #     # 3. Disentanglement loss
        #     disentangle_loss = self.disentangle_loss(inv_features, var_features)
        #     losses['disentangle_loss'] = disentangle_loss
        #
        #     # 4. Environment adversarial loss
        #     if env_labels is not None:
        #         env_loss = F.cross_entropy(env_logits, env_labels)
        #         losses['env_loss'] = env_loss
        #     else:
        #         # Random environment labels for adversarial training
        #         batch_size = labels.size(0)
        #         rand_env_labels = torch.randint(0, self.num_envs, (batch_size,), device=x.device)
        #         env_loss = F.cross_entropy(env_logits, rand_env_labels)
        #         losses['env_loss'] = env_loss
        #
        #     output['losses'] = losses
        #
        # return output
        # 计算损失
        if return_losses and labels is not None:
            losses = self._compute_full_losses(
                inv_features, var_features, env_features_list,
                task_logits, env_logits, labels, env_labels, epoch
            )
            output['losses'] = losses

        return output

    def _compute_full_losses(self, inv_features, var_features, env_features_list,
                             task_logits, env_logits, labels, env_labels, epoch):
        """使用完整模块计算损失"""
        losses = {}
        inv_features = F.normalize(inv_features, p=2, dim=1)
        var_features = F.normalize(var_features, p=2, dim=1)

        task_logits = self.task_classifier(inv_features)

        # 1. 分类损失
        cls_loss = F.cross_entropy(task_logits, labels, label_smoothing=0.1)
        losses['cls_loss'] = cls_loss

        # 2. 原型对齐损失（使用Wasserstein）
        if hasattr(self.prototype_aligner, 'alignment_loss'):
            # 完整版本
            align_loss = self.prototype_aligner.alignment_loss(
                {i: feat for i, feat in enumerate(env_features_list)},
                labels
            )
        else:
            # 简化版本
            align_loss = self.prototype_aligner(inv_features, labels)
        losses['align_loss'] = align_loss

        # 3. 表征解耦损失（使用互信息）
        if hasattr(self, 'mi_loss_fn') and self.use_mi_loss:
            # 使用互信息损失
            mi_loss = self.mi_loss_fn(inv_features, var_features)
            losses['disentangle_loss'] = mi_loss
        else:
            # 使用正交性损失（回退）
            disentangle_loss = self.disentangle_loss(inv_features, var_features)
            losses['disentangle_loss'] = disentangle_loss

        # 4. 环境对抗损失
        if env_labels is not None:
            env_loss = F.cross_entropy(env_logits, env_labels)
        else:
            batch_size = labels.size(0)
            rand_env_labels = torch.randint(0, self.num_envs, (batch_size,), device=inv_features.device)
            env_loss = F.cross_entropy(env_logits, rand_env_labels)
        losses['env_loss'] = env_loss

        # 5. 环境生成器的对抗损失（如果可用）
        if hasattr(self.env_generator, 'adversarial_loss'):
            try:
                # 使用简化版本避免数值问题
                if hasattr(self.env_generator, 'simple_adversarial_loss'):
                    adv_loss = self.env_generator.simple_adversarial_loss(env_features_list)
                else:
                    adv_loss = self.env_generator.adversarial_loss(env_features_list)

                inv_loss = self.env_generator.invariance_loss(env_features_list, labels)

                # 限制损失范围
                adv_loss = torch.clamp(adv_loss, max=5.0)
                inv_loss = torch.clamp(inv_loss, max=10.0)

                losses['env_adv_loss'] = adv_loss
                losses['env_inv_loss'] = inv_loss
            except Exception as e:
                print(f"⚠️  Environment generator loss failed: {e}")

        return losses

    def alternating_optimization(self, features_list, labels_list, num_iterations=3):
        """Simple alternating optimization for prototype alignment"""
        aligned_features = []

        for features, labels in zip(features_list, labels_list):
            # Simple alignment: move features toward class prototypes
            with torch.no_grad():
                prototypes = self.prototype_aligner.prototypes[labels]
                # Interpolate between features and prototypes
                aligned = 0.7 * features + 0.3 * prototypes
                aligned_features.append(aligned)

        return aligned_features


# ============================================================================
# 6. DIGL Trainer
# ============================================================================

class DIGLTrainer:
    """DIGL trainer with progressive training schedule"""

    def __init__(self, model, train_loader, val_loader, test_loader, args):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader
        self.args = args
        self.device = args.device

        # Move model to device
        self.model.to(self.device)

        # Optimizer
        self.optimizer = optim.AdamW(
            model.parameters(),
            lr=args.lr,
            weight_decay=2e-4
        )

        # Simple learning rate scheduler
        self.scheduler = optim.lr_scheduler.StepLR(
            self.optimizer,
            step_size=5,
            gamma=0.5
        )

        # Early stopping
        self.best_val_acc = 0
        self.patience_counter = 0
        self.best_model_state = None

        # Training history
        self.history = {
            'train_loss': [],
            'train_acc': [],
            'val_loss': [],
            'val_acc': []
        }

    # def compute_total_loss(self, losses_dict, epoch):
    #     """Compute total loss with progressive scheduling"""
    #     total_loss = 0
    #
    #     # Always include classification loss
    #     if 'cls_loss' in losses_dict:
    #         total_loss += self.args.lambda_class * losses_dict['cls_loss']
    #
    #     # Progressive training schedule
    #     progress = epoch / self.args.epochs
    #
    #     if progress < 0.4:
    #         # Phase 1: Focus on classification
    #         pass
    #     elif progress < 0.7:
    #         # Phase 2: Add alignment and disentanglement
    #         if 'align_loss' in losses_dict:
    #             total_loss += self.args.lambda_align * 0.5 * losses_dict['align_loss']
    #         if 'disentangle_loss' in losses_dict:
    #             total_loss += self.args.lambda_disentangle * 0.3 * losses_dict['disentangle_loss']
    #     else:
    #         # Phase 3: Full DIGL training
    #         if 'align_loss' in losses_dict:
    #             total_loss += self.args.lambda_align * losses_dict['align_loss']
    #         if 'disentangle_loss' in losses_dict:
    #             total_loss += self.args.lambda_disentangle * losses_dict['disentangle_loss']
    #         if 'env_loss' in losses_dict:
    #             # Negative weight for adversarial training
    #             total_loss -= self.args.lambda_adv * losses_dict['env_loss']
    #
    #     return total_loss

    def compute_total_loss(self, losses_dict, epoch):
        total_loss = self.args.lambda_class * losses_dict['cls_loss']

        if epoch >= 5:
            # 使用更小的对齐权重，并进行 log 平滑处理
            # 这样即使 distance 很大，产生的 Loss 增量也有限
            align_loss = losses_dict['align_loss']
            safe_align = torch.log(align_loss + 1.0)
            total_loss += (self.args.lambda_align * 0.05) * safe_align

        if epoch >= 10:
            # 使用 .get() 并设置默认值为 0.0，防止 KeyError
            dis_loss = losses_dict.get('dis_loss', torch.tensor(0.0, device=self.device))
            total_loss += self.args.lambda_disentangle * dis_loss

        return total_loss

    def train_epoch(self, epoch):
        """Train for one epoch"""
        self.model.train()
        total_loss = 0
        total_correct = 0
        total_samples = 0

        for batch_idx, batch in enumerate(self.train_loader):
            # Move data to device
            images = batch['image'].to(self.device)
            labels = batch['label'].to(self.device)
            envs = batch['environment'].to(self.device)

            # Forward pass with loss computation
            output = self.model(images, labels, envs, return_losses=True, epoch=epoch)

            # Compute total loss
            loss = self.compute_total_loss(output['losses'], epoch)

            # Backward pass
            self.optimizer.zero_grad()
            loss.backward()

            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.1)
            found_inf = False
            for p in self.model.parameters():
                if p.grad is not None and (torch.isnan(p.grad).any() or torch.isinf(p.grad).any()):
                    found_inf = True
                    break
            if not found_inf:
                self.optimizer.step()
            self.optimizer.zero_grad()

            # Update parameters
            # self.optimizer.step()

            # Statistics
            total_loss += loss.item()

            # Calculate accuracy
            with torch.no_grad():
                preds = output['task_logits'].argmax(dim=1)
                correct = (preds == labels).sum().item()
                total_correct += correct
                total_samples += labels.size(0)

            # Print batch progress occasionally
            if batch_idx % 20 == 0:
                batch_acc = correct / labels.size(0) if labels.size(0) > 0 else 0
                print(f"   Batch {batch_idx:3d}/{len(self.train_loader)}: "
                      f"Loss: {loss.item():.4f}, Acc: {batch_acc:.2%}")

        # Calculate epoch metrics
        avg_loss = total_loss / max(len(self.train_loader), 1)
        accuracy = total_correct / total_samples if total_samples > 0 else 0

        return avg_loss, accuracy

    def validate(self):
        """Validation step"""
        self.model.eval()
        total_correct = 0
        total_samples = 0
        total_loss = 0

        with torch.no_grad():
            for batch in self.val_loader:
                images = batch['image'].to(self.device)
                labels = batch['label'].to(self.device)

                # Forward pass
                output = self.model(images, return_losses=False)

                # Calculate accuracy
                preds = output['task_logits'].argmax(dim=1)
                correct = (preds == labels).sum().item()
                total_correct += correct
                total_samples += labels.size(0)

                # Calculate loss for monitoring
                loss = F.cross_entropy(output['task_logits'], labels)
                total_loss += loss.item()

        avg_loss = total_loss / max(len(self.val_loader), 1)
        accuracy = total_correct / total_samples if total_samples > 0 else 0

        return avg_loss, accuracy

    def test(self):
        """Test on test set"""
        self.model.eval()
        total_correct = 0
        total_samples = 0
        env_correct = defaultdict(int)
        env_total = defaultdict(int)

        with torch.no_grad():
            for batch in self.test_loader:
                images = batch['image'].to(self.device)
                labels = batch['label'].to(self.device)
                envs = batch['environment'].to(self.device)

                # Forward pass
                output = self.model(images, return_losses=False)

                # Calculate accuracy
                preds = output['task_logits'].argmax(dim=1)
                correct = (preds == labels).sum().item()
                total_correct += correct
                total_samples += labels.size(0)

                # Environment statistics
                for i in range(len(labels)):
                    env = envs[i].item()
                    env_total[env] += 1
                    if preds[i] == labels[i]:
                        env_correct[env] += 1

        # Calculate overall accuracy
        accuracy = total_correct / total_samples if total_samples > 0 else 0

        # Calculate environment accuracies
        env_accuracies = {}
        for env in env_total:
            env_accuracies[env] = env_correct[env] / env_total[env] if env_total[env] > 0 else 0

        # Calculate fairness gap
        if len(env_accuracies) >= 2:
            fairness_gap = max(env_accuracies.values()) - min(env_accuracies.values())
        else:
            fairness_gap = 0

        return {
            'accuracy': accuracy,
            'env_accuracies': env_accuracies,
            'fairness_gap': fairness_gap,
            'total_correct': total_correct,
            'total_samples': total_samples
        }

    def train(self):
        """Complete training process"""
        print(f"\n🚀 Starting training for {self.args.epochs} epochs...")

        for epoch in range(self.args.epochs):
            start_time = time.time()

            # Training
            train_loss, train_acc = self.train_epoch(epoch)

            # Validation
            val_loss, val_acc = self.validate()

            # Update learning rate
            self.scheduler.step()

            # Update history
            self.history['train_loss'].append(train_loss)
            self.history['train_acc'].append(train_acc)
            self.history['val_loss'].append(val_loss)
            self.history['val_acc'].append(val_acc)

            # Print progress
            epoch_time = time.time() - start_time
            current_lr = self.optimizer.param_groups[0]['lr']

            print(f"\nEpoch {epoch+1:3d}/{self.args.epochs}: "
                  f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2%}, "
                  f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2%}, "
                  f"LR: {current_lr:.6f}, Time: {epoch_time:.1f}s")

            # Early stopping check
            if val_acc > self.best_val_acc:
                self.best_val_acc = val_acc
                self.patience_counter = 0
                self.best_model_state = self.model.state_dict().copy()

                # Save model
                self.save_model(epoch + 1, val_acc)
                print(f"   🔥 New best validation accuracy: {val_acc:.2%}")
            else:
                self.patience_counter += 1
                if self.patience_counter >= self.args.patience:
                    print(f"   ⏹️  Early stopping triggered at epoch {epoch + 1}")
                    break

        # Restore best model
        if self.best_model_state is not None:
            self.model.load_state_dict(self.best_model_state)
            print(f"\n✅ Restored best model with validation accuracy: {self.best_val_acc:.2%}")

        return self.history

    def save_model(self, epoch, val_acc):
        """Save model checkpoint"""
        os.makedirs(self.args.output_dir, exist_ok=True)

        filename = f"digl_{self.args.dataset}_epoch{epoch}.pth"
        checkpoint_path = os.path.join(self.args.output_dir, filename)

        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'val_acc': val_acc,
            'best_val_acc': self.best_val_acc,
            'args': vars(self.args)
        }

        torch.save(checkpoint, checkpoint_path)
        print(f"   💾 Saved model: {checkpoint_path}")


# ============================================================================
# 7. Main Training Function
# ============================================================================

def train_digl_disc(args):
    """Main training function for DisC datasets"""

    # Set random seeds
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)

    print(f"\n⚙️  Configuration:")
    print(f"   Dataset: {args.dataset}")
    print(f"   Device: {args.device}")
    print(f"   Hidden dimension: {args.hidden_dim}")
    print(f"   Number of environments: {args.num_envs}")
    print(f"   Learning rate: {args.lr}")
    print(f"   Batch size: {args.batch_size}")
    print(f"   Epochs: {args.epochs}")

    # Import modules (optional)
    try:
        modules = import_digl_modules()
    except:
        print("⚠️  Some modules failed to import, using simplified implementations")
        modules = {}

    # Determine datasets to train
    if args.dataset == 'all':
        datasets_to_train = ['cmnist', 'cfashion', 'ckuzushiji']
    else:
        datasets_to_train = [args.dataset]

    all_results = {}

    for dataset_name in datasets_to_train:
        print(f"\n{'=' * 60}")
        print(f"Training {dataset_name.upper()} Dataset")
        print(f"{'=' * 60}")

        try:
            # Create datasets
            print(f"1. Creating {dataset_name} dataset...")

            num_samples = 1000 if args.quick_test else 2000

            train_dataset = DisCDataset(
                name=dataset_name,
                num_samples=num_samples,
                color_bias=args.color_bias,
                img_size=args.img_size,
                split='train'
            )

            val_dataset = DisCDataset(
                name=dataset_name,
                num_samples=num_samples // 5,
                color_bias=0.5,
                img_size=args.img_size,
                split='val'
            )

            test_dataset = DisCDataset(
                name=dataset_name,
                num_samples=num_samples // 5,
                color_bias=0.5,
                img_size=args.img_size,
                split='test'
            )

            print(f"   Training samples: {len(train_dataset)}")
            print(f"   Validation samples: {len(val_dataset)}")
            print(f"   Test samples: {len(test_dataset)}")

            # Create data loaders
            train_loader = torch.utils.data.DataLoader(
                train_dataset, batch_size=args.batch_size, shuffle=True
            )
            val_loader = torch.utils.data.DataLoader(
                val_dataset, batch_size=args.batch_size, shuffle=False
            )
            test_loader = torch.utils.data.DataLoader(
                test_dataset, batch_size=args.batch_size, shuffle=False
            )

            # Create model
            print(f"2. Creating DIGL model...")

            model = DIGLModel(
                img_size=args.img_size,
                hidden_dim=args.hidden_dim,
                num_classes=10,
                num_envs=args.num_envs
            )

            print(f"   Model parameters: {sum(p.numel() for p in model.parameters()):,}")

            # Create trainer
            trainer = DIGLTrainer(
                model=model,
                train_loader=train_loader,
                val_loader=val_loader,
                test_loader=test_loader,
                args=args
            )

            # Train model
            print(f"3. Starting training...")
            history = trainer.train()

            # Test model
            print(f"4. Testing model...")
            test_results = trainer.test()

            print(f"\n🎯 Training completed!")
            print(f"   Best validation accuracy: {trainer.best_val_acc:.2%}")
            print(f"   Test accuracy: {test_results['accuracy']:.2%}")
            print(f"   Fairness gap: {test_results['fairness_gap']:.4f}")

            if test_results['env_accuracies']:
                print(f"\n   Environment accuracies:")
                for env, acc in test_results['env_accuracies'].items():
                    print(f"     Environment {env}: {acc:.2%}")

            # Store results
            all_results[dataset_name] = {
                'test_results': test_results,
                'history': history,
                'best_val_acc': trainer.best_val_acc
            }

            # Save model
            model_path = os.path.join(args.output_dir, f"{dataset_name}_model.pth")
            torch.save({
                'model_state_dict': model.state_dict(),
                'test_results': test_results,
                'history': history,
                'args': vars(args)
            }, model_path)
            print(f"   Model saved to: {model_path}")

        except Exception as e:
            print(f"❌ Training {dataset_name} failed: {e}")
            import traceback
            traceback.print_exc()

    return all_results


# ============================================================================
# 8. Main Function
# ============================================================================

def main():
    """Main function"""
    args = parse_args()

    # Adjust for quick test
    if args.quick_test:
        args.epochs = min(args.epochs, 20)
        args.batch_size = min(args.batch_size, 32)
        args.hidden_dim = min(args.hidden_dim, 64)
        print(f"\n🚀 Quick test mode: {args.epochs} epochs")

    # Create output directory
    os.makedirs(args.output_dir, exist_ok=True)

    print(f"\n🎯 Starting DIGL training for DisC datasets")
    print("=" * 60)

    try:
        results = train_digl_disc(args)

        if results:
            print(f"\n{'=' * 60}")
            print("DIGL Training Summary")
            print(f"{'=' * 60}")

            for dataset_name, result in results.items():
                test_acc = result['test_results']['accuracy']
                fairness_gap = result['test_results']['fairness_gap']
                best_val_acc = result['best_val_acc']

                print(f"{dataset_name.upper():12} Test Acc: {test_acc:.2%}, "
                      f"Fairness Gap: {fairness_gap:.4f}, "
                      f"Best Val Acc: {best_val_acc:.2%}")

            # Overall evaluation
            avg_acc = np.mean([r['test_results']['accuracy'] for r in results.values()])
            print(f"\n📊 Overall average accuracy: {avg_acc:.2%}")

            if avg_acc > 0.75:
                print("   🎉 Excellent overall performance!")
            elif avg_acc > 0.65:
                print("   👍 Good overall performance!")
            elif avg_acc > 0.55:
                print("   ⚠️  Moderate overall performance")
            else:
                print("   ❌ Needs improvement")

            # Save summary
            summary_path = os.path.join(args.output_dir, "summary.json")
            with open(summary_path, 'w') as f:
                json.dump(results, f, indent=2)
            print(f"\n💾 Summary saved to: {summary_path}")

        print(f"\n✅ DIGL training completed successfully!")
    
    except Exception as e:
        print(f"\n❌ Training failed: {e}")
        import traceback
        traceback.print_exc()
    
    print(f"\n🎉 Program execution completed!")


if __name__ == "__main__":
    main()