"""
Training script for Adaptive Gating MetaNet using pre-computed features.

This script implements the training procedure for adaptive gating with
support for both MetaNet-based and Atlas-based models.
"""

import os
import time
import json
import torch
import random
import numpy as np
import gc
import traceback
from datetime import datetime

from src.adaptive_gating_metanet import AdaptiveGatingMetaNet
from src.utils import cosine_lr
from src.datasets.common import maybe_dictionarize
from src.distributed import cleanup_ddp, distribute_loader, is_main_process, setup_ddp

from src.args import parse_arguments


class DirectFeatureModel(torch.nn.Module):
    """A direct feature model (Atlas approach)

    Passes features directly to classifier without MetaNet transformations.
    Optionally supports gating directly applied to features.

    Args:
        feature_dim: Dimension of the pre-computed feature vectors
        gating_no_metanet: Whether to apply gating directly to features
        base_threshold: Base threshold for gating mechanism
        beta: Beta parameter for uncertainty weighting
        uncertainty_reg: Weight for uncertainty regularization
    """
    def __init__(self, feature_dim, gating_no_metanet=False, base_threshold=0.05,
                 beta=1.0, uncertainty_reg=0.01):
        super().__init__()
        self.feature_dim = feature_dim
        self.gating_no_metanet = gating_no_metanet

        # Identity projection (no transformation)
        self.projection = torch.nn.Identity()

        # Add a dummy parameter to ensure DDP works
        self.dummy_param = torch.nn.Parameter(torch.zeros(1), requires_grad=True)
        self.dummy_linear = torch.nn.Linear(feature_dim, feature_dim)
        torch.nn.init.zeros_(self.dummy_linear.weight)
        torch.nn.init.zeros_(self.dummy_linear.bias)

        # Adding gating mechanism if enabled
        if self.gating_no_metanet:
            # Initialize learnable gating parameters
            self.log_base_threshold = torch.nn.Parameter(
                torch.tensor([math.log(max(base_threshold, 1e-5))], dtype=torch.float)
            )
            self.log_beta = torch.nn.Parameter(
                torch.tensor([math.log(max(beta * 0.95, 1e-5))], dtype=torch.float)
            )

            # Register buffers for monitoring
            self.register_buffer('initial_base_threshold', torch.tensor([base_threshold], dtype=torch.float))
            self.register_buffer('initial_beta', torch.tensor([beta], dtype=torch.float))

            # Uncertainty related variables
            self.uncertainty_reg = uncertainty_reg
            self._forward_count = 0
            self._reg_loss_count = 0
            self.training_mode = True

            # Storage for computed values
            self.last_uncertainties = None
            self.last_gated_features = None
            self.last_thresholds = None
            self.last_orig_features = None
            self.last_gating_mask = None

            # Transform to generate feature-specific uncertainty
            self.uncertainty_net = torch.nn.Sequential(
                torch.nn.Linear(feature_dim, feature_dim // 4),
                torch.nn.ReLU(),
                torch.nn.Linear(feature_dim // 4, feature_dim),
                torch.nn.Sigmoid()
            )

    @property
    def base_threshold(self):
        """Get actual base threshold value (always positive)"""
        if self.gating_no_metanet:
            return torch.exp(self.log_base_threshold)
        return torch.tensor(0.0)

    @property
    def beta(self):
        """Get actual beta value (always positive)"""
        if self.gating_no_metanet:
            return torch.exp(self.log_beta)
        return torch.tensor(0.0)

    def compute_uncertainty(self, features):
        """Compute uncertainty based on feature characteristics

        Args:
            features: Input features [batch_size, feature_dim]

        Returns:
            uncertainties: Uncertainty scores for each feature [batch_size, feature_dim]
        """
        # Get feature-specific uncertainty using network
        feature_uncertainty = self.uncertainty_net(features)

        # Add batch statistics component
        batch_std = features.std(dim=0, keepdim=True).expand(features.size(0), -1)
        batch_std = batch_std / (batch_std.max() + 1e-8)  # Normalize

        # Combine components
        combined_uncertainty = 0.7 * feature_uncertainty + 0.3 * batch_std

        # Add randomness during training
        if self.training:
            random_noise = torch.rand_like(combined_uncertainty) * 0.1
        else:
            random_noise = torch.ones_like(combined_uncertainty) * 0.05

        combined_uncertainty = combined_uncertainty + random_noise

        # Normalize to [0, 1]
        combined_uncertainty = combined_uncertainty.clamp(min=0.01, max=1.0)

        return combined_uncertainty

    def adaptive_gating(self, features, uncertainties):
        """Apply adaptive thresholding based on uncertainty

        Args:
            features: Original features [batch_size, feature_dim]
            uncertainties: Uncertainty scores [batch_size, feature_dim]

        Returns:
            gated_features: Features after gating
            thresholds: Computed thresholds
        """
        # Get parameters
        base_threshold = self.base_threshold
        beta_val = self.beta

        # Compute adaptive thresholds
        thresholds = base_threshold * (1.0 + beta_val * uncertainties)

        # Normalize features for gating
        feature_norms = torch.norm(features, dim=1, keepdim=True)
        normalized_features = features / (feature_norms + 1e-8)
        feature_magnitudes = torch.abs(normalized_features)

        # Apply gating with smooth transition
        sigmoid_scale = 20.0
        gating_mask = torch.sigmoid(sigmoid_scale * (feature_magnitudes - thresholds))
        gated_features = features * gating_mask

        # Store gating mask for statistics
        self.last_gating_mask = (feature_magnitudes >= thresholds).float().detach()

        return gated_features, thresholds

    def forward(self, features):
        """Forward pass with optional adaptive gating

        Args:
            features: Pre-computed feature vectors [batch_size, feature_dim]

        Returns:
            features: Original or gated features
        """
        if self.gating_no_metanet:
            # Apply gating to features
            self._forward_count = getattr(self, '_forward_count', 0) + 1
            self.last_orig_features = features.detach()

            # Compute uncertainty
            uncertainties = self.compute_uncertainty(features)
            self.last_uncertainties = uncertainties

            # Apply adaptive gating
            gated_features, thresholds = self.adaptive_gating(features, uncertainties)
            self.last_gated_features = gated_features
            self.last_thresholds = thresholds

            # Add dummy computation scaled to zero
            return gated_features + self.dummy_linear(features) * 0.0
        else:
            # Identity projection + dummy computation scaled to zero
            return self.projection(features) + self.dummy_linear(features) * 0.0

    def uncertainty_regularization_loss(self):
        """Calculate regularization loss based on uncertainty and gating

        Returns:
            Regularization loss
        """
        if not self.gating_no_metanet or self.uncertainty_reg < 1e-8:
            # Return dummy loss scaled to zero
            return self.dummy_param.sum() * 0.0

        self._reg_loss_count = getattr(self, '_reg_loss_count', 0) + 1

        # Check for necessary stored values
        if (self.last_uncertainties is None or
                self.last_gated_features is None or
                self.last_orig_features is None):
            return self.base_threshold * 0.001 + self.beta * 0.001

        # Compute weighted uncertainty loss
        active_mask = (self.last_gated_features != 0).float()
        uncertainty_loss = torch.sum(active_mask * self.last_uncertainties) * self.uncertainty_reg

        return uncertainty_loss

    def get_gating_stats(self):
        """Get statistics about the gating process

        Returns:
            Dictionary with gating statistics
        """
        if not self.gating_no_metanet:
            return {}

        # Calculate gating ratio
        if self.last_gated_features is None or self.last_gating_mask is None:
            # Use sample data for statistics
            batch_size = 64
            features = torch.randn(batch_size, self.feature_dim, device=self.log_base_threshold.device)

            with torch.no_grad():
                uncertainties = self.compute_uncertainty(features)
                thresholds = self.base_threshold * (1.0 + self.beta * uncertainties)

                # Calculate gating mask
                feature_norms = torch.norm(features, dim=1, keepdim=True)
                normalized_features = features / (feature_norms + 1e-8)
                feature_magnitudes = torch.abs(normalized_features)
                gating_mask = (feature_magnitudes >= thresholds).float()
                gating_ratio = gating_mask.mean().item()
        else:
            # Use stored values
            gating_ratio = self.last_gating_mask.mean().item()

        # Get parameter values
        current_base_threshold = self.base_threshold.item()
        current_beta = self.beta.item()
        current_log_base_threshold = self.log_base_threshold.item()
        current_log_beta = self.log_beta.item()

        # Get initial values
        initial_base_threshold = self.initial_base_threshold.item()
        initial_beta = self.initial_beta.item()

        # Calculate changes
        threshold_change = ((current_base_threshold - initial_base_threshold) / initial_base_threshold) * 100
        beta_change = ((current_beta - initial_beta) / initial_beta) * 100

        # Get average values
        avg_threshold = self.last_thresholds.mean().item() if hasattr(self, 'last_thresholds') and self.last_thresholds is not None else current_base_threshold
        avg_uncertainty = self.last_uncertainties.mean().item() if hasattr(self, 'last_uncertainties') and self.last_uncertainties is not None else 0.0

        return {
            "gating_ratio": gating_ratio,
            "avg_threshold": avg_threshold,
            "avg_uncertainty": avg_uncertainty,
            "base_threshold": current_base_threshold,
            "beta": current_beta,
            "log_base_threshold": current_log_base_threshold,
            "log_beta": current_log_beta,
            "initial_base_threshold": initial_base_threshold,
            "initial_beta": initial_beta,
            "threshold_change_percent": threshold_change,
            "beta_change_percent": beta_change,
            "forward_count": getattr(self, '_forward_count', 0),
            "reg_loss_count": getattr(self, '_reg_loss_count', 0),
        }


class PrecomputedFeatureDataset(torch.utils.data.Dataset):
    """Dataset for precomputed features with augmentation support

    Args:
        features_path: Path to precomputed features tensor
        labels_path: Path to labels tensor
        verbose: Whether to print detailed logs
        augmentation_paths: List of paths to augmented feature/label pairs
        use_augmentation: Whether to use augmented versions when available
    """
    def __init__(self, features_path, labels_path, verbose=False,
                 augmentation_paths=None, use_augmentation=True):
        super().__init__()

        # Store augmentation settings
        self.training = True
        self.use_augmentation = use_augmentation
        self.augmentation_paths = []
        if augmentation_paths is not None:
            self.augmentation_paths = augmentation_paths

        # Load features and labels
        try:
            self.features = torch.load(features_path)
        except Exception as e:
            raise RuntimeError(f"Failed to load features from {features_path}: {e}")

        try:
            self.labels = torch.load(labels_path)
        except Exception as e:
            raise RuntimeError(f"Failed to load labels from {labels_path}: {e}")

        # Validate matching dimensions
        if len(self.features) != len(self.labels):
            raise ValueError(f"Features ({len(self.features)}) and labels ({len(self.labels)}) count mismatch")

        # Load augmented versions
        self.augmented_features = []
        self.augmented_labels = []

        if augmentation_paths and use_augmentation:
            for aug_idx, (aug_feat_path, aug_label_path) in enumerate(augmentation_paths):
                if os.path.exists(aug_feat_path) and os.path.exists(aug_label_path):
                    try:
                        aug_features = torch.load(aug_feat_path)
                        aug_labels = torch.load(aug_label_path)

                        # Verify shapes match
                        if aug_features.shape == self.features.shape and aug_labels.shape == self.labels.shape:
                            self.augmented_features.append(aug_features)
                            self.augmented_labels.append(aug_labels)
                    except Exception:
                        pass

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        # During training, randomly choose augmented versions
        if self.training and self.augmented_features and self.use_augmentation and random.random() > 0.2:
            # 80% chance to use augmented features
            aug_idx = random.randint(0, len(self.augmented_features) - 1)
            return {
                "features": self.augmented_features[aug_idx][idx],
                "labels": self.augmented_labels[aug_idx][idx],
                "index": idx,
                "augmented": True
            }
        else:
            # Use original features or when evaluating
            return {
                "features": self.features[idx],
                "labels": self.labels[idx],
                "index": idx,
                "augmented": False
            }

    def train(self, mode=True):
        """Sets the dataset in training mode"""
        self.training = mode
        return self


class PrecomputedFeatures:
    """Dataset container class for precomputed features with augmentation support

    Args:
        feature_dir: Path to directory with precomputed features
        batch_size: Batch size for dataloaders
        num_workers: Number of worker threads for dataloaders
        persistent_workers: Whether to keep worker processes alive
        use_augmentation: Whether to use augmentations during training
    """
    def __init__(self,
                 feature_dir,
                 batch_size=128,
                 num_workers=8,
                 persistent_workers=False,
                 use_augmentation=True):
        # Verify directory exists
        if not os.path.exists(feature_dir):
            raise FileNotFoundError(f"Feature directory not found: {feature_dir}")

        # Define file paths
        train_features_path = os.path.join(feature_dir, "train_features.pt")
        train_labels_path = os.path.join(feature_dir, "train_labels.pt")
        val_features_path = os.path.join(feature_dir, "val_features.pt")
        val_labels_path = os.path.join(feature_dir, "val_labels.pt")

        # Check for train files
        if not os.path.exists(train_features_path):
            raise FileNotFoundError(f"Train features not found at {train_features_path}")

        # Find augmented versions
        augmentation_paths = []
        aug_idx = 1

        while True:
            aug_feat_path = os.path.join(feature_dir, f"train_features_aug{aug_idx}.pt")
            aug_label_path = os.path.join(feature_dir, f"train_labels_aug{aug_idx}.pt")

            if os.path.exists(aug_feat_path) and os.path.exists(aug_label_path):
                augmentation_paths.append((aug_feat_path, aug_label_path))
                aug_idx += 1
            else:
                break

        # Create train dataset
        self.train_dataset = PrecomputedFeatureDataset(
            train_features_path,
            train_labels_path,
            verbose=False,
            augmentation_paths=augmentation_paths,
            use_augmentation=use_augmentation
        )
        self.train_dataset.train(True)

        # Create train loader
        self.train_loader = torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=num_workers,
            persistent_workers=persistent_workers and num_workers > 0,
            pin_memory=True,
            drop_last=False,
            timeout=120,
        )

        # Create test dataset
        self.test_dataset = PrecomputedFeatureDataset(
            val_features_path if os.path.exists(val_features_path) else train_features_path,
            val_labels_path if os.path.exists(val_labels_path) else train_labels_path,
            verbose=False,
            augmentation_paths=None,
            use_augmentation=False
        )
        self.test_dataset.train(False)

        # Create test loader
        self.test_loader = torch.utils.data.DataLoader(
            self.test_dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers,
            persistent_workers=persistent_workers and num_workers > 0,
            pin_memory=True,
            drop_last=False,
            timeout=120,
        )

        # Load classnames if available
        classnames_path = os.path.join(feature_dir, "classnames.txt")
        if os.path.exists(classnames_path):
            with open(classnames_path, "r") as f:
                self.classnames = [line.strip() for line in f.readlines()]
        else:
            # Create dummy classnames if file doesn't exist
            unique_labels = torch.unique(self.train_dataset.labels)
            self.classnames = [f"class_{i}" for i in range(len(unique_labels))]


def cleanup_resources(dataset):
    """Cleanup data resources to prevent memory leaks

    Args:
        dataset: Dataset to clean up
    """
    if dataset is None:
        return

    try:
        # Clear dataset references
        if hasattr(dataset, 'test_loader') and dataset.test_loader is not None:
            dataset.test_loader = None
        if hasattr(dataset, 'train_loader') and dataset.train_loader is not None:
            dataset.train_loader = None
        if hasattr(dataset, 'test_dataset') and dataset.test_dataset is not None:
            dataset.test_dataset = None
        if hasattr(dataset, 'train_dataset') and dataset.train_dataset is not None:
            dataset.train_dataset = None
    except Exception as e:
        print(f"Warning during dataset cleanup: {e}")

    # Force garbage collection
    gc.collect()
    torch.cuda.empty_cache()


def evaluate_model(model, classifier, dataset, device):
    """Evaluate model on dataset

    Args:
        model: Model to evaluate
        classifier: Classification head
        dataset: Dataset with precomputed features
        device: Computation device

    Returns:
        Accuracy on test set
    """
    model.eval()
    classifier.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for batch in dataset.test_loader:
            batch = maybe_dictionarize(batch)
            features = batch["features"].to(device)
            labels = batch["labels"].to(device)

            # Forward pass
            transformed_features = model(features)
            outputs = classifier(transformed_features)

            # Compute accuracy
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    return correct / total


def train_with_adaptive_gating(rank, args):
    """Main training function with adaptive gating

    Args:
        rank: Process rank for distributed training
        args: Parsed command-line arguments
    """
    args.rank = rank

    # Initialize distributed setup
    setup_ddp(args.rank, args.world_size, port=args.port)

    # Set random seed
    if args.seed is not None:
        random.seed(args.seed)
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(args.seed)

    # Apply no-gating settings if specified
    if args.no_gating and not args.no_metanet:
        args.base_threshold = 1e-9
        args.beta = 1e-9
        args.uncertainty_reg = 0.0
        if is_main_process():
            print(f"No-gating mode enabled with minimal threshold/beta values")

    # Process datasets
    datasets_to_process = args.datasets

    # Create model-specific directory
    model_save_dir = os.path.join(args.save_dir, args.model)
    os.makedirs(model_save_dir, exist_ok=True)

    # Print configuration summary
    if rank == 0:
        print(f"\n=== Training Configuration ===")
        print(f"Model: {args.model}")
        print(f"Using MetaNet: {not args.no_metanet}")
        if args.no_metanet:
            print(f"Atlas with Gating: {args.gating_no_metanet}")
        else:
            print(f"Blockwise coefficients: {args.blockwise_coef}")
            print(f"No gating: {args.no_gating}")
        print(f"Base threshold: {args.base_threshold}")
        print(f"Beta: {args.beta}")
        print(f"Learning rate: {args.lr}")
        print(f"Batch size: {args.batch_size}")
        print(f"Epochs: {args.epochs}")
        print(f"Save directory: {model_save_dir}")
        print(f"Datasets: {datasets_to_process}")
        print("=" * 30)

    for dataset_name in datasets_to_process:
        if is_main_process():
            if args.no_metanet:
                if args.gating_no_metanet:
                    print(f"=== Training on {dataset_name} with Atlas + Gating ===")
                else:
                    print(f"=== Training on {dataset_name} with Atlas (no MetaNet) ===")
            elif args.no_gating:
                print(f"=== Training on {dataset_name} with MetaNet only (no gating) ===")
            else:
                print(f"=== Training on {dataset_name} with adaptive gating ===")

        # Setup save directory
        save_dir = os.path.join(model_save_dir, dataset_name + "Val")
        if is_main_process():
            os.makedirs(save_dir, exist_ok=True)

        dataset = None

        try:
            # Define feature directory
            feature_dir = os.path.join(args.data_location, "precomputed_features", args.model, dataset_name + "Val")

            # Verify directory exists
            if not os.path.exists(feature_dir):
                if is_main_process():
                    print(f"Error: Feature directory not found at {feature_dir}")
                continue

            # Load dataset
            dataset = PrecomputedFeatures(
                feature_dir=feature_dir,
                batch_size=args.batch_size,
                num_workers=2,
                use_augmentation=args.use_augmentation,
            )

            # Get feature dimension
            sample_batch = next(iter(dataset.train_loader))
            sample_batch = maybe_dictionarize(sample_batch)
            feature_dim = sample_batch["features"].shape[1]
            if is_main_process():
                print(f"Feature dimension: {feature_dim}")

            if args.no_metanet:
                # Create Atlas model
                model = DirectFeatureModel(
                    feature_dim=feature_dim,
                    gating_no_metanet=args.gating_no_metanet,
                    base_threshold=args.base_threshold,
                    beta=args.beta,
                    uncertainty_reg=args.uncertainty_reg
                )
                if is_main_process():
                    if args.gating_no_metanet:
                        print(f"Created Atlas model with gating mechanism")
                    else:
                        print(f"Created Atlas model (direct features)")
            else:
                # Create adaptive gating model
                model = AdaptiveGatingMetaNet(
                    feature_dim=feature_dim,
                    task_vectors=args.num_task_vectors,
                    blockwise=args.blockwise_coef,
                    base_threshold=args.base_threshold,
                    beta=args.beta,
                    uncertainty_reg=args.uncertainty_reg,
                    reg_coefficient=args.reg_coefficient,
                    margin_weight=args.margin_weight
                )
                if is_main_process():
                    print(f"Created AdaptiveGatingMetaNet with {args.num_task_vectors} task vectors")

            model = model.to(rank)

            # Setup training
            data_loader = dataset.train_loader
            num_batches = len(data_loader)
            print_every = max(num_batches // 5, 1)

            # Distributed training setup
            ddp_loader = distribute_loader(data_loader)
            ddp_model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[args.rank],
                find_unused_parameters=True
            )

            # Setup classifier
            num_classes = len(dataset.classnames)
            classifier = torch.nn.Linear(feature_dim, num_classes).cuda()

            # Setup optimizer with parameter groups
            # Group parameters by type for different learning rates
            if args.no_metanet:
                if args.gating_no_metanet:
                    # Atlas with gating
                    gating_params = []
                    uncertainty_net_params = []
                    other_params = []

                    for name, param in ddp_model.named_parameters():
                        if 'log_beta' in name or 'log_base_threshold' in name:
                            gating_params.append(param)
                        elif 'uncertainty_net' in name:
                            uncertainty_net_params.append(param)
                        else:
                            other_params.append(param)

                    other_params.extend(list(classifier.parameters()))

                    param_groups = [
                        {'params': gating_params, 'lr': args.lr * args.lr_multiplier, 'weight_decay': args.weight_decay},
                        {'params': uncertainty_net_params, 'lr': args.lr * 5.0, 'weight_decay': 0.001},
                        {'params': other_params, 'lr': args.lr, 'weight_decay': args.wd}
                    ]
                else:
                    # Simple Atlas
                    param_groups = [
                        {'params': list(ddp_model.parameters()) + list(classifier.parameters()),
                         'lr': args.lr, 'weight_decay': args.wd}
                    ]
            else:
                # MetaNet variants
                gating_log_params = []
                meta_net_params = []
                other_params = []

                for name, param in ddp_model.named_parameters():
                    if 'log_beta' in name or 'log_base_threshold' in name:
                        gating_log_params.append(param)
                    elif 'meta_net' in name:
                        meta_net_params.append(param)
                    else:
                        other_params.append(param)

                other_params.extend(list(classifier.parameters()))

                if args.no_gating:
                    param_groups = [
                        {'params': meta_net_params, 'lr': args.lr * 3, 'weight_decay': 0.001},
                        {'params': gating_log_params, 'lr': args.lr * 0.0001, 'weight_decay': 0.0},
                        {'params': other_params, 'lr': args.lr, 'weight_decay': args.wd}
                    ]
                else:
                    param_groups = [
                        {'params': gating_log_params, 'lr': args.lr * args.lr_multiplier, 'weight_decay': args.weight_decay},
                        {'params': meta_net_params, 'lr': args.lr * 3, 'weight_decay': 0.001},
                        {'params': other_params, 'lr': args.lr, 'weight_decay': args.wd}
                    ]

            optimizer = torch.optim.AdamW(param_groups)

            # Learning rate scheduler
            scheduler = cosine_lr(
                optimizer,
                args.lr,
                0,
                args.epochs * num_batches
            )

            # Loss function
            loss_fn = torch.nn.CrossEntropyLoss()

            # Training monitoring
            train_losses = []
            reg_losses = []
            val_accuracies = []
            gating_stats = []
            base_threshold_values = []
            beta_values = []
            log_base_threshold_values = []
            log_beta_values = []
            best_acc = 0.0
            best_model_state = None

            # Training loop
            for epoch in range(args.epochs):
                ddp_model.train()
                classifier.train()

                epoch_loss = 0.0
                epoch_reg_loss = 0.0
                batch_count = 0

                if is_main_process():
                    print(f"\nEpoch {epoch + 1}/{args.epochs} - Training")

                for i, batch in enumerate(ddp_loader):
                    start_time = time.time()

                    try:
                        batch = maybe_dictionarize(batch)
                        features = batch["features"].to(rank)
                        labels = batch["labels"].to(rank)

                        # Forward pass
                        transformed_features = ddp_model(features)
                        logits = classifier(transformed_features)

                        # Task loss
                        task_loss = loss_fn(logits, labels)

                        # Add uncertainty regularization
                        reg_loss = ddp_model.module.uncertainty_regularization_loss()
                        total_loss = task_loss + reg_loss

                        # Backward pass
                        total_loss.backward()

                        # Step optimizer
                        scheduler(i + epoch * num_batches)
                        optimizer.step()
                        optimizer.zero_grad()

                        # Record stats
                        task_loss_cpu = task_loss.item()
                        reg_loss_cpu = reg_loss.item()
                        batch_count += 1
                        epoch_loss += task_loss_cpu
                        epoch_reg_loss += reg_loss_cpu

                        if is_main_process():
                            train_losses.append(task_loss_cpu)
                            reg_losses.append(reg_loss_cpu)

                            # Get gating statistics
                            if hasattr(ddp_model.module, 'get_gating_stats'):
                                stats = ddp_model.module.get_gating_stats()
                                if stats:
                                    gating_stats.append(stats)

                        # Print progress with reduced frequency
                        if i % print_every == 0 and is_main_process():
                            # Simple output format
                            print(f"  Batch {i:4d}/{num_batches:4d} | "
                                  f"Loss: {task_loss_cpu:.4f} | "
                                  f"Reg: {reg_loss_cpu:.4f} | "
                                  f"Time: {time.time() - start_time:.2f}s")

                    except Exception as e:
                        if is_main_process():
                            print(f"  Error in batch {i}: {e}")
                        continue

                # Compute average epoch loss
                avg_epoch_loss = epoch_loss / batch_count if batch_count > 0 else 0
                avg_epoch_reg_loss = epoch_reg_loss / batch_count if batch_count > 0 else 0

                # Record parameter values for gating models
                if is_main_process():
                    # For gating models, track parameter evolution
                    if (args.no_metanet and args.gating_no_metanet) or (not args.no_metanet and not args.no_gating):
                        current_base_threshold = float(ddp_model.module.base_threshold.item())
                        current_beta = float(ddp_model.module.beta.item())
                        current_log_base_threshold = float(ddp_model.module.log_base_threshold.item())
                        current_log_beta = float(ddp_model.module.log_beta.item())

                        # Save to monitoring lists
                        base_threshold_values.append(current_base_threshold)
                        beta_values.append(current_beta)
                        log_base_threshold_values.append(current_log_base_threshold)
                        log_beta_values.append(current_log_beta)

                        print(f"  Summary: Loss: {avg_epoch_loss:.4f} | Reg: {avg_epoch_reg_loss:.4f} | "
                              f"αT: {current_base_threshold:.4f} | β: {current_beta:.4f}")
                    else:
                        # For non-gating models
                        base_threshold_values.append(0.0)
                        beta_values.append(0.0)
                        log_base_threshold_values.append(0.0)
                        log_beta_values.append(0.0)
                        print(f"  Summary: Loss: {avg_epoch_loss:.4f} | Reg: {avg_epoch_reg_loss:.4f}")

                # Evaluate on validation set
                if is_main_process():
                    print(f"Epoch {epoch+1}/{args.epochs} - Validation")

                    val_acc = evaluate_model(
                        model=ddp_model.module,
                        classifier=classifier,
                        dataset=dataset,
                        device=rank
                    )
                    val_accuracies.append(val_acc)

                    print(f"  Accuracy: {val_acc*100:.2f}% ({int(val_acc * len(dataset.test_dataset))}/{len(dataset.test_dataset)})")

                    # Save best model
                    if val_acc > best_acc:
                        best_acc = val_acc

                        # Create model configuration
                        if args.no_metanet:
                            if args.gating_no_metanet:
                                # Configuration for Atlas with Gating
                                config = {
                                    'feature_dim': feature_dim,
                                    'no_metanet': True,
                                    'gating_no_metanet': True,
                                    'base_threshold': current_base_threshold,
                                    'beta': current_beta,
                                    'uncertainty_reg': args.uncertainty_reg,
                                    'model_name': args.model
                                }
                            else:
                                # Configuration for Atlas
                                config = {
                                    'feature_dim': feature_dim,
                                    'no_metanet': True,
                                    'model_name': args.model
                                }
                        elif args.no_gating:
                            # Configuration for no-gating MetaNet
                            config = {
                                'feature_dim': feature_dim,
                                'num_task_vectors': args.num_task_vectors,
                                'blockwise': args.blockwise_coef,
                                'base_threshold': 0.0,
                                'beta': to.0,
                                'uncertainty_reg': 0.0,
                                'model_name': args.model,
                                'no_gating': True,
                                'no_metanet': False
                            }
                        else:
                            # Configuration for adaptive gating MetaNet
                            config = {
                                'feature_dim': feature_dim,
                                'num_task_vectors': args.num_task_vectors,
                                'blockwise': args.blockwise_coef,
                                'base_threshold': current_base_threshold,
                                'beta': current_beta,
                                'uncertainty_reg': args.uncertainty_reg,
                                'model_name': args.model,
                                'no_gating': False,
                                'no_metanet': False
                            }

                        best_model_state = {
                            'meta_net': ddp_model.module.state_dict(),
                            'classifier': classifier.state_dict(),
                            'epoch': epoch,
                            'acc': val_acc,
                            'config': config
                        }

                        print(f"  New best model! Accuracy: {best_acc*100:.2f}%")

            # Save results
            if is_main_process():
                # Choose appropriate suffix
                if args.no_metanet:
                    if args.gating_no_metanet:
                        model_type_suffix = "_atlas_with_gating"
                    else:
                        model_type_suffix = "_atlas"
                elif args.no_gating:
                    model_type_suffix = "_no_gating"
                else:
                    model_type_suffix = "_adaptive_gating"

                # Save best model
                if best_model_state:
                    best_model_path = os.path.join(save_dir, f"best{model_type_suffix}_model.pt")
                    torch.save(best_model_state, best_model_path)

                    # Save a copy with standard name
                    torch.save(best_model_state, os.path.join(save_dir, "best_model.pt"))

                # Save training history
                history = {
                    'train_losses': train_losses,
                    'reg_losses': reg_losses,
                    'val_accuracies': val_accuracies,
                    'gating_stats': gating_stats,
                    'base_threshold_values': base_threshold_values,
                    'beta_values': beta_values,
                    'log_base_threshold_values': log_base_threshold_values,
                    'log_beta_values': log_beta_values,
                    'best_acc': best_acc,
                    'config': config if 'config' in locals() else {},
                    'use_augmentation': args.use_augmentation,
                    'no_gating': args.no_gating,
                    'no_metanet': args.no_metanet,
                    'gating_no_metanet': args.gating_no_metanet
                }

                # Use appropriate suffix for the history file
                history_path = os.path.join(save_dir, f"{model_type_suffix.strip('_')}_training_history.json")

                with open(history_path, 'w') as f:
                    # Convert numpy values to Python types
                    for key in history:
                        if isinstance(history[key], (list, dict)) and key not in ['gating_stats']:
                            history[key] = [float(item) if isinstance(item, (np.floating, np.integer)) else item
                                          for item in history[key]]

                    json.dump(history, f, indent=4)

                print(f"Training completed for {dataset_name}. Best accuracy: {best_acc*100:.2f}%")

        except Exception as e:
            if is_main_process():
                print(f"Error processing dataset {dataset_name}: {e}")
                traceback.print_exc()
        finally:
            # Clean up resources
            cleanup_resources(dataset)
            torch.cuda.empty_cache()
            gc.collect()

    # Clean up distributed environment
    cleanup_ddp()


if __name__ == "__main__":
    args = parse_arguments()
    try:
        torch.multiprocessing.spawn(train_with_adaptive_gating, args=(args,), nprocs=args.world_size)
    except Exception as e:
        print(f"Training failed with error: {e}")
        traceback.print_exc()