import torch
import torch.nn as nn
import random
from tqdm import tqdm  
import torch.nn.functional as F
from collections import defaultdict
import pandas as pd
import os
import torch.distributed as dist
from torch.utils.data import DistributedSampler, DataLoader
from torch.nn.parallel import DistributedDataParallel as DDP
import math
import numpy as np
import sys
import datetime
import json
import traceback as import_traceback
import traceback
import torch.profiler as profiler
from torch.profiler import profile, record_function, ProfilerActivity
from contextlib import nullcontext

import torch

import time
import gc  # Added for explicit garbage collection when handling memory errors

DEBUGGING = False
ENABLE_PROFILING = False  # Set to True only when you need to profile
PROFILE_STEPS = [0, 1, 10, 100]  # Only profile these specific steps

def is_main_process():
    """Check if this is the main process in distributed training"""
    return not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0


class SupConLoss(nn.Module):
    """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
    It also supports the unsupervised contrastive loss in SimCLR"""
    def __init__(self, temperature=0.07, contrast_mode='all',
                 base_temperature=0.07):
        super(SupConLoss, self).__init__()
        self.temperature = temperature
        self.contrast_mode = contrast_mode
        self.base_temperature = base_temperature

    def forward(self, features, labels=None, mask=None, **kwargs):
        """Compute loss for model. If both `labels` and `mask` are None,
        it degenerates to SimCLR unsupervised loss:
        https://arxiv.org/pdf/2002.05709.pdf

        Args:
            features: hidden vector of shape [bsz, n_views, ...].
            labels: ground truth of shape [bsz].
            mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
                has the same class as sample i. Can be asymmetric.
        Returns:
            A loss scalar.
        """
        if len(features) < 2:
            return 0
        device = features.device

        if len(features.shape) < 3:
            raise ValueError('`features` needs to be [bsz, n_views, ...],'
                             'at least 3 dimensions are required')
        if len(features.shape) > 3:
            features = features.view(features.shape[0], features.shape[1], -1)

        batch_size = features.shape[0]
        if labels is not None and mask is not None:
            raise ValueError('Cannot define both `labels` and `mask`')
        elif labels is None and mask is None:
            mask = torch.eye(batch_size, dtype=torch.float32).to(device)
        elif labels is not None:
            labels = labels.contiguous().view(-1, 1)
            if labels.shape[0] != batch_size:
                raise ValueError('Num of labels does not match num of features')
            mask = torch.eq(labels, labels.T).float().to(device)
        else:
            mask = mask.float().to(device)

        contrast_count = features.shape[1]
        contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
        if self.contrast_mode == 'one':
            anchor_feature = features[:, 0]
            anchor_count = 1
        elif self.contrast_mode == 'all':
            anchor_feature = contrast_feature
            anchor_count = contrast_count
        else:
            raise ValueError('Unknown mode: {}'.format(self.contrast_mode))

        # compute logits
        anchor_dot_contrast = torch.div(
            torch.matmul(anchor_feature, contrast_feature.T),
            self.temperature)
        # for numerical stability
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()
        logits = torch.clamp(logits, min=-100, max=100)


        # tile mask
        mask = mask.repeat(anchor_count, contrast_count)
        # mask-out self-contrast cases
        logits_mask = torch.scatter(
            torch.ones_like(mask),
            1,
            torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
            0
        )
        mask = mask * logits_mask

        # compute log_prob
        exp_logits = torch.exp(logits) * logits_mask
        # Add epsilon to avoid log(0)
        epsilon = 1e-12
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + epsilon)

        # compute mean of log-likelihood over positive
        # modified to handle edge cases when there is no positive pair
        # for an anchor point. 
        # Edge case e.g.:- 
        # features of shape: [4,1,...]
        # labels:            [0,1,1,2]
        # loss before mean:  [nan, ..., ..., nan] 
        mask_pos_pairs = mask.sum(1)
        mask_pos_pairs = torch.where(mask_pos_pairs < 1e-6, 1, mask_pos_pairs)
        mean_log_prob_pos = (mask * log_prob).sum(1) / mask_pos_pairs

        # loss
        loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
        loss = loss.view(anchor_count, batch_size).mean()

        return loss


