import os
import random
import numpy as np
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
import wandb
from datetime import datetime, timedelta
import pickle
import json
import torch.nn as nn

from data.loaders.physionet_2012_dataset import PhysioNet2012
from data.loaders.physionet_2012_raindrop_dataset import PhysioNet2012FromRaindrop
from model.utils import OneHotEmbedder, SymmetricStabilizedBCEWithLogitsLoss
from config_ddp import get_config
from model.GMAN import GMAN
from model.GMAN_Ablation import GMAN_Ablation
from data.collate_fns.GMAN.physionet import distance_collate_fn_physionet
from model.set_trainer import train_epoch, test_epoch


class DistributedBalancedBatchSampler:
    """
    Distributed balanced batch sampler that mirrors Raindrop's per-batch minority upsampling.

    - Each batch contains a 50/50 class balance
    - Strategy can be 'upsample' (repeat minority) or 'downsample' (subsample majority)
    - For 'upsample': positives are repeated (upsample_factor times) and shuffled each epoch
    - For 'downsample': majority class is truncated each epoch to match minority count
    - Batches are evenly partitioned across ranks (stride selection)
    """
    def __init__(self, labels_tensor: torch.Tensor, batch_size: int, world_size: int, rank: int, upsample_factor: int = 3, drop_last: bool = True, balance_strategy: str = 'upsample'):
        assert batch_size % 2 == 0, "Batch size must be even for balanced batches"
        self.labels = labels_tensor.detach().cpu().long().numpy().ravel()
        self.batch_size = int(batch_size)
        self.half = self.batch_size // 2
        self.world_size = int(world_size)
        self.rank = int(rank)
        self.upsample_factor = int(upsample_factor)
        self.drop_last = bool(drop_last)
        self.balance_strategy = str(balance_strategy)
        self.epoch = 0

        self.neg_indices = np.where(self.labels == 0)[0]
        self.pos_indices = np.where(self.labels == 1)[0]
        # Precompute max batches per full (global) epoch based on strategy
        if self.balance_strategy == 'downsample':
            minority_count = min(len(self.neg_indices), len(self.pos_indices))
            self.global_n_batches = max(0, minority_count // self.half)
        else:
            expanded_pos = np.concatenate([self.pos_indices] * max(1, self.upsample_factor), axis=0)
            self.global_n_batches = max(0, min(len(self.neg_indices), len(expanded_pos)) // self.half)

    def set_epoch(self, epoch: int):
        self.epoch = int(epoch)

    def __iter__(self):
        if self.global_n_batches == 0:
            return iter([])

        rng = np.random.default_rng(seed=12345 + self.epoch)
        I0 = self.neg_indices.copy()
        I1 = self.pos_indices.copy()
        rng.shuffle(I0)
        rng.shuffle(I1)

        if self.balance_strategy == 'downsample':
            # Truncate majority class to match minority count
            minority_count = min(len(I0), len(I1))
            if len(I0) > minority_count:
                I0 = I0[:minority_count]
            if len(I1) > minority_count:
                I1 = I1[:minority_count]
        else:
            # Upsample positives by repetition
            I1 = np.concatenate([I1] * max(1, self.upsample_factor), axis=0)
            rng.shuffle(I1)

        # Ensure enough negatives/positives by truncating to exact multiples
        max_batches = self.global_n_batches
        # Build global list of balanced batches
        batches = []
        for n in range(max_batches):
            start = n * self.half
            end = (n + 1) * self.half
            idx0_batch = I0[start:end]
            idx1_batch = I1[start:end]
            if len(idx0_batch) < self.half or len(idx1_batch) < self.half:
                if self.drop_last:
                    continue
                else:
                    # Pad if needed (rare)
                    pad0 = self.half - len(idx0_batch)
                    pad1 = self.half - len(idx1_batch)
                    if pad0 > 0:
                        idx0_batch = np.concatenate([idx0_batch, I0[:pad0]])
                    if pad1 > 0:
                        idx1_batch = np.concatenate([idx1_batch, I1[:pad1]])
            batch_idx = np.concatenate([idx0_batch, idx1_batch], axis=0)
            # Shuffle within-batch for randomness
            rng.shuffle(batch_idx)
            batches.append(batch_idx.tolist())

        # Partition batches across ranks with equal contiguous chunks; drop any remainder
        total_batches = len(batches)
        per_rank = total_batches // self.world_size
        if per_rank == 0:
            return iter([])
        total_used = per_rank * self.world_size
        start = self.rank * per_rank
        end = start + per_rank
        my_batches = batches[:total_used][start:end]
        return iter(my_batches)

    def __len__(self):
        if self.global_n_batches == 0:
            return 0
        # Ensure identical length across all ranks
        return self.global_n_batches // self.world_size


class CrossEntropyFromLogit(nn.Module):
    """
    CrossEntropyLoss computed from a single logit per sample by constructing two-class logits [0, z].
    This is equivalent (numerically) to BCEWithLogitsLoss on z with targets in {0,1}.
    """
    def __init__(self):
        super().__init__()
        self.ce = nn.CrossEntropyLoss()

    def forward(self, logits_1d: torch.Tensor, targets_float: torch.Tensor):
        # logits_1d: shape [B], targets_float: shape [B] in {0,1}
        two_class_logits = torch.stack([torch.zeros_like(logits_1d), logits_1d], dim=1)
        targets = targets_float.long()
        return self.ce(two_class_logits, targets)

def setup(rank, world_size):
    """Initialize the distributed environment."""
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    
    # Set NCCL environment variables to handle containerized environment
    os.environ['NCCL_IB_DISABLE'] = '1'
    os.environ['NCCL_P2P_DISABLE'] = '1'
    os.environ['NCCL_SOCKET_IFNAME'] = 'lo'
    os.environ['NCCL_DEBUG'] = 'INFO'
    
    # Use gloo backend which is more reliable in containerized environments
    # Use a generous timeout because rank 0 performs validation/testing while others wait
    dist.init_process_group("gloo", rank=rank, world_size=world_size, timeout=timedelta(minutes=20))


def cleanup():
    """Clean up the distributed environment."""
    dist.destroy_process_group()


def setup_model_and_data(rank, world_size, exp_config, device):
    """Setup model, data loaders, and optimizer for DDP training."""
    
    print(f"Rank {rank}: Setting up model and data on device {device}")
    
    # Option A: Use Raindrop preprocessed data and splits (for fair comparison)
    use_raindrop = bool(getattr(exp_config, 'use_raindrop_data', False))
    pos_weight = 6.0

    if getattr(exp_config, 'use_cached_dataset', False):
        # Option C: Load cached PSV files with pre-written newlabel from a directory using split_1.pkl
        import pickle as pkl
        split_pkl_path = getattr(exp_config, 'split_pkl_path', 'P12_data_splits/split_1.pkl')
        cached_dir = getattr(exp_config, 'cached_dataset_dir', '/tmp')
        if rank == 0:
            print(f"[Cached] Loading cached PSV dataset from {cached_dir} using splits at {split_pkl_path}")
        with open(split_pkl_path, 'rb') as f:
            data_split = pkl.load(f)
        # Expect paths in split file to be relative; map to cached_dir filenames by RecordID stem
        def map_to_cached(psv_paths):
            mapped = []
            for rel_path in psv_paths:
                try:
                    rid = int(os.path.splitext(os.path.basename(rel_path))[0])
                except Exception:
                    continue
                cand = os.path.join(cached_dir, f"{rid}.psv")
                if os.path.exists(cand):
                    mapped.append(cand)
            return mapped
        train_files = map_to_cached(data_split['train_files'])
        val_files = map_to_cached(data_split['val_files'])
        test_files = map_to_cached(data_split['test_files'])
        if rank == 0:
            print(f"[Cached] Files -> train:{len(train_files)} val:{len(val_files)} test:{len(test_files)}")
        one_hot_embedder = OneHotEmbedder(input_dim=exp_config.num_biom, output_dim=exp_config.num_biom_embed)
        train_dataset = PhysioNet2012(files=train_files, load_cached_dataset=True, biom_one_hot_embedder=one_hot_embedder,
                                      predictive_label=getattr(exp_config, 'predictive_label', 'mortality'),
                                      los_threshold_days=int(getattr(exp_config, 'los_threshold_days', 3)))
        val_dataset = PhysioNet2012(files=val_files, load_cached_dataset=True, biom_one_hot_embedder=one_hot_embedder,
                                    predictive_label=getattr(exp_config, 'predictive_label', 'mortality'),
                                    los_threshold_days=int(getattr(exp_config, 'los_threshold_days', 3)))
        test_dataset = PhysioNet2012(files=test_files, load_cached_dataset=True, biom_one_hot_embedder=one_hot_embedder,
                                     predictive_label=getattr(exp_config, 'predictive_label', 'mortality'),
                                     los_threshold_days=int(getattr(exp_config, 'los_threshold_days', 3)))
    elif use_raindrop:
        # Lazy import Raindrop utilities
        from baselines.Raindrop.code.baselines.utils_phy12 import (
            get_data_split,
            getStats,
            getStats_static,
            tensorize_normalize,
        )

        base_path = getattr(exp_config, 'raindrop_base_path', '')
        split_idx = int(getattr(exp_config, 'raindrop_split_idx', 1))
        split_type = getattr(exp_config, 'raindrop_split_type', 'random')
        reverse = bool(getattr(exp_config, 'raindrop_reverse', False))
        predictive_label = getattr(exp_config, 'predictive_label', 'LoS')

        split_path = f'/splits/phy12_split{split_idx}.npy'
        if rank == 0:
            print(f"[Raindrop] Loading P12 from {base_path} split {split_path} (type={split_type}, reverse={reverse})")
        
        print(f"[Target] Predictive label: {predictive_label}")
        print(f"[Target] Los threshold days: {getattr(exp_config, 'los_threshold_days', 3)}")

        Ptrain, Pval, Ptest, ytrain, yval, ytest = get_data_split(
            base_path,
            split_path,
            split_type=split_type,
            reverse=reverse,
            baseline=True,
            dataset='P12',
            predictive_label=predictive_label,
            los_threshold_days=int(getattr(exp_config, 'los_threshold_days', 3)),
        )

        # Compute normalization statistics on train as Raindrop does
        T, F = Ptrain[0]['arr'].shape
        Ptrain_arr = np.stack([p['arr'] for p in Ptrain], axis=0)
        Ptrain_static = np.stack([p['extended_static'] for p in Ptrain], axis=0)
        mf, stdf = getStats(Ptrain_arr)
        ms, ss = getStats_static(Ptrain_static, dataset='P12')

        # Convert to tensors in Raindrop format (kept in memory) then adapt to GMAN graphs per sample
        # We keep Raindrop per-patient dicts for constructing GMAN graph dicts on-the-fly in Dataset
        one_hot_embedder = OneHotEmbedder(input_dim=exp_config.num_biom, output_dim=exp_config.num_biom_embed)

        train_dataset = PhysioNet2012FromRaindrop(Ptrain, ytrain, one_hot_embedder,
                                                  predictive_label=predictive_label,
                                                  los_threshold_days=int(getattr(exp_config, 'los_threshold_days', 3)))
        val_dataset = PhysioNet2012FromRaindrop(Pval, yval, one_hot_embedder,
                                                predictive_label=predictive_label,
                                                los_threshold_days=int(getattr(exp_config, 'los_threshold_days', 3)))
        test_dataset = PhysioNet2012FromRaindrop(Ptest, ytest, one_hot_embedder,
                                                 predictive_label=predictive_label,
                                                 los_threshold_days=int(getattr(exp_config, 'los_threshold_days', 3)))
    else:
        # Option B: Original CSV file-based splits
        with open(os.path.join("P12_data_splits", "split_1.pkl"), "rb") as f:
            data_split = pickle.load(f)
        train_files = data_split["train_files"]
        val_files = data_split["val_files"]
        test_files = data_split["test_files"]

        if rank == 0:
            print("Loaded pre-computed CSV file splits:")
            print(f"Number of training samples: {len(train_files)}")
            print(f"Number of validation samples: {len(val_files)}")
            print(f"Number of test samples: {len(test_files)}")
            print("Class weights will be computed from training labels.")

        # Initialize one-hot embedder and datasets.
        one_hot_embedder = OneHotEmbedder(input_dim=exp_config.num_biom, output_dim=exp_config.num_biom_embed)

        train_dataset = PhysioNet2012(files=train_files, load_cached_dataset=False, biom_one_hot_embedder=one_hot_embedder,
                                      predictive_label=getattr(exp_config, 'predictive_label', 'LoS'),
                                      los_threshold_days=int(getattr(exp_config, 'los_threshold_days', 3)))
        val_dataset = PhysioNet2012(files=val_files, load_cached_dataset=False, biom_one_hot_embedder=one_hot_embedder,
                                    predictive_label=getattr(exp_config, 'predictive_label', 'LoS'),
                                    los_threshold_days=int(getattr(exp_config, 'los_threshold_days', 3)))
        test_dataset = PhysioNet2012(files=test_files, load_cached_dataset=False, biom_one_hot_embedder=one_hot_embedder,
                                     predictive_label=getattr(exp_config, 'predictive_label', 'LoS'),
                                     los_threshold_days=int(getattr(exp_config, 'los_threshold_days', 3)))

    if use_raindrop and rank == 0:
        print("Raindrop splits loaded:")
        print(f"Number of training samples: {len(train_dataset)}")
        print(f"Number of validation samples: {len(val_dataset)}")
        print(f"Number of test samples: {len(test_dataset)}")
        print("Class weights will be computed from training labels.")
    
    # Per-device batch size to keep global batch size constant across world_size
    global_batch_size = int(exp_config.batch_size)
    per_device_batch_size = max(1, global_batch_size // max(1, world_size))
    # Balanced sampler requires even per-device batch size
    assert per_device_batch_size % 2 == 0, (
        f"Per-device batch size must be even for balanced sampling, got {per_device_batch_size}. "
        f"Consider setting batch_size to a multiple of 2*world_size (current world_size={world_size})."
    )

    # Initialize one-hot embedder and datasets (standard CSV path)
    if not getattr(exp_config, 'use_cached_dataset', False):
        one_hot_embedder = OneHotEmbedder(input_dim=exp_config.num_biom, output_dim=exp_config.num_biom_embed)
        train_dataset = PhysioNet2012(files=train_files, load_cached_dataset=False, biom_one_hot_embedder=one_hot_embedder)
        val_dataset = PhysioNet2012(files=val_files, load_cached_dataset=False, biom_one_hot_embedder=one_hot_embedder)
        test_dataset = PhysioNet2012(files=test_files, load_cached_dataset=False, biom_one_hot_embedder=one_hot_embedder)
    
    # Build labels tensor for balanced batch sampling (Raindrop-style upsampling)
    def extract_labels(dataset):
        labels = []
        for i in range(len(dataset)):
            try:
                _, y = dataset[i]
                if isinstance(y, torch.Tensor):
                    labels.append(float(y.view(-1)[0].item()))
                else:
                    labels.append(float(y))
            except Exception:
                labels.append(0.0)
        return torch.tensor(labels, dtype=torch.float)

    train_labels_tensor = extract_labels(train_dataset)

    # Compute pos_weight from training label imbalance: pos_weight = N_neg / N_pos
    num_samples = int(train_labels_tensor.numel())
    num_pos = int(train_labels_tensor.sum().item())
    num_neg = max(0, num_samples - num_pos)
    if num_pos == 0:
        pos_weight = 1.0
    else:
        pos_weight = float(num_neg) / float(num_pos)
    if rank == 0:
        print(f"[Imbalance] Train positives: {num_pos}, negatives: {num_neg}, pos_weight: {pos_weight:.4f}")
    # Ensure identical pos_weight across ranks
    if dist.is_initialized():
        pw_list = [float(pos_weight)]
        if rank != 0:
            pw_list = [0.0]
        dist.broadcast_object_list(pw_list, src=0)
        pos_weight = float(pw_list[0])

    # Create balanced distributed batch sampler (replaces DistributedSampler)
    train_sampler = DistributedBalancedBatchSampler(
        labels_tensor=train_labels_tensor,
        batch_size=per_device_batch_size,
        world_size=world_size,
        rank=rank,
        upsample_factor=int(getattr(exp_config, 'upsample_factor', 3)),
        drop_last=True,
        balance_strategy=str(getattr(exp_config, 'balance_strategy', 'upsample')).lower(),
    )

    # Create DataLoader with balanced batch sampler for training
    train_loader = DataLoader(
        train_dataset,
        batch_sampler=train_sampler,
        collate_fn=distance_collate_fn_physionet,
        num_workers=4,  # Reduced for stability
        pin_memory=True,
        persistent_workers=True,
        prefetch_factor=2
    )
    # For evaluation, only rank 0 builds non-distributed loaders to compute global metrics
    if rank == 0:
        val_loader = DataLoader(
            val_dataset,
            batch_size=exp_config.val_batch_size,
            shuffle=False,
            collate_fn=distance_collate_fn_physionet,
            num_workers=4,
            pin_memory=True,
            persistent_workers=True
        )
        test_loader = DataLoader(
            test_dataset,
            batch_size=exp_config.test_batch_size,
            shuffle=False,
            collate_fn=distance_collate_fn_physionet,
            num_workers=4,
            pin_memory=True,
            persistent_workers=True
        )
    else:
        val_loader = None
        test_loader = None
    
    # Create model
    print(f"Rank {rank}: Creating model...")
    use_ablation_model = (
        getattr(exp_config, 'disable_deepset', False) or 
        getattr(exp_config, 'disable_distance_embedding', False) or 
        getattr(exp_config, 'use_simple_aggregation', False) or
        getattr(exp_config, 'feature_processor_type', 'gnan') != 'gnan'
    )
    
    if use_ablation_model:
        print(f"Rank {rank}: Using GMAN_Ablation model")
        model = GMAN_Ablation(
            feature_groups=exp_config.feature_groups,
            out_channels=exp_config.out_channels,
            is_graph_task=exp_config.is_graph_task,
            batch_size=per_device_batch_size,
            n_layers=exp_config.n_layers,
            hidden_channels=exp_config.hidden_channels,
            dropout=exp_config.dropout,
            device=device,
            normalize_rho=exp_config.normalize_rho,
            max_num_GNANs=exp_config.max_num_GNANs,
            biomarker_groups=exp_config.biomarker_groups,
            gnan_mode=exp_config.gnan_mode,
            deepset_n_layers=2,
            # Ablation flags
            disable_deepset=getattr(exp_config, 'disable_deepset', False),
            disable_distance_embedding=getattr(exp_config, 'disable_distance_embedding', False),
            use_simple_aggregation=getattr(exp_config, 'use_simple_aggregation', False),
            feature_processor_type=getattr(exp_config, 'feature_processor_type', 'gnan'),
        ).to(device)
    else:
        print(f"Rank {rank}: Using GMAN model")
        model = GMAN(
            feature_groups=exp_config.feature_groups,
            out_channels=exp_config.out_channels,
            is_graph_task=exp_config.is_graph_task,
            batch_size=per_device_batch_size,
            n_layers=exp_config.n_layers,
            hidden_channels=exp_config.hidden_channels,
            dropout=exp_config.dropout,
            device=device,
            normalize_rho=exp_config.normalize_rho,
            max_num_GNANs=exp_config.max_num_GNANs,
            biomarker_groups=exp_config.biomarker_groups,
            gnan_mode=exp_config.gnan_mode,
            deepset_n_layers=3,
        ).to(device)
    
    print(f"Rank {rank}: Model created successfully")

    # Ensure all ranks start from identical parameters/buffers
    for param in model.parameters():
        dist.broadcast(param.data, src=0)
    for buffer in model.buffers():
        dist.broadcast(buffer.data, src=0)
    
    # Wrap model with DDP
    print(f"Rank {rank}: Wrapping model with DDP...")
    model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=True)
    print(f"Rank {rank}: DDP wrapping successful")
    
    # Create optimizer and scheduler (Plateau to mirror single-GPU behavior)
    optimizer = torch.optim.Adam(model.parameters(), lr=exp_config.lr, weight_decay=exp_config.wd)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.2, patience=5, min_lr=1e-4
    )
    
    return model, train_loader, val_loader, test_loader, optimizer, scheduler, train_sampler, pos_weight, per_device_batch_size


def train_worker(rank, world_size, exp_config):
    """Main training function for each process."""
    
    # Setup distributed training
    setup(rank, world_size)
    
    # Set device for this process
    device = torch.device(f"cuda:{rank}")
    
    # Use the SAME seed on all ranks so model init is identical
    SEED = int(exp_config.seed)
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    
    # Setup model and data
    model, train_loader, val_loader, test_loader, optimizer, scheduler, train_sampler, pos_weight, per_device_batch_size = setup_model_and_data(
        rank, world_size, exp_config, device
    )

    # Enable TF32 for matmul and cudnn to speed up training on Ampere+
    if torch.cuda.is_available():
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True

    # GradScaler for AMP - enable only for FP16 (bf16 does not use GradScaler)
    from torch import amp as torch_amp
    scaler = torch_amp.GradScaler("cuda", enabled=(torch.cuda.is_available() and not torch.cuda.is_bf16_supported()))
    
    # Create a consistent run name for checkpointing; wandb name comes from exp_config.run_name
    unique_run_name = f"{exp_config['config_path']}_layers-{exp_config['n_layers']}_hidden-{exp_config['hidden_channels']}_lr-{exp_config['lr']}_ddp_gloo"

    # Initialize wandb only on rank 0
    if rank == 0 and exp_config.wandb:
        config_dict = exp_config.copy()
        config_dict['device'] = torch.cuda.get_device_name() if torch.cuda.is_available() else 'cpu'
        config_dict['model'] = model.module.__class__.__name__
        config_dict['world_size'] = world_size
        config_dict['backend'] = 'gloo'
        config_dict['per_device_batch_size'] = per_device_batch_size
        wandb.init(
            project="GMAN_IC_P12_ablation",
            config=config_dict,
            settings=wandb.Settings(start_method='thread'),
            name=str(getattr(exp_config, 'run_name', unique_run_name)),
            reinit=True
        ) 
    
    # Training loop
    best_val_loss = float('inf')
    best_val_auc = float('-inf')
    best_val_auprc = float('-inf')
    
    # Track test metrics at best validation points
    best_test_at_val_loss = {}
    best_test_at_val_auc = {}
    best_test_at_val_auprc = {}
    
    if rank == 0:
        print(f"Started DDP training on {world_size} GPUs using gloo backend")
        print(f"Running experiment: {exp_config.exp_name}")
    
    for epoch in range(exp_config.epochs):
        # Set epoch for distributed sampler
        # Reshuffle balanced sampler each epoch
        train_sampler.set_epoch(epoch)
        
        # Learning rate warmup
        warmup_epochs = int(getattr(exp_config, 'warmup_epochs', 0))
        base_lr = float(exp_config.lr)
        if epoch < warmup_epochs:
            warmup_factor = float(epoch + 1) / float(max(1, warmup_epochs))
            new_lr = base_lr * warmup_factor
            for param_group in optimizer.param_groups:
                param_group['lr'] = new_lr
        current_lr = optimizer.param_groups[0]['lr']
        if rank == 0:
            print(f"Epoch {epoch}: LR={current_lr:.6f} (warmup_epochs={warmup_epochs})")
        
        # Match Raindrop's CrossEntropyLoss using a single-logit adapter
        # loss_fn = CrossEntropyFromLogit()
        # loss_fn = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weight, device=device, dtype=torch.float))
        loss_fn = nn.BCEWithLogitsLoss()
        
        # Training
        train_loss, train_acc, train_auc, train_auprc = train_epoch(
            epoch=epoch,
            model=model,
            dloader=train_loader,
            loss_fn=loss_fn,
            optimizer=optimizer,
            scheduler=None,  # Plateau will be driven by validation loss below
            device=device,
            writer=None,
            scaler=scaler,
            enable_plots=False,
        )
        # Ensure all ranks reached the end of training step before evaluation
        if dist.is_initialized():
            dist.barrier()

        # Validation and testing only on rank 0
        if rank == 0:
            val_loss, val_acc, val_auc, val_auprc = test_epoch(
                epoch=epoch,
                model=model,
                dloader=val_loader,
                loss_fn=loss_fn,
                device=device,
                writer=None,
                enable_plots=False,
            )
            test_loss, test_acc, test_auc, test_auprc = test_epoch(
                epoch=epoch,
                model=model,
                dloader=test_loader,
                loss_fn=loss_fn,
                device=device,
                writer=None,
                enable_plots=False,
            )
        else:
            # Placeholders for broadcasting
            val_loss = torch.tensor(0.0, device=device)
            val_acc = val_auc = val_auprc = 0.0
            test_loss = test_acc = test_auc = test_auprc = 0.0
        
        # Broadcast validation loss so all ranks step the Plateau scheduler identically
        if dist.is_initialized():
            metric_list = [0.0]
            if rank == 0:
                metric_list[0] = float(val_loss)
            dist.broadcast_object_list(metric_list, src=0)
            plateau_metric = float(metric_list[0])
        else:
            plateau_metric = float(val_loss)
        
        # Step Plateau scheduler with the (global) validation loss after warmup
        if epoch >= warmup_epochs:
            scheduler.step(plateau_metric)
        
        # Log metrics only on rank 0
        if rank == 0:
            print(
                f"Epoch {epoch}/{exp_config.epochs}: "
                f"Train Loss={train_loss:.4f}, Train AUC={train_auc:.4f}, Train AUPRC={train_auprc:.4f}"
            )
            print(
                f"Epoch {epoch}/{exp_config.epochs}: "
                f"Val Loss={val_loss:.4f}, Val AUC={val_auc:.4f}, Val AUPRC={val_auprc:.4f}"
            )
            print(
                f"Epoch {epoch}/{exp_config.epochs}: "
                f"Test Loss={test_loss:.4f}, Test AUC={test_auc:.4f}, Test AUPRC={test_auprc:.4f}"
            )
            
            if exp_config.wandb:
                wandb.log({
                    "epoch": epoch,
                    "train_loss": train_loss,
                    "train_acc": train_acc,
                    "train_auc": train_auc,
                    "train_auprc": train_auprc,
                    "val_loss": val_loss,
                    "val_acc": val_acc,
                    "val_auc": val_auc,
                    "val_auprc": val_auprc,
                    "test_loss": test_loss,
                    "test_acc": test_acc,
                    "test_auc": test_auc,
                    "test_auprc": test_auprc,
                    "lr": current_lr,
                })
            
            # Save checkpoints only on rank 0
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                best_test_at_val_loss = {
                    'test_loss': test_loss,
                    'test_acc': test_acc, 
                    'test_auc': test_auc,
                    'test_auprc': test_auprc
                }
                checkpoint_dir = f"{exp_config.model_checkpoints_dir}/ddp_gloo_{unique_run_name}"
                os.makedirs(checkpoint_dir, exist_ok=True)
                checkpoint_name = os.path.join(checkpoint_dir, f'best_params_by_val_loss.pth')
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.module.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict()
                }, checkpoint_name)
                if exp_config.wandb:
                    wandb.save(checkpoint_name)

            if val_auc > best_val_auc:
                best_val_auc = val_auc
                best_test_at_val_auc = {
                    'test_loss': test_loss,
                    'test_acc': test_acc,
                    'test_auc': test_auc, 
                    'test_auprc': test_auprc
                }
                checkpoint_dir = f"{exp_config.model_checkpoints_dir}/ddp_gloo_{unique_run_name}"
                os.makedirs(checkpoint_dir, exist_ok=True)
                checkpoint_name = os.path.join(checkpoint_dir, f'best_params_by_val_auc.pth')
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.module.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict()
                }, checkpoint_name)
                if exp_config.wandb:
                    wandb.save(checkpoint_name)

            if val_auprc > best_val_auprc:
                best_val_auprc = val_auprc
                best_test_at_val_auprc = {
                    'test_loss': test_loss,
                    'test_acc': test_acc,
                    'test_auc': test_auc,
                    'test_auprc': test_auprc
                }
                checkpoint_dir = f"{exp_config.model_checkpoints_dir}/ddp_gloo_{unique_run_name}"
                os.makedirs(checkpoint_dir, exist_ok=True)
                checkpoint_name = os.path.join(checkpoint_dir, f'best_params_by_val_auprc.pth')
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.module.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict()
                }, checkpoint_name)
                if exp_config.wandb:
                    wandb.save(checkpoint_name)
        
        # Synchronize before next epoch
        if dist.is_initialized():
            dist.barrier()
    
    # Final evaluation
    # Match Raindrop's CrossEntropyLoss at final eval
    loss_fn = CrossEntropyFromLogit()
    # Only rank 0 has loaders; run there and broadcast metrics for completeness
    if rank == 0:
        test_loss, test_acc, test_auc, test_auprc = test_epoch(
            epoch=exp_config.epochs,
            model=model,
            dloader=test_loader,
            loss_fn=loss_fn,
            device=device,
            writer=None
        )
    else:
        test_loss = test_acc = test_auc = test_auprc = 0.0
    
    if rank == 0:
        print(f"Final Test Loss={test_loss:.4f}, Final Test Accuracy={test_acc:.4f}, Final Test AUC={test_auc:.4f}, Final Test AUPRC={test_auprc:.4f}")
        
        # Print best test performance during training
        if best_test_at_val_auc:
            print(f"Best Test Loss={best_test_at_val_auc['test_loss']:.4f}, Best Test Accuracy={best_test_at_val_auc['test_acc']:.4f}, Best Test AUC={best_test_at_val_auc['test_auc']:.4f}, Best Test AUPRC={best_test_at_val_auc['test_auprc']:.4f}")
        else:
            print(f"Best Test Loss={test_loss:.4f}, Best Test Accuracy={test_acc:.4f}, Best Test AUC={test_auc:.4f}, Best Test AUPRC={test_auprc:.4f}")
        
        if exp_config.wandb:
            wandb.log({
                "final_test_loss": test_loss,
                "final_test_acc": test_acc,
                "final_test_auc": test_auc,
                "final_test_auprc": test_auprc
            })
            wandb.finish()
        
        # Save results to file
        results_dir = "training_results"
        os.makedirs(results_dir, exist_ok=True)
        
        # Create comprehensive results dictionary
        results = {
            "experiment_info": {
                "config_path": getattr(exp_config, 'config_path', 'unknown'),
                "exp_name": exp_config.exp_name,
                "gnan_mode": getattr(exp_config, 'gnan_mode', 'unknown'),
                "seed": getattr(exp_config, 'seed', 'unknown'),
                "n_layers": exp_config.n_layers,
                "hidden_channels": exp_config.hidden_channels,
                "batch_size": int(exp_config.batch_size),
                "val_batch_size": int(exp_config.val_batch_size),
                "test_batch_size": int(exp_config.test_batch_size),
                "per_device_batch_size": per_device_batch_size,
                "lr": exp_config.lr,
                "epochs": exp_config.epochs,
                "world_size": world_size,
                "backend": "gloo",
                "timestamp": datetime.now().strftime('%Y-%m-%d %H:%M:%S')
            },
            "final_metrics": {
                "test_loss": float(test_loss),
                "test_accuracy": float(test_acc),
                "test_auc": float(test_auc),
                "test_auprc": float(test_auprc)
            },
            "best_metrics": {
                "test_loss": float(best_test_at_val_auc.get('test_loss', test_loss)) if best_test_at_val_auc else float(test_loss),
                "test_accuracy": float(best_test_at_val_auc.get('test_acc', test_acc)) if best_test_at_val_auc else float(test_acc),
                "test_auc": float(best_test_at_val_auc.get('test_auc', test_auc)) if best_test_at_val_auc else float(test_auc),
                "test_auprc": float(best_test_at_val_auprc.get('test_auprc', test_auprc)) if best_test_at_val_auprc else float(test_auprc)
            },
            "validation_metrics": {
                "best_val_loss": float(best_val_loss),
                "best_val_auc": float(best_val_auc),
                "best_val_auprc": float(best_val_auprc)
            }
        }
        
        # Create filename with timestamp and key info
        timestamp_str = datetime.now().strftime('%Y%m%d_%H%M%S')
        config_name = os.path.basename(getattr(exp_config, 'config_path', 'unknown')).replace('.yaml', '')
        seed_str = getattr(exp_config, 'seed', 'unknown')
        results_filename = f"results_ddp_gloo_{config_name}_seed{seed_str}_gpus{world_size}_{timestamp_str}.json"
        results_path = os.path.join(results_dir, results_filename)
        
        # Save results to JSON file
        with open(results_path, 'w') as f:
            json.dump(results, f, indent=2)
        
        print(f"\n📁 Results saved to: {results_path}")
        print(f"   Best Test AUC: {results['best_metrics']['test_auc']:.4f}")
        print(f"   Final Test AUC: {results['final_metrics']['test_auc']:.4f}")
    
    # Cleanup
    cleanup()


def main():
    """Main function to launch distributed training."""
    
    # Get experiment configuration
    exp_config = get_config()
    
    # Get number of GPUs
    world_size = torch.cuda.device_count()
    
    if world_size < 2:
        print("Warning: Less than 2 GPUs available. Consider using single GPU training.")
        return
    
    print(f"Starting distributed training on {world_size} GPUs using gloo backend")
    print(f"Configuration: {exp_config}")
    
    # Launch processes
    mp.spawn(
        train_worker,
        args=(world_size, exp_config),
        nprocs=world_size,
        join=True
    )

if __name__ == '__main__':
    main() 