class VAETrainer:
    def __init__(
        self,
        model,
        table_vectorizer,
        train_dataloader,
        val_dataloader,
        optimizer,
        scheduler=None,
        device="cuda",
        distributed=False,
        rank=0,
        dwa=False,
        world_size=1,
        init_beta=0.0,          # Initial beta value
        max_beta=1.0,          # Maximum beta value
        max_beta_steps=1000,   # Steps to reach maximum beta
        temperature=0.5,       # Temperature parameter for DWA
        vectorizer_warmup_epochs=0,  # How many epochs to warmup the vectorizer before training main VAE.
        validation_interval=None,  # Steps between validations (None for epoch-based)
        save_interval=50,      # Steps between saves (None for epoch-based)
        early_stop_patience=20,  # Steps between validations (None for epoch-based)
        interval_type="epoch",  # "epoch" or "step"
        scheduler_interval="epoch",  # "epoch" or "step" for scheduler updates
        max_grad_norm=1.0,    # Add this parameter
        gradient_accumulation_steps=1,  # Add this parameter
        mask_ratio=0.0,       # Ratio of columns to mask for masked autoencoder setup
        contrastive_weight=0.0,  # Weight for contrastive loss (0 = disabled)
        contrastive_temperature=0.07,  # Temperature parameter for contrastive loss
        base_contrastive_temperature=0.07,  # Base temperature for contrastive loss
        contrastive_dim=128,  # Output dimension for contrastive projection head
        log_dir=None,
        checkpoint_dir=None,
        load_scheduler_state=True,  # Whether to load scheduler state from checkpoint
    ):
        self.device = device
        self.distributed = distributed
        self.rank = rank
        self.world_size = world_size
        
        # Move model and vectorizer to device first
        self.model = model.to(self.device)
        self.table_vectorizer = table_vectorizer.to(self.device)
        
        # Wrap dataloaders for distributed training if needed
        self.train_dataloader = self.get_distributed_dataloader(train_dataloader) if distributed else train_dataloader
        self.val_dataloader = self.get_distributed_dataloader(val_dataloader) if distributed else val_dataloader

        # Create infinite iterator only for training, not for validation
        self.train_iter = iter(self.train_dataloader)
        # Remove the validation iterator since we'll use epoch-based validation
        self.val_iter = self.val_dataloader if self.val_dataloader else None
        if self.val_iter is not None:
            print(len(self.val_iter))
        self.optimizer = optimizer
        self.scheduler = scheduler
        
        # Add control for scheduler state loading
        self.load_scheduler_state = load_scheduler_state

        # Setup distributed training if needed
        if distributed:
            self.setup_distributed()

        # Beta-VAE parameters
        self.init_beta = init_beta
        self.max_beta = max_beta
        self.max_beta_steps = max_beta_steps
        self.beta = init_beta  # Start with initial beta
        
        # For tracking losses
        self.best_recon_loss = float('inf')

        # Add DWA-related attributes
        self.dwa = dwa
        self.temperature = temperature
        self.loss_weights = {
            "numerical": 1.0,
            "categorical": 0.5,
            "text": 1.0,
            "datetime": 1.0
        }
        self.prev_losses = {
            "numerical": None,
            "categorical": None,
            "text": None,
            "datetime": None
        }

        # Add vectorizer warmup attribute
        self.vectorizer_warmup_epochs = vectorizer_warmup_epochs
        self.current_epoch = 0  # Add to track current epoch

        # Add step-based learning parameters
        self.validation_interval = validation_interval
        self.save_interval = save_interval
        self.early_stop_patience = early_stop_patience
        self.interval_type = interval_type
        self.scheduler_interval = scheduler_interval
        
        self.global_step = 0
        self.steps_without_improvement = 0
        self.best_step_recon_loss = float('inf')
        # Create persistent tensor for global step synchronization
        self.global_step_tensor = torch.zeros(1, dtype=torch.long, device=self.device)

        # Add gradient clipping parameter
        self.max_grad_norm = max_grad_norm

        # Add gradient accumulation parameter
        self.gradient_accumulation_steps = gradient_accumulation_steps

        # Add mask ratio parameter
        self.mask_ratio = mask_ratio

        # Add contrastive loss parameters
        self.contrastive_weight = contrastive_weight
        self.contrastive_temperature = contrastive_temperature
        self.base_contrastive_temperature = base_contrastive_temperature
        self.contrastive_dim = contrastive_dim
        
        if self.contrastive_weight > 0:
            self.contrastive_loss_fn = SupConLoss(
                temperature=self.contrastive_temperature,
                base_temperature=self.base_contrastive_temperature
            )
            
            self.projection_head = None    #  built in setup_model 

        self.setup_model()

        # Add log directory and checkpoint directory
        self.log_dir = log_dir
        self.checkpoint_dir = checkpoint_dir

    def setup_distributed(self):
        """
        Set up distributed training components (dataloaders and models)
        """
        # Modify dataloaders for distributed training
        self.train_dataloader = self.get_distributed_dataloader(self.train_dataloader)
        self.val_dataloader = self.get_distributed_dataloader(self.val_dataloader)

    def get_distributed_dataloader(self, loader):
        """
        Wrap the DataLoader with a DistributedSampler for distributed training.
        """
        # no shuffling for distributed training.
        sampler = DistributedSampler(loader.dataset, num_replicas=self.world_size, rank=self.rank, shuffle=False)
        # Keep the original batch size - each GPU will process full batches
        return DataLoader(loader.dataset, sampler=sampler, batch_size=loader.batch_size, num_workers=4, collate_fn=loader.collate_fn)

    def setup_model(self):
        if self.distributed:
            local_rank = int(os.environ.get("LOCAL_RANK", "0"))
            torch.cuda.set_device(local_rank)
            self.device = torch.device(f"cuda:{local_rank}")

        # (1) move the main modules
        self.model           = self.model.to(self.device)
        self.table_vectorizer = self.table_vectorizer.to(self.device)

        # (2) wrap with DDP if necessary
        if self.distributed:
            self.model           = DDP(self.model, device_ids=[local_rank])
            self.table_vectorizer = DDP(self.table_vectorizer, device_ids=[local_rank])

        # (3) projection head: create lazily *after* device is fixed
        if self.contrastive_weight > 0 and (not hasattr(self, "projection_head") or self.projection_head is None):
            in_dim = self.model.d_latent_len * self.model.d_latent_width
            self.projection_head = self._create_projection_head(in_dim).to(self.device)

        # (4) DDP-wrap the head as well
        if self.distributed and self.contrastive_weight > 0:
            self.projection_head = DDP(self.projection_head, device_ids=[local_rank])

    def compute_kl_divergence(self, mu, logvar):
        """
        Compute KL divergence between N(mu, var) and N(0, 1).
        It includes beta factor.
        """
        kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        kl_div = kl_div / mu.shape[-1]
        return self.beta * kl_div

    def update_beta(self):
        """
        Update beta value using linear scheduling based on global step.
        Beta increases linearly from init_beta to max_beta over max_beta_steps.
        """
        if self.global_step >= self.max_beta_steps:
            self.beta = self.max_beta
        else:
            # Linear interpolation between init_beta and max_beta
            progress = self.global_step / self.max_beta_steps
            self.beta = self.init_beta + (self.max_beta - self.init_beta) * progress
        
        #if is_main_process():
        #    print(f"\nUpdated beta to {self.beta:.6f} at step {self.global_step}")

    def get_model(self):
        return self.model

    def _update_loss_weights(self, current_losses):
        """
        Update loss weights using Dynamic Weight Averaging (DWA).
        
        Args:
            current_losses (dict): Dictionary of current epoch losses by type
        """
        # Skip if this is the first epoch
        if any(v is None for v in self.prev_losses.values()):
            self.prev_losses = {k: v for k, v in current_losses.items() if v > 0}
            return
        
        # Calculate relative loss changes
        rel_changes = {}
        total_change = 0
        n_active = 0
        
        for loss_type in self.loss_weights.keys():
            if current_losses.get(loss_type, 0) > 0 and self.prev_losses.get(loss_type, 0) > 0:
                # Calculate relative change in loss
                rel_change = current_losses[loss_type] / (self.prev_losses[loss_type] + 1e-8)
                rel_changes[loss_type] = rel_change
                total_change += math.exp(rel_change / self.temperature)
                n_active += 1
        
        # Update weights if we have active losses
        if n_active > 0:
            for loss_type in self.loss_weights.keys():
                if loss_type in rel_changes:
                    # Calculate new weight
                    weight = n_active * math.exp(rel_changes[loss_type] / self.temperature) / (total_change + 1e-8)
                    # Smooth the transition
                    self.loss_weights[loss_type] = 0.8 * self.loss_weights[loss_type] + 0.2 * weight
        
        # Update previous losses
        self.prev_losses = {k: v for k, v in current_losses.items() if v > 0}
        
        if is_main_process():
            print("\nDWA Weights:", {k: f"{v:.4f}" for k, v in self.loss_weights.items()})

    def create_contrastive_labels(self, df_batch, config):
        """
        Create labels for contrastive learning based on the target column (assumed to be the last column).
        
        Args:
            df_batch: The batch of data as a DataFrame
            config: The configuration dictionary containing variable information
            
        Returns:
            labels: Tensor of labels for contrastive learning
        """
        if not config or 'variables' not in config or len(config['variables']) == 0:
            # If no variables defined, return None
            return None
            
        # Get the last variable (assumed to be the target)
        target_var = config['variables'][-1]
        target_name = target_var['variable_name']
        
        if target_name not in df_batch.columns:
            # If target not in data, return None
            return None
            
        # Process based on variable type
        var_type = target_var['variable_type']
        
        if var_type == 'categorical':
            # For categorical, use category indices directly
            categories = target_var.get('categories', [])
            if not categories:
                return None
            
            try:    
                # Get values as list to handle any data type
                target_values = df_batch[target_name].tolist()
                
                # Convert values to indices
                indices = []
                for val in target_values:
                    try:
                        # Try to find the category index
                        idx = categories.index(val) if val in categories else -1
                        indices.append(idx)
                    except (ValueError, TypeError):
                        # If comparison fails, use -1 as a fallback
                        indices.append(-1)
                
                # Filter out -1 indices which represent unknown categories
                valid_indices = [i for i in indices if i != -1]
                
                # If no valid indices were found, return None
                if not valid_indices:
                    if is_main_process():
                        print(f"No valid category indices found for contrastive learning")
                    return None
                
                # Create labels tensor
                labels = torch.tensor(indices, dtype=torch.long, device=self.device)
                
                # Replace -1 with the most common valid index as a fallback
                if -1 in indices:
                    most_common = max(set(valid_indices), key=valid_indices.count) if valid_indices else 0
                    labels[labels == -1] = most_common
                    
                return labels
            except Exception as e:
                if is_main_process():
                    print(f"Categorical indexing failed: {e}")
                return None
            
        elif var_type == 'numerical':
            # For numerical, create bins (default 10)
            num_bins = 10
            
            try:
                # Get values as a list first
                values_list = df_batch[target_name].tolist()
                
                # Check if we have any non-numeric or sequence values
                has_complex_values = False
                numeric_values = []
                
                for val in values_list:
                    if isinstance(val, (list, tuple, dict, str)) or hasattr(val, '__iter__') and not isinstance(val, (int, float, bool, np.number)):
                        has_complex_values = True
                        break
                    
                    try:
                        # Try to convert to float
                        numeric_values.append(float(val))
                    except (ValueError, TypeError):
                        has_complex_values = True
                        break
                
                if has_complex_values:
                    # For complex data, hash the values and use the hash for binning
                    hashed_values = []
                    for val in values_list:
                        try:
                            # Convert various types to string and hash
                            hashed_val = hash(str(val)) % 10000  # limit to moderate values
                            hashed_values.append(hashed_val)
                        except:
                            hashed_values.append(0)  # fallback
                    
                    # Use the hashed values for binning
                    target_values = np.array(hashed_values, dtype=np.float64)
                else:
                    # Use the numeric values directly
                    target_values = np.array(numeric_values, dtype=np.float64)
                
                # Continue with binning as before
                if len(target_values) == 0:
                    return None
                
                # Get min/max values for binning using np.nanmin/np.nanmax to handle NaN values
                min_val = np.nanmin(target_values)
                max_val = np.nanmax(target_values)
                
                # Check if min and max are equal or if all values are NaN
                if min_val == max_val or np.isnan(min_val) or np.isnan(max_val):
                    # All values are the same or all are NaN
                    labels = torch.zeros(len(target_values), dtype=torch.long, device=self.device)
                else:
                    # Create bin edges
                    bin_edges = np.linspace(min_val, max_val, num_bins + 1)
                    
                    # Handle NaN values by assigning to a separate bin
                    bin_indices = np.zeros_like(target_values, dtype=np.int64)
                    
                    # Assign non-NaN values to bins (last bin includes the max value)
                    non_nan_mask = ~np.isnan(target_values)
                    if np.any(non_nan_mask):
                        bin_indices[non_nan_mask] = np.digitize(target_values[non_nan_mask], bin_edges[:-1]) - 1
                    
                    # Ensure indices are within bounds
                    bin_indices = np.clip(bin_indices, 0, num_bins - 1)
                    
                    labels = torch.tensor(bin_indices, dtype=torch.long, device=self.device)
                
                return labels
            except Exception as e:
                # If numerical binning fails, return None
                if is_main_process():
                    print(f"Numerical binning failed: {e}")
                return None
        else:
            # For other types (like text, datetime), not supported for now
            return None

    def _create_projection_head(self, input_dim):
        """
        Create a simple MLP projection head for contrastive learning.
        
        Args:
            input_dim: The input dimension (l*d from latent)
            
        Returns:
            nn.Sequential: A simple MLP projection head
        """
        # Create a 2-layer MLP with ReLU activation
        projection_head = nn.Sequential(
            nn.Linear(input_dim, input_dim // 2),
            nn.ReLU(),
            nn.Linear(input_dim // 2, self.contrastive_dim)
        ).to(self.device)
        
        # Add projection head parameters to the optimizer
        projection_params = list(projection_head.parameters())
        
        # Create a new parameter group with the same hyperparameters as the existing ones
        if hasattr(self.optimizer, 'param_groups') and len(self.optimizer.param_groups) > 0:
            # Get the hyperparameters from the first parameter group
            param_group = {k: v for k, v in self.optimizer.param_groups[0].items() if k != 'params'}
            param_group['params'] = projection_params
            self.optimizer.add_param_group(param_group)
        else:
            # Fallback: just add the parameters without specifying hyperparameters
            self.optimizer.add_param_group({'params': projection_params})
        
        if is_main_process():
            print(f"Created contrastive projection head: {input_dim} → {input_dim // 2} → {self.contrastive_dim}")
            print(f"Added {sum(p.numel() for p in projection_params):,} parameters to optimizer")
        
        return projection_head

    def _run_batch(self, batch, is_train=True):
        """Process a single batch of data"""
        # Add timing info for debugging distributed performance
        start_time = time.time()
        timing_info = {}
        
        try:
            config = batch['config']
            df_batch = batch['df_batch']
            batch_size = df_batch.shape[0]
            
            timing_info['data_extraction'] = time.time() - start_time
            model_start = time.time()
            
            if is_train:
                self.model.train()
                self.table_vectorizer.train()
            else:
                self.model.eval()
                self.table_vectorizer.eval()

            # Get the actual model/vectorizer
            vectorizer = self.table_vectorizer.module if self.distributed else self.table_vectorizer
            model = self.model.module if self.distributed else self.model
            
            timing_info['model_setup'] = time.time() - model_start
            vectorize_start = time.time()
            
            # Only profile if enabled and on specific steps
            should_profile = ENABLE_PROFILING and is_train and self.global_step in PROFILE_STEPS and is_main_process()
            
            # Context manager for profiling (or a dummy context manager if not profiling)
            prof_context = profile(
                activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
                record_shapes=True,
                profile_memory=True,
                with_stack=True,
                with_flops=True
            ) if should_profile else nullcontext()
            
            # Record function context (or dummy if not profiling)
            def record_fn(name):
                return record_function(name) if should_profile else nullcontext()
            
            with prof_context as prof:
                # Vectorize table and move to device
                with record_fn("vectorize"):
                    # Enable batch processing for better performance
                    table_tensor = vectorizer.vectorize(df_batch, config, batch_processing=True).to(self.device)
                    n_cols = table_tensor.shape[1]
                    
                    # Create attention mask
                    attention_mask = torch.ones(batch_size, n_cols, dtype=torch.bool, device=self.device)
                    
                    # Apply masking if mask_ratio > 0
                    if self.mask_ratio > 0 and is_train:
                        # For each example in the batch, randomly select columns to mask
                        for i in range(batch_size):
                            # Calculate number of columns to mask
                            num_cols_to_mask = int(n_cols * self.mask_ratio)
                            if num_cols_to_mask > 0:
                                # Randomly select column indices to mask
                                mask_indices = torch.randperm(n_cols, device=self.device)[:num_cols_to_mask]
                                # Set mask to 0 for selected columns
                                attention_mask[i, mask_indices] = False
                
                timing_info['vectorize'] = time.time() - vectorize_start
                forward_start = time.time()
                
                # Forward pass - modified to handle warmup
                if self.current_epoch < self.vectorizer_warmup_epochs:
                    with record_fn("warmup_forward"):
                        decoded_embedding = table_tensor
                else:
                    # Encode metadata and column information
                    with record_fn("encode_meta"):
                        meta, column_names, dtypes, dist = vectorizer.encode_meta(config)
                        meta = meta.unsqueeze(0).repeat(batch_size, 1).to(self.device)
                        column_names = column_names.unsqueeze(0).repeat(batch_size, 1, 1).to(self.device)
                        dtypes = dtypes.unsqueeze(0).repeat(batch_size, 1, 1).to(self.device)
                        dist = dist.unsqueeze(0).repeat(batch_size, 1, 1).to(self.device)
                        
                        # Check for NaN values in metadata tensors
                        if torch.isnan(meta).any():
                            raise ValueError("NaN values detected in metadata tensor")
                        if torch.isnan(table_tensor).any():
                            raise ValueError("NaN values detected in table_tensor tensor")
                        if torch.isnan(column_names).any():
                            raise ValueError("NaN values detected in column_names tensor")
                        if torch.isnan(dtypes).any():
                            raise ValueError("NaN values detected in dtypes tensor")
                        if torch.isnan(dist).any():
                            print("Debugging dist")
                            print(dist)
                            raise ValueError("NaN values detected in distribution tensor")
                    
                    with record_fn("model_encode"):
                        latent, mu, logvar = model.encode(table_tensor, column_names, dtypes, meta, attention_mask, dist=dist)
                        
                        # Check for NaN values in encoder outputs
                        if torch.isnan(latent).any():
                            raise ValueError("NaN values detected in latent representation")
                        if torch.isnan(mu).any():
                            raise ValueError("NaN values detected in mu (mean) tensor")
                        if torch.isnan(logvar).any():
                            raise ValueError("NaN values detected in logvar tensor")
                    
                    with record_fn("model_decode"):
                        decoded_embedding = model.decode(latent, column_names, dtypes, meta, dist=dist)
                        
                        # Check for NaN values in decoder output
                        if torch.isnan(decoded_embedding).any():
                            raise ValueError("NaN values detected in decoded embeddings")

                # Inverse vectorize for loss computation
                with record_fn("inverse_vectorize"):
                    reconstructed_values = vectorizer.inverse_vectorize(
                        decoded_embedding, config, mode="train", target_df=df_batch, batch_processing=True
                    )
            
            # Save profiler results to file if profiling was enabled
            if should_profile and prof is not None:
                profile_dir = os.path.join("debug_trainer", "profiler")
                os.makedirs(profile_dir, exist_ok=True)
                
                # Save trace file for Chrome trace viewer
                prof.export_chrome_trace(os.path.join(profile_dir, f"trace_step_{self.global_step}.json"))
                
                # Print summary to log file
                with open(os.path.join(profile_dir, f"profile_summary_step_{self.global_step}.txt"), "w") as f:
                    f.write(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20))
                    f.write("\n\nCPU Time Summary:\n")
                    f.write(prof.key_averages().table(sort_by="cpu_time_total", row_limit=20))
                    
                # Print memory summary
                with open(os.path.join(profile_dir, f"memory_summary_step_{self.global_step}.txt"), "w") as f:
                    f.write(prof.key_averages().table(sort_by="self_cuda_memory_usage", row_limit=20))
                
                # Print stack traces for top operations
                with open(os.path.join(profile_dir, f"stack_traces_step_{self.global_step}.txt"), "w") as f:
                    f.write(prof.key_averages(group_by_stack_n=5).table(sort_by="cuda_time_total", row_limit=10))
            
            # Compute losses by variable type with DWA weights
            loss_by_type = {"numerical": [], "categorical": [], "text": [], "datetime": []}
            accuracies = []
            current_losses = {k: 0.0 for k in loss_by_type.keys()}
            n_vars_by_type = {k: 0 for k in loss_by_type.keys()}

            for idx, var_config in enumerate(config['variables']):
                var_type = var_config['variable_type']
                col_name = var_config['variable_name']

                # Extract prediction and loss from reconstructed_values
                # Handle different return formats from vectorizer
                column_result = reconstructed_values[idx]
                
                # Check if the result is a tuple with prediction and loss
                pred, loss = column_result

                # Accumulate loss by type
                if var_type == 'numerical':
                    loss_by_type['numerical'].append(loss)
                    n_vars_by_type['numerical'] += 1
                elif var_type == 'categorical':
                    loss_by_type['categorical'].append(loss)
                    n_vars_by_type['categorical'] += 1
                    
                    # Calculate accuracy for categorical variables
                    target = torch.tensor(
                        [var_config['categories'].index(val) for val in df_batch[col_name]], 
                        device=self.device
                    )
                    pred_classes = torch.argmax(pred, dim=1)
                    accuracy = (pred_classes == target).float().mean()
                    accuracies.append(accuracy)
                elif var_type == 'text':
                    loss_by_type['text'].append(loss)
                    n_vars_by_type['text'] += 1
                elif var_type == 'datetime':
                    loss_by_type['datetime'].append(loss)
                    n_vars_by_type['datetime'] += 1

            # Aggregate losses with DWA weights
            total_loss = 0
            losses = {}
            
            # Storage for gradient debugging
            type_losses = {}
            type_grads_before_weighting = {}
            type_grads_after_weighting = {}
            
            # Calculate average loss per type
            for var_type, type_losses_list in loss_by_type.items():
                if type_losses_list:
                    avg_type_loss = torch.stack(type_losses_list).mean()
                    
                    # Store unweighted loss
                    type_losses[var_type] = avg_type_loss
                    
                    # For debugging gradients - calculate for numerical and categorical only
                    if is_train and DEBUGGING and var_type in ['numerical', 'categorical']:
                        # Need to get gradient before weighting
                        if avg_type_loss.requires_grad:
                            # Create a copy of model parameters for gradient checking
                            model_params = list(model.parameters())
                            # Zero existing gradients
                            self.optimizer.zero_grad(set_to_none=True)
                            # Backward pass on unweighted loss
                            avg_type_loss.backward(retain_graph=True)
                            # Store gradients before weighting
                            type_grads_before_weighting[var_type] = []
                            for param in model_params:
                                if param.grad is not None:
                                    # Store a copy of the gradient
                                    type_grads_before_weighting[var_type].append(param.grad.detach().clone().cpu())
                                else:
                                    type_grads_before_weighting[var_type].append(None)
                            # Zero gradients for next computation
                            self.optimizer.zero_grad(set_to_none=True)
                    
                    # Apply weighting
                    weighted_loss = avg_type_loss * self.loss_weights[var_type]
                    total_loss += weighted_loss
                    losses[f"{var_type}_loss"] = avg_type_loss.item()
                    current_losses[var_type] = avg_type_loss.item()
                    
                    # For debugging gradients - after weighting
                    if is_train and DEBUGGING and var_type in ['numerical', 'categorical']:
                        if weighted_loss.requires_grad:
                            # Zero existing gradients
                            self.optimizer.zero_grad(set_to_none=True)
                            # Backward pass on weighted loss
                            weighted_loss.backward(retain_graph=True)
                            # Store gradients after weighting
                            type_grads_after_weighting[var_type] = []
                            for param in model_params:
                                if param.grad is not None:
                                    # Store a copy of the gradient
                                    type_grads_after_weighting[var_type].append(param.grad.detach().clone().cpu())
                                else:
                                    type_grads_after_weighting[var_type].append(None)
                            # Zero gradients for next computation
                            self.optimizer.zero_grad(set_to_none=True)

            # Update DWA weights if in training mode
            if is_train and self.dwa:
                self._update_loss_weights(current_losses)

            # Add contrastive loss if enabled and in train mode
            if self.contrastive_weight > 0 and is_train and self.current_epoch >= self.vectorizer_warmup_epochs:
                # Create contrastive labels
                contrastive_labels = self.create_contrastive_labels(df_batch, config)
                
                if contrastive_labels is not None and batch_size > 1:
                    # Transform latent from (batch_size, l, d) to (batch_size, 1, l*d)
                    l, d = latent.shape[1], latent.shape[2]
                    flat_latent = latent.reshape(batch_size, l*d)
                    
                    # Create projection head if it doesn't exist yet
                    #if self.projection_head is None:
                    #    self.projection_head = self._create_projection_head(l*d)
                    
                    # Project features through MLP head
                    projected_features = self.projection_head(flat_latent)
                    
                    # Reshape for contrastive loss: (batch_size, 1, contrastive_dim)
                    contrastive_features = projected_features.unsqueeze(1)
                    
                    # Compute contrastive loss
                    cont_loss = self.contrastive_loss_fn(contrastive_features, labels=contrastive_labels)
                    
                    # Add to total loss and store (apply weight)
                    weighted_cont_loss = self.contrastive_weight * cont_loss
                    total_loss += weighted_cont_loss
                    losses["cont_loss"] = weighted_cont_loss.item()

            # Add KL divergence
            if self.current_epoch >= self.vectorizer_warmup_epochs:
                kl_loss = self.compute_kl_divergence(mu, logvar)
                total_loss += kl_loss
                losses["kl_loss"] = kl_loss.item()

            if accuracies:
                losses["accuracy"] = torch.stack(accuracies).mean().item()
                
            # Check for NaN values in losses
            for loss_name, loss_value in losses.items():
                if isinstance(loss_value, (int, float)) and math.isnan(loss_value):
                    raise ValueError(f"NaN detected in {loss_name} loss. Training cannot continue.")

            if is_train:
                # Scale loss by accumulation steps
                scaled_loss = total_loss / self.gradient_accumulation_steps
                scaled_loss.backward()
                
                # Save gradient information if debugging is enabled
                if DEBUGGING and (type_grads_before_weighting or type_grads_after_weighting):
                    # Create directory for gradient debug
                    grad_debug_dir = os.path.join("debug_trainer", "gradient_debug")
                    os.makedirs(grad_debug_dir, exist_ok=True)
                    
                    # Calculate gradient statistics (norm, mean, max) for comparison
                    grad_stats = {
                        "before_weighting": {},
                        "after_weighting": {},
                        "ratio": {}
                    }
                    
                    # Process gradients for each type
                    for var_type in ['numerical', 'categorical']:
                        if var_type in type_grads_before_weighting and var_type in type_grads_after_weighting:
                            grad_stats["before_weighting"][var_type] = {}
                            grad_stats["after_weighting"][var_type] = {}
                            grad_stats["ratio"][var_type] = {}
                            
                            # Calculate statistics per parameter group
                            for i, (before_grad, after_grad) in enumerate(zip(
                                type_grads_before_weighting[var_type], 
                                type_grads_after_weighting[var_type]
                            )):
                                if before_grad is not None and after_grad is not None:
                                    # Calculate L2 norm
                                    before_norm = torch.norm(before_grad).item()
                                    after_norm = torch.norm(after_grad).item()
                                    norm_ratio = after_norm / (before_norm + 1e-8)
                                    
                                    # Calculate mean and max
                                    before_mean = torch.mean(torch.abs(before_grad)).item()
                                    after_mean = torch.mean(torch.abs(after_grad)).item()
                                    mean_ratio = after_mean / (before_mean + 1e-8)
                                    
                                    before_max = torch.max(torch.abs(before_grad)).item()
                                    after_max = torch.max(torch.abs(after_grad)).item()
                                    max_ratio = after_max / (before_max + 1e-8)
                                    
                                    # Store statistics
                                    grad_stats["before_weighting"][var_type][f"param_group_{i}"] = {
                                        "norm": before_norm,
                                        "mean": before_mean,
                                        "max": before_max
                                    }
                                    grad_stats["after_weighting"][var_type][f"param_group_{i}"] = {
                                        "norm": after_norm,
                                        "mean": after_mean,
                                        "max": after_max
                                    }
                                    grad_stats["ratio"][var_type][f"param_group_{i}"] = {
                                        "norm_ratio": norm_ratio,
                                        "mean_ratio": mean_ratio,
                                        "max_ratio": max_ratio
                                    }
                    
                    # Calculate overall gradient statistics for each variable type
                    for var_type in ['numerical', 'categorical']:
                        if var_type in type_grads_before_weighting and var_type in type_grads_after_weighting:
                            # Combine all gradients for this type
                            before_grads_flat = []
                            after_grads_flat = []
                            
                            for before_grad, after_grad in zip(
                                type_grads_before_weighting[var_type], 
                                type_grads_after_weighting[var_type]
                            ):
                                if before_grad is not None and after_grad is not None:
                                    before_grads_flat.append(before_grad.flatten())
                                    after_grads_flat.append(after_grad.flatten())
                            
                            if before_grads_flat and after_grads_flat:
                                before_flat = torch.cat(before_grads_flat)
                                after_flat = torch.cat(after_grads_flat)
                                
                                grad_stats["before_weighting"][var_type]["overall"] = {
                                    "norm": torch.norm(before_flat).item(),
                                    "mean": torch.mean(torch.abs(before_flat)).item(),
                                    "max": torch.max(torch.abs(before_flat)).item()
                                }
                                grad_stats["after_weighting"][var_type]["overall"] = {
                                    "norm": torch.norm(after_flat).item(),
                                    "mean": torch.mean(torch.abs(after_flat)).item(),
                                    "max": torch.max(torch.abs(after_flat)).item()
                                }
                                
                                # Calculate ratios
                                before_norm = grad_stats["before_weighting"][var_type]["overall"]["norm"]
                                after_norm = grad_stats["after_weighting"][var_type]["overall"]["norm"]
                                norm_ratio = after_norm / (before_norm + 1e-8)
                                
                                before_mean = grad_stats["before_weighting"][var_type]["overall"]["mean"]
                                after_mean = grad_stats["after_weighting"][var_type]["overall"]["mean"]
                                mean_ratio = after_mean / (before_mean + 1e-8)
                                
                                before_max = grad_stats["before_weighting"][var_type]["overall"]["max"]
                                after_max = grad_stats["after_weighting"][var_type]["overall"]["max"]
                                max_ratio = after_max / (before_max + 1e-8)
                                
                                grad_stats["ratio"][var_type]["overall"] = {
                                    "norm_ratio": norm_ratio,
                                    "mean_ratio": mean_ratio,
                                    "max_ratio": max_ratio,
                                    "weight": self.loss_weights[var_type]
                                }
                    
                    # Save gradient statistics to file
                    stats_file = os.path.join(grad_debug_dir, f"grad_stats_step_{self.global_step}_rank_{self.rank}.json")
                    with open(stats_file, 'w') as f:
                        json.dump(grad_stats, f, indent=2)
                    
                    # Create summary file with analysis
                    if 'numerical' in grad_stats["ratio"] and 'categorical' in grad_stats["ratio"]:
                        num_ratio = grad_stats["ratio"]["numerical"]["overall"]["norm_ratio"]
                        cat_ratio = grad_stats["ratio"]["categorical"]["overall"]["norm_ratio"]
                        
                        imbalance_ratio = num_ratio / (cat_ratio + 1e-8)
                        
                        summary_file = os.path.join(grad_debug_dir, f"grad_imbalance_summary.txt")
                        with open(summary_file, 'a') as f:
                            timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
                            f.write(f"[{timestamp}] Step: {self.global_step}, Epoch: {self.current_epoch}\n")
                            f.write(f"  Numerical loss: {current_losses['numerical']:.6f}, "
                                   f"weight: {self.loss_weights['numerical']:.4f}, "
                                   f"gradient norm ratio: {num_ratio:.6f}\n")
                            f.write(f"  Categorical loss: {current_losses['categorical']:.6f}, "
                                   f"weight: {self.loss_weights['categorical']:.4f}, "
                                   f"gradient norm ratio: {cat_ratio:.6f}\n")
                            f.write(f"  Imbalance ratio (numerical/categorical): {imbalance_ratio:.6f}\n")
                            f.write(f"  Analysis: {'Imbalance detected' if abs(imbalance_ratio - 1.0) > 0.5 else 'Balance acceptable'}\n\n")
                
                # Only step optimizer after accumulating gradients
                if (self.global_step + 1) % self.gradient_accumulation_steps == 0:
                    torch.nn.utils.clip_grad_norm_(
                        self.model.parameters(), 
                        self.max_grad_norm
                    )
                    torch.nn.utils.clip_grad_norm_(
                        self.table_vectorizer.parameters(), 
                        self.max_grad_norm
                    )
                    self.optimizer.step()
                    self.optimizer.zero_grad(set_to_none=True)
                
            # Record total time
            timing_info['total_time'] = time.time() - start_time
            
            # Add timing info to debug log if debugging
            if DEBUGGING:
                debug_dir = os.path.join("debug_trainer", "timing_debug")
                os.makedirs(debug_dir, exist_ok=True)
                timing_file = os.path.join(debug_dir, f"timing_rank{self.rank}.log")
                with open(timing_file, "a") as f:
                    timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")
                    timing_str = ", ".join([f"{k}: {v:.4f}s" for k, v in timing_info.items()])
                    f.write(f"[{timestamp}] Step: {self.global_step}, {timing_str}\n")
            
            losses["total_loss"] = total_loss.item()

            if is_train:
                # First ensure all processes completed their batch
                if self.distributed:
                    dist.barrier()
                
                # Now each process can increment
                self.global_step += 1
                if self.distributed:
                    # Use all_reduce to ensure all processes are aligned
                    self.global_step_tensor[0] = self.global_step
                    dist.all_reduce(self.global_step_tensor, op=dist.ReduceOp.MAX)
                    self.global_step = int(self.global_step_tensor.item())

            # Explicitly delete intermediate tensors
            del table_tensor, attention_mask
            if 'meta' in locals():
                del meta, column_names, dtypes, dist
            if self.current_epoch >= self.vectorizer_warmup_epochs and 'latent' in locals():
                del latent, mu, logvar
            del decoded_embedding, reconstructed_values

            return losses

        except Exception as e:
            # Create debug directory with timestamp and device/rank info
            os.makedirs("debug_trainer", exist_ok=True)
            timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
            device_info = f"rank{self.rank}_device{self.device}"
            debug_dir = os.path.join("debug_trainer", f"error_{timestamp}_{device_info}")
            os.makedirs(debug_dir, exist_ok=True)

            # Save config
            config_path = os.path.join(debug_dir, "error_config.json")
            with open(config_path, 'w') as f:
                json.dump(config, f, indent=2)

            # Save DataFrame
            df_path = os.path.join(debug_dir, "error_df.csv")
            df_batch.to_csv(df_path, index=False)

            # Save error information with additional device details
            error_info = {
                "error_type": type(e).__name__,
                "error_message": str(e),
                "traceback": import_traceback.format_exc(),
                "is_training": is_train,
                "global_step": self.global_step,
                "current_epoch": self.current_epoch,
                "rank": self.rank,
                "device": str(self.device),
                "world_size": self.world_size,
                "distributed": self.distributed
            }
            error_path = os.path.join(debug_dir, "error_info.json")
            with open(error_path, 'w') as f:
                json.dump(error_info, f, indent=2)

            print(f"\nError occurred during batch processing on rank {self.rank}, device {self.device}.")
            print(f"Debug files saved to: {debug_dir}")
            print(f"Error: {str(e)}")

            # Re-raise the exception
            raise

    def train(self, num_epochs=None, max_steps=None, checkpoint_path="best_model.pth", resume_from_checkpoint=False, skip_iters=0):
        """Main training loop with support for step-based training"""
        # Validate input parameters
        if (num_epochs is None and max_steps is None) or (num_epochs is not None and max_steps is not None):
            raise ValueError("Exactly one of num_epochs or max_steps must be provided")

        # Improved skip iterations for distributed training
        if skip_iters > 0:
            self.train_dataloader = self.prepare_dataset_with_skipped_iterations(skip_iters)
            self.train_iter = iter(self.train_dataloader)

        # Convert epochs to steps if needed
        if num_epochs is not None and self.interval_type == "epoch":
            max_steps = num_epochs * len(self.train_dataloader)
            self.validation_interval = len(self.train_dataloader)
            self.save_interval = len(self.train_dataloader)
            self.early_stop_patience = len(self.train_dataloader)
            self.scheduler_interval = 'step'
            if is_main_process():
                print(f"Converting {num_epochs} epochs to {max_steps} steps")

        # Initialize training state
        best_loss = float('inf')
        best_model_state = None
        loss_history = []
        self.steps_without_improvement = 0

        # Resume from checkpoint if requested
        if resume_from_checkpoint and os.path.exists(checkpoint_path):
            try:
                if is_main_process():
                    print(f"Loading checkpoint from {checkpoint_path}")
                # Load checkpoint with backward compatibility
                checkpoint = torch.load(
                    checkpoint_path,
                    map_location=self.device,
                    weights_only=False  # Allow loading all types of data
                )
                self._restore_checkpoint(checkpoint)
                if is_main_process():
                    print(f"Resumed training from step {self.global_step}")
            except Exception as e:
                if is_main_process():
                    print(f"Failed to load checkpoint: {str(e)}")
                if not isinstance(e, FileNotFoundError):
                    print("Starting fresh training session")

        # Initialize progress bar for main process
        if is_main_process():
            pbar = tqdm(total=max_steps, 
                       desc='Training', 
                       ncols=200,  # Increased width further for all metrics
                       position=0,
                       initial=self.global_step, 
                       leave=True)
            # Initialize postfix dict with all metrics
            postfix_dict = {
                'tot': float('inf'),    # Shortened names for space
                'num': float('inf'),
                'cat': float('inf'),
                'txt': float('inf'),    # Shortened 'text'
                'dt': float('inf'),
                "cont": float('inf'),
                'kl': 0.0,
                'val': float('inf'),
                'lr': self.optimizer.param_groups[0]['lr'],
                'β': self.beta          # Using unicode beta symbol
            }
            pbar.set_postfix(postfix_dict)

        # Get total GPU memory for percentage calculation
        if torch.cuda.is_available():
            total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory
        
        def should_clear_cache():
            """Check if memory usage exceeds 80% threshold"""
            if not torch.cuda.is_available():
                return False
            memory_allocated = torch.cuda.memory_allocated(self.device)
            memory_reserved = torch.cuda.memory_reserved(self.device)
            memory_used = max(memory_allocated, memory_reserved)  # Use the larger value
            memory_percentage = memory_used / total_gpu_memory * 100
            
            if DEBUGGING and is_main_process() and memory_percentage >= 80:
                print(f"\nGPU Memory Usage: {memory_percentage:.1f}% ({memory_used / 1024**3:.2f}GB / {total_gpu_memory / 1024**3:.2f}GB)")
            
            return memory_percentage >= 80

        # Main training loop
        while self.global_step < max_steps:
            try:
                # Update current epoch for logging and distributed sampler
                self.current_epoch = self.global_step // len(self.train_dataloader)
                
                # Get next batch
                skip_counter = getattr(self, 'skip_counter', 0)
                max_consecutive_skips = 10  # Adjust based on your tolerance
                
                try:
                    batch = self.get_train_batch()
                    # Reset skip counter on success
                    self.skip_counter = 0
                except Exception as e:
                    self.skip_counter = skip_counter + 1
                    
                    # Always log the error, but with different severity
                    error_msg = f"[Rank {self.rank}] Failed to get train batch: {e}"
                    if self.skip_counter >= max_consecutive_skips:
                        print(f"ERROR: {error_msg} - {self.skip_counter} consecutive failures. Training may be compromised.")
                        # Optional: raise an exception if too many consecutive failures
                        # raise RuntimeError(f"Too many consecutive batch failures: {self.skip_counter}")
                    elif self.skip_counter > 1:
                        print(f"WARNING: {error_msg} - {self.skip_counter} consecutive failures")
                    else:
                        print(f"INFO: {error_msg} - Skipping batch and continuing")
                    
                    # If using distributed training, make sure all processes stay in sync
                    if self.distributed:
                        # Broadcast skip counter to ensure all processes have same value
                        skip_tensor = torch.tensor([self.skip_counter], device=self.device)
                        dist.broadcast(skip_tensor, 0)
                        
                    continue
                
                # Debug logging to track batch processing across ranks
                if DEBUGGING:
                    debug_dir = os.path.join("debug_trainer", "batch_debug")
                    os.makedirs(debug_dir, exist_ok=True)
                    debug_file = os.path.join(debug_dir, "batch_processing.log")
                    with open(debug_file, "a") as f:
                        timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")
                        batch_key = batch.get('key', 'unknown_key')
                        f.write(f"[{timestamp}] Rank: {self.rank}, Global Step: {self.global_step}, "
                               f"Epoch: {self.current_epoch}, Batch Key: {batch_key}\n")
                
                # Run training step
                batch_losses = self._run_batch(batch, is_train=True)
                
                # Store metrics
                metrics = {
                    "step": self.global_step,
                    "epoch": self.current_epoch,
                    **batch_losses
                }
                loss_history.append(metrics)

                # Update progress bar if main process
                if is_main_process():
                    postfix_dict.update({
                        'tot': f"{batch_losses['total_loss']:.4f}",      # Reduced decimals
                        'num': f"{batch_losses.get('numerical_loss', 0):.4f}",
                        'cat': f"{batch_losses.get('categorical_loss', 0):.4f}",
                        'txt': f"{batch_losses.get('text_loss', 0):.4f}",
                        'dt': f"{batch_losses.get('datetime_loss', 0):.4f}",
                        "cont": f"{batch_losses.get('cont_loss', 0):.4f}",
                        'kl': f"{batch_losses.get('kl_loss', 0):.4f}",
                        'lr': f"{self.optimizer.param_groups[0]['lr']:.2e}",
                        'β': f"{self.beta:.4f}"
                    })
                    pbar.set_postfix(postfix_dict)
                    pbar.update(1)

                # Step-based validation check
                if self.validation_interval and self.global_step % self.validation_interval == 0:
                    if self.val_dataloader is not None:
                        val_losses = self._run_validation()
                        monitoring_loss = val_losses['total_loss']
                        monitoring_recon_loss = sum(v for k, v in val_losses.items() 
                                           if k not in ['kl_loss', 'total_loss', 'accuracy'])
                        
                        # Store validation metrics
                        metrics.update({f"val_{k}": v for k, v in val_losses.items()})
                        
                        # Print detailed validation report
                        if is_main_process():
                            pbar.write("\nValidation Losses:")
                            pbar.write(f"{'Loss Type':<15} {'Value':>10}")
                            pbar.write("-" * 30)
                            for k, v in val_losses.items():
                                pbar.write(f"{k:<15} {v:>10.6f}")
                            pbar.write("-" * 30)
                        
                            # Update progress bar with validation loss
                            postfix_dict['val'] = f"{monitoring_recon_loss:.4f}"
                            pbar.set_postfix(postfix_dict)
                    else:
                        monitoring_loss = batch_losses['total_loss']
                        monitoring_recon_loss = sum(v for k, v in batch_losses.items() 
                                           if k not in ['kl_loss', 'total_loss', 'accuracy'])
                        
                        # Print detailed training loss report when no validation set
                        if is_main_process():
                            pbar.write("\nTraining Losses:")
                            pbar.write(f"{'Loss Type':<15} {'Value':>10}")
                            pbar.write("-" * 30)
                            for k, v in batch_losses.items():
                                pbar.write(f"{k:<15} {v:>10.6f}")
                            pbar.write("-" * 30)

                    # Check for improvement
                    if monitoring_recon_loss < best_loss:
                        best_loss = monitoring_recon_loss
                        best_model_state = self._get_model_state(self.current_epoch, best_loss)
                        self.steps_without_improvement = 0
                        
                        # Save best model checkpoint
                        if is_main_process():
                            self._save_checkpoint(best_model_state, checkpoint_path)
                            pbar.write(f"\nSaved best model at step {self.global_step} with loss {best_loss:.6f}")
                    else:
                        self.steps_without_improvement += 1
                        
                    # Early stopping check
                    if self.early_stop_patience and self.steps_without_improvement >= self.early_stop_patience:
                        if is_main_process():
                            pbar.write(f"\nEarly stopping triggered after {self.early_stop_patience} steps without improvement")
                        break
                    
                if self.scheduler and self.scheduler_interval == "step":
                    self.scheduler.step()

                self.update_beta()
                
                # Regular checkpoint saving
                if self.save_interval and self.global_step % self.save_interval == 0 and is_main_process():
                    current_state = self._get_model_state(self.current_epoch, 
                        monitoring_loss if 'monitoring_loss' in locals() else batch_losses['total_loss'])
                    save_path = f"{os.path.splitext(checkpoint_path)[0]}_step_{self.global_step}.pth"
                    self._save_checkpoint(current_state, save_path)
                    if is_main_process():
                        print(f"\nSaved checkpoint at step {self.global_step}")

                # Clear memory only if usage is high
                if should_clear_cache():
                    if hasattr(torch.cuda, 'empty_cache'):
                        torch.cuda.empty_cache()
                        if DEBUGGING and is_main_process():
                            print(f"Cache cleared at step {self.global_step}")
                
                # Keep the explicit tensor deletions
                del batch
                
                # Keep efficient grad clearing
                self.optimizer.zero_grad(set_to_none=True)
                
            except Exception as e:
                # Enhanced memory error handling
                err_msg = str(e).lower()

                memory_error_substrings = [
                    "out of memory",          # generic / PyTorch
                    "cuda out of memory",      # explicit CUDA OOM
                    "memoryerror",            # Python MemoryError
                    "failed to allocate",     # allocation failures
                    "cannot allocate memory", # numpy / OS errors
                    "could not allocate"      # alternative phrasing
                ]

                is_memory_error = isinstance(e, MemoryError) or any(sub in err_msg for sub in memory_error_substrings)

                # Torch specific OutOfMemoryError (sub-class of RuntimeError in recent PyTorch versions)
                try:
                    from torch.cuda import OutOfMemoryError as TorchOOMError  # PyTorch ≥ 2.1
                    if isinstance(e, TorchOOMError):
                        is_memory_error = True
                except Exception:
                    # torch.cuda.OutOfMemoryError might not exist in older versions – ignore
                    pass

                if is_memory_error:
                    # Attempt to free up GPU and CPU memory
                    if hasattr(torch.cuda, 'empty_cache'):
                        torch.cuda.empty_cache()
                    gc.collect()

                    # Clear any accumulated gradients to prevent further memory usage
                    self.optimizer.zero_grad(set_to_none=True)

                    print(f"[Rank {self.rank}] Memory error encountered (step {self.global_step}). Cleared cache and skipping this step. Error detail: {e}")

                    # Optionally delete batch if it exists in the current scope
                    try:
                        del batch
                    except NameError:
                        pass

                    continue  # Skip to next training iteration

                # Re-raise non-memory related exceptions
                raise e

        # Close progress bar
        if is_main_process():
            pbar.close()

        # Training completed - save final state and restore best model
        if is_main_process():   
            # Save final checkpoint
            final_state = self._get_model_state(self.current_epoch, 
                monitoring_loss if 'monitoring_loss' in locals() else batch_losses['total_loss'])
            final_path = f"{os.path.splitext(checkpoint_path)[0]}_final.pth"
            self._save_checkpoint(final_state, final_path)
            print(f"\nSaved final model at step {self.global_step}")

            # Restore best model
            if best_model_state:
                self._restore_checkpoint(best_model_state)
                print(f"\nRestored best model with loss {best_loss:.6f}")

        # Return training history
        return pd.DataFrame(loss_history) if loss_history else None

    def _run_validation(self):
        """Run validation with more detailed process tracking"""
        if is_main_process():
            print(f"[Rank {self.rank}] Starting validation...")

        self.model.eval()
        self.table_vectorizer.eval()

        val_losses = defaultdict(float)
        processed_batches = 0

        if self.distributed:
            dist.barrier()

        if self.distributed:
            self.val_dataloader.sampler.set_epoch(self.current_epoch)

        with torch.no_grad():
            val_iterator = self.val_dataloader
            print(val_iterator)
            if is_main_process():
                 from tqdm import tqdm
                 val_iterator = tqdm(self.val_dataloader, desc=f"Epoch {self.current_epoch} Validation")

            for batch_idx, batch in enumerate(val_iterator):
                try:
                    batch_losses = self._run_batch(batch, is_train=False)

                    for k, v in batch_losses.items():
                        if isinstance(v, torch.Tensor):
                            v = v.item()
                        if isinstance(v, (float, int)) and math.isfinite(v):
                             val_losses[k] += v
                        elif DEBUGGING:
                             print(f"[Rank {self.rank}] WARNING: Non-finite or invalid loss value detected for key '{k}' in batch {batch_idx}. Value: {v}. Skipping accumulation for this key in this batch.")

                    processed_batches += 1
                except Exception as e:
                    if DEBUGGING:
                        print(f"[Rank {self.rank}] ERROR processing validation batch {batch_idx}: {e}")

        if self.distributed:
            if DEBUGGING:
                print(f"[Rank {self.rank}] Completed validation loop. Processed {processed_batches} batches. Waiting at barrier post-loop.")
            dist.barrier()
            if DEBUGGING:
                print(f"[Rank {self.rank}] Passed barrier post-loop.")

        if processed_batches == 0:
            if DEBUGGING:
                print(f"[Rank {self.rank}] WARNING: No batches processed during validation.")
            avg_losses = {}
        else:
            avg_losses = {k: v / processed_batches for k, v in val_losses.items()}

        local_keys = sorted(list(avg_losses.keys()))
        if DEBUGGING:
            print(f"[Rank {self.rank}] Calculated local avg_losses. Local Keys: {local_keys}")

        final_losses = {}
        if self.distributed:
            if DEBUGGING:
                print(f"[Rank {self.rank}] Starting loss synchronization")

            all_local_keys = [None] * self.world_size
            if DEBUGGING:
                print(f"[Rank {self.rank}] BEFORE all_gather_object for keys")
            dist.all_gather_object(all_local_keys, local_keys)
            if DEBUGGING:
                print(f"[Rank {self.rank}] AFTER all_gather_object for keys. Received: {all_local_keys}")

            union_keys = set()
            for keys in all_local_keys:
                union_keys.update(keys)
            synchronized_keys = sorted(list(union_keys))
            if DEBUGGING:
                print(f"[Rank {self.rank}] Union of keys for reduction: {synchronized_keys}")
                print(f"[Rank {self.rank}] BEFORE barrier pre-reduction loop")
            dist.barrier()
            if DEBUGGING:
                print(f"[Rank {self.rank}] AFTER barrier pre-reduction loop")

            for k in synchronized_keys:
                if DEBUGGING:
                    print(f"[Rank {self.rank}] Processing key: {k}")

                local_value = avg_losses.get(k, 0.0)

                is_finite = isinstance(local_value, (int, float)) and math.isfinite(local_value)
                if DEBUGGING:
                    print(f"[Rank {self.rank}] Value for key '{k}': {local_value} (Is Finite: {is_finite})")
                if not is_finite:
                    if DEBUGGING:
                        print(f"[Rank {self.rank}] WARNING: Non-finite local average value detected for key '{k}'. Using 0.0 for reduction.")
                    local_value = 0.0

                tensor = torch.tensor(local_value, device=self.device)

                if DEBUGGING:
                    print(f"[Rank {self.rank}] BEFORE all_reduce for key: {k}")
                dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
                if DEBUGGING:
                    print(f"[Rank {self.rank}] AFTER all_reduce for key: {k}")

                final_losses[k] = tensor.item() / self.world_size
                if DEBUGGING:
                    print(f"[Rank {self.rank}] Synchronized value for key '{k}': {final_losses[k]}")

            if DEBUGGING:
                print(f"[Rank {self.rank}] Finished loss reduction loop.")

        else:
             final_losses = avg_losses

        if self.distributed:
            if DEBUGGING:
                print(f"[Rank {self.rank}] BEFORE final barrier")
            dist.barrier()
            if DEBUGGING:
                print(f"[Rank {self.rank}] AFTER final barrier")

        if is_main_process():
            print(f"Validation Summary (Epoch {self.current_epoch}): {final_losses}")

        if DEBUGGING:
            print(f"[Rank {self.rank}] Validation completed!")

        return final_losses

    def _get_model_state(self, epoch, loss):
        """Get current model state for checkpointing"""
        state = {
            # Model states
            'model': self.model.module.state_dict() if self.distributed else self.model.state_dict(),
            'vectorizer': self.table_vectorizer.module.state_dict() if self.distributed else self.table_vectorizer.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'scheduler': self.scheduler.state_dict() if self.scheduler else None,
            
            # Training progress
            'epoch': epoch,
            'step': self.global_step,
            'best_loss': loss,
            
            # DWA states
            'loss_weights': self.loss_weights,
            'prev_losses': self.prev_losses,
            
            # RNG states for reproducibility
            'torch_rng_state': torch.get_rng_state().cpu().clone(),
            'cuda_rng_state': torch.cuda.get_rng_state().cpu().clone() if torch.cuda.is_available() else None,
            'numpy_rng_state': np.random.get_state()[1].tobytes() if 'numpy' in sys.modules else None,
            'python_rng_state': random.getstate(),
            
            # Training configuration
            'interval_type': self.interval_type,
            'scheduler_interval': self.scheduler_interval,
            'validation_interval': self.validation_interval,
            'save_interval': self.save_interval,
            'early_stop_patience': self.early_stop_patience,
            'vectorizer_warmup_epochs': self.vectorizer_warmup_epochs,
            'current_epoch': self.current_epoch
        }
        
        # Save projection head if it exists
        if hasattr(self, 'projection_head') and self.projection_head is not None:
            state['projection_head'] = self.projection_head.state_dict()
        
        return state

    def _save_checkpoint(self, state, path):
        """Save checkpoint to disk with proper serialization"""
        # Create directory if it doesn't exist
        os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
        
        # Convert numpy arrays to tensors before saving
        processed_state = {}
        for key, value in state.items():
            if isinstance(value, np.ndarray):
                processed_state[key] = torch.from_numpy(value)
            else:
                processed_state[key] = value
        
        torch.save(processed_state, path)

    def _restore_checkpoint(self, state):
        """Restore model from state dict with proper deserialization"""
        try:
            # Restore model states
            if self.distributed:
                self.model.module.load_state_dict(state['model'])
                self.table_vectorizer.module.load_state_dict(state['vectorizer'])
            else:
                self.model.load_state_dict(state['model'])
                self.table_vectorizer.load_state_dict(state['vectorizer'])
            
            # Handle projection head before optimizer to ensure param groups alignment
            if 'projection_head' in state and state['projection_head'] is not None:
                # Get input dimension from the first layer's weight
                weight_key = next(key for key in state['projection_head'].keys() if 'weight' in key)
                in_dim = state['projection_head'][weight_key].shape[1]
                
                # Create projection head with correct dimensions if needed
                if not hasattr(self, 'projection_head') or self.projection_head is None:
                    self.projection_head = self._create_projection_head(in_dim)
                
                # Load state dict into projection head
                self.projection_head.load_state_dict(state['projection_head'])
                
                if is_main_process():
                    print(f"Loaded projection head from checkpoint (input dim: {in_dim})")
            elif self.contrastive_weight > 0 and (not hasattr(self, 'projection_head') or self.projection_head is None):
                # If contrastive learning is active but no projection head in checkpoint,
                # we need to create one but can't properly initialize it yet
                # This will be done during the first forward pass
                if is_main_process():
                    print("Contrastive learning active but no projection head in checkpoint.")
                    print("A new projection head will be created during the first forward pass.")
            
            # Now load optimizer state after projection head is set up
            def _align_optimizer_param_groups(saved_optim_state):
                """Ensure the current optimizer has the same number of param groups as the saved state.

                If the checkpoint contains *more* groups, we will try to create corresponding
                groups in the live optimiser by attaching parameters that are not yet
                registered (typically projection-head params).
                If it contains fewer groups, we simply skip loading the optimiser – you lose
                momentum but training can continue.
                """
                cur_groups = self.optimizer.param_groups
                saved_groups = saved_optim_state.get('param_groups', [])

                cur_len = len(cur_groups)
                saved_len = len(saved_groups)

                if saved_len == cur_len:
                    return True  # already aligned

                if saved_len < cur_len:
                    # Too many groups in live optimiser – safest is to bail out
                    if is_main_process():
                        print("Optimizer has more parameter groups than checkpoint. "
                              "Will skip loading optimiser state.")
                    return False

                # We need to *add* groups so that len matches
                # Gather all params currently tracked to avoid duplicates
                tracked_ids = {id(p) for group in cur_groups for p in group['params']}

                # Iterate over the extra param groups in the checkpoint and create them
                for extra_pg in saved_groups[cur_len:]:
                    # Identify candidate parameters that are *not* yet in any group
                    extra_params = []
                    for pname, param in self.__dict__.items():
                        if isinstance(param, nn.Parameter):
                            if id(param) not in tracked_ids:
                                extra_params.append(param)
                                tracked_ids.add(id(param))
                    # If we have a projection head, prefer its params
                    if hasattr(self, 'projection_head') and self.projection_head is not None:
                        for p in self.projection_head.parameters():
                            if id(p) not in tracked_ids:
                                extra_params.append(p)
                                tracked_ids.add(id(p))

                    if not extra_params:
                        # Could not find unique parameters to assign; abort alignment
                        if is_main_process():
                            print("Unable to find unused parameters for new optimizer param group; skipping optimizer state load.")
                        return False  # abort alignment

                    # Build new group dictionary copying hyper-parameters except the param list
                    new_pg = {k: v for k, v in extra_pg.items() if k != 'params'}
                    new_pg['params'] = extra_params
                    self.optimizer.add_param_group(new_pg)

                if is_main_process():
                    print(f"Added {saved_len - cur_len} parameter group(s) to optimizer to match checkpoint.")
                return True

            saved_optim_state = state['optimizer']
            aligned = _align_optimizer_param_groups(saved_optim_state)
            optim_state_loaded = False
            if aligned:
                try:
                    self.optimizer.load_state_dict(saved_optim_state)
                    optim_state_loaded = True
                except ValueError as e:
                    if is_main_process():
                        print(f"WARNING: Still could not load optimizer state after alignment: {str(e)}")
                        print("Continuing with freshly initialized optimizer parameters.")
                        traceback.print_exc()
            else:
                if is_main_process():
                    print("Optimizer state was not loaded due to irrecoverable param-group mismatch.")

            # Load scheduler state only if optimizer state was successfully loaded (to avoid param-group mismatch)
            if optim_state_loaded and self.scheduler and state['scheduler'] is not None and self.load_scheduler_state:
                try:
                    self.scheduler.load_state_dict(state['scheduler'])
                except ValueError as e:
                    if is_main_process():
                        print(f"WARNING: Could not load scheduler state: {str(e)}")
                        print("Scheduler will continue with fresh state.")
                        traceback.print_exc()
            elif optim_state_loaded and self.scheduler and state['scheduler'] is not None and not self.load_scheduler_state:
                if is_main_process():
                    print("Skipping scheduler state loading as requested (load_scheduler_state=False)")
                    print("This is useful when changing max_steps between training runs.")
                    print(f"Resetting scheduler to step {state['step']} under new learning rate schedule.")
                    for i in range(state['step']):
                        self.scheduler.step()

            # Restore training progress
            self.global_step = state['step']
            self.current_epoch = state['epoch']
            
            # Restore DWA states if present
            if 'loss_weights' in state:
                self.loss_weights = state['loss_weights']
            if 'prev_losses' in state:
                self.prev_losses = state['prev_losses']

            # Restore RNG states
            if 'torch_rng_state' in state:
                torch.set_rng_state(state['torch_rng_state'].cpu())
            if 'cuda_rng_state' in state and torch.cuda.is_available():
                torch.cuda.set_rng_state(state['cuda_rng_state'].cpu())
            if 'numpy_rng_state' in state and state['numpy_rng_state'] is not None:
                np_state = np.random.get_state()
                np_state = (np_state[0], np.frombuffer(state['numpy_rng_state'], dtype=np.uint32), *np_state[2:])
                np.random.set_state(np_state)
            if 'python_rng_state' in state:
                random.setstate(state['python_rng_state'])

            if is_main_process():
                print(f"Checkpoint restored from step {self.global_step}, epoch {self.current_epoch}")
                
        except Exception as e:
            if is_main_process():
                print(f"Failed to load checkpoint: {str(e)}")
                traceback.print_exc()  # Print full traceback for better debugging
            if not isinstance(e, FileNotFoundError):
                print("Starting fresh training session")

    def get_next_batch(self, iterator, dataloader):
        """
        Get next batch from iterator, recreating it if exhausted.
        """
        try:
            batch = next(iterator)
        except StopIteration:
            # Synchronize processes before recreating iterator
            if self.distributed:
                dist.barrier()  # Ensure all processes finished their epoch
                dataloader.sampler.set_epoch(self.current_epoch)
            # Create new iterator
            iterator = iter(dataloader)
            batch = next(iterator)
        return batch, iterator

    def get_train_batch(self):
        """
        Get next training batch, handling iterator recreation.
        """
        if DEBUGGING and self.distributed:
            start_wait = time.time()
        
        batch, self.train_iter = self.get_next_batch(self.train_iter, self.train_dataloader)
        
        if DEBUGGING and self.distributed:
            wait_time = time.time() - start_wait
            if wait_time > 1.0:  # Log only if wait time is significant (>1s)
                debug_dir = os.path.join("debug_trainer", "sync_debug")
                os.makedirs(debug_dir, exist_ok=True)
                sync_file = os.path.join(debug_dir, f"sync_rank{self.rank}.log")
                with open(sync_file, "a") as f:
                    timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")
                    f.write(f"[{timestamp}] Step: {self.global_step}, Wait time: {wait_time:.4f}s\n")
        
        return batch

    def get_val_batch(self):
        """
        Get a validation batch for specific use cases outside of regular validation.
        Note: This is not used during the standard validation phase, which processes
        the entire validation set once per validation run.
        
        Returns:
            A validation batch or None if validation dataloader is not available
        """
        if self.val_dataloader is None:
            return None
        
        # Create a temporary iterator for a single batch
        try:
            val_iter = iter(self.val_dataloader)
            return next(val_iter)
        except StopIteration:
            return None

    def prepare_dataset_with_skipped_iterations(self, skip_iters):
        """
        Prepare the dataset by skipping the first skip_iters iterations across all processes.
        
        Args:
            skip_iters (int): Number of iterations to skip
            
        Returns:
            torch.utils.data.DataLoader: Updated dataloader with skipped samples
        """
        if skip_iters <= 0:
            return self.train_dataloader
        
        if is_main_process():
            print(f"Preparing to skip first {skip_iters} iterations across all processes...")
        
        # Get the original dataset and properties
        original_dataset = self.train_dataloader.dataset
        batch_size = self.train_dataloader.batch_size
        total_dataset_size = len(original_dataset)
        
        # Calculate total samples to skip
        total_samples_to_skip = skip_iters * batch_size * self.world_size
        
        if total_samples_to_skip >= total_dataset_size:
            raise ValueError(f"Cannot skip {total_samples_to_skip} samples as dataset only contains {total_dataset_size} samples")
        
        # For distributed training with no shuffle, we need to understand which indices this process handles
        if self.distributed:
            # Calculate indices this process would normally handle
            indices_per_process = []
            for rank in range(self.world_size):
                # The non-shuffled distributed sampler assigns indices as:
                # rank 0: 0, world_size, 2*world_size, ...
                # rank 1: 1, world_size+1, 2*world_size+1, ...
                process_indices = list(range(rank, total_dataset_size, self.world_size))
                indices_per_process.append(process_indices)
            
            # Calculate which indices to skip for each process
            skipped_indices = []
            for rank_indices in indices_per_process:
                # Take the first skip_iters * batch_size indices from each process
                skipped_indices.extend(rank_indices[:skip_iters * batch_size])
            
            # Sort skipped indices for clarity
            skipped_indices.sort()
            
            # Remaining indices are all indices except skipped ones
            # Use set to reduce time complexity
            all_indices = list(range(total_dataset_size))
            skipped_indices_set = set(skipped_indices)
            remaining_indices = [i for i in all_indices if i not in skipped_indices_set]
        else:
            # For non-distributed training, simply skip the first n samples
            skipped_indices = list(range(total_samples_to_skip))
            remaining_indices = list(range(total_samples_to_skip, total_dataset_size))
        
        # Create subset with remaining indices
        remaining_dataset = torch.utils.data.Subset(original_dataset, remaining_indices)
        
        # Recreate dataloader with the subset
        if self.distributed:
            # For distributed training with the subset
            new_sampler = DistributedSampler(
                remaining_dataset, 
                num_replicas=self.world_size, 
                rank=self.rank, 
                shuffle=False
            )
            
            new_dataloader = torch.utils.data.DataLoader(
                remaining_dataset, 
                batch_size=batch_size,
                sampler=new_sampler,
                num_workers=self.train_dataloader.num_workers,
                collate_fn=self.train_dataloader.collate_fn
            )
        else:
            # For non-distributed training
            new_dataloader = torch.utils.data.DataLoader(
                remaining_dataset, 
                batch_size=batch_size,
                shuffle=False,
                num_workers=self.train_dataloader.num_workers,
                collate_fn=self.train_dataloader.collate_fn
            )
        
        if is_main_process():
            print(f"Skipped {len(skipped_indices)} samples. Remaining dataset size: {len(remaining_indices)}")
        
        return new_dataloader