#!/usr/bin/env python
# coding=utf-8

"""
Fine-tuning BERT models for GLUE tasks using optimal LoRA configurations.
Optimized for memory efficiency and computational performance while maintaining accuracy.
"""
import argparse
import logging
import math
import os
import random
import json
import time
import sys
import gc
from functools import partial
from pathlib import Path
import numpy as np
import datasets
import torch
from datasets import load_dataset, load_metric
import evaluate
from torch.utils.data import DataLoader
import torch.profiler as profiler
import transformers
from accelerate import Accelerator
from transformers import (
    AdamW,
    AutoConfig,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    DataCollatorWithPadding,
    PretrainedConfig,
    SchedulerType,
    default_data_collator,
    get_scheduler,
    set_seed,
)

from trainer.optimizer import LoRAOptimizer
from models.optimal_lora import prepare_model_for_glue
from utils.logging_utils import setup_logger, measure_batch_inference_time, measure_evaluation_time
import loralib as lora

# Import progressive pruning modules
from pruning.progressive_pruning import ProgressivePruningManager
from pruning import pruning_utils

# Control flags for optimized execution
LOG_FLOPS_INTERVAL = 50  # How often to log computation metrics during training
MAX_OPTIMIZATION_SAMPLES = 500  # Maximum samples to use for LoRA optimization

# Define task mapping
task_to_keys = {
    "cola": ("sentence", None),
    "mnli": ("premise", "hypothesis"),
    "mrpc": ("sentence1", "sentence2"),
    "qnli": ("question", "sentence"),
    "qqp": ("question1", "question2"),
    "rte": ("sentence1", "sentence2"),
    "sst2": ("sentence", None),
    "stsb": ("sentence1", "sentence2"),
    "wnli": ("sentence1", "sentence2"),
}

logger = logging.getLogger(__name__)

def parse_args():
    """Parse command line arguments."""
    parser = argparse.ArgumentParser(description="Fine-tune BERT-based models with optimal LoRA for GLUE tasks")
    
    # Task and model arguments
    parser.add_argument(
        "--task_name",
        type=str,
        required=True,
        help=f"GLUE task name. Choices: {', '.join(task_to_keys.keys())}",
    )
    parser.add_argument(
        "--model_name_or_path",
        type=str,
        default="bert-base-uncased",
        help="Path to pretrained model or model identifier from huggingface.co/models",
    )
    parser.add_argument(
        "--config_name",
        type=str,
        default=None,
        help="Pretrained config name or path if not the same as model_name",
    )
    parser.add_argument(
        "--tokenizer_name",
        type=str,
        default=None,
        help="Pretrained tokenizer name or path if not the same as model_name",
    )
    parser.add_argument(
        "--max_seq_length",
        type=int,
        default=128,
        help="Maximum sequence length after tokenization",
    )
    
    # LoRA optimization arguments
    parser.add_argument(
        "--lora_r_values",
        type=str,
        default="0,1,2,4,8,16,32",
        help="Comma-separated list of r values to consider for LoRA optimization",
    )
    parser.add_argument(
        "--lora_alpha",
        type=float,
        default=16,
        help="LoRA scaling parameter (alpha)",
    )
    parser.add_argument(
        "--lora_dropout",
        type=float,
        default=0.1,
        help="Dropout probability for LoRA layers",
    )
    parser.add_argument(
        "--lora_budget",
        type=float,
        default=2000000.0,  # ~5% of BERT-base parameters
        help="Budget constraint for LoRA optimization",
    )
    parser.add_argument(
        "--use_existing_lora_config",
        type=str,
        default=None,
        help="Path to existing LoRA configuration file (skips optimization phase)",
    )
    
    # Training arguments
    parser.add_argument(
        "--output_dir", 
        type=str, 
        default=None,
        help="Directory to save model, logs, and results",
    )
    parser.add_argument(
        "--overwrite_output_dir",
        action="store_true",
        help="Overwrite the content of the output directory",
    )
    parser.add_argument(
        "--per_device_train_batch_size",
        type=int,
        default=32,
        help="Batch size per GPU/TPU for training",
    )
    parser.add_argument(
        "--per_device_eval_batch_size",
        type=int,
        default=128,
        help="Batch size per GPU/TPU for evaluation",
    )
    parser.add_argument(
        "--learning_rate",
        type=float,
        default=5e-4,
        help="Initial learning rate (after warmup period)",
    )
    parser.add_argument(
        "--weight_decay",
        type=float,
        default=0.0,
        help="Weight decay to apply",
    )
    parser.add_argument(
        "--num_train_epochs",
        type=int,
        default=3,
        help="Total number of training epochs",
    )
    parser.add_argument(
        "--max_train_steps",
        type=int,
        default=None,
        help="Total number of training steps (overrides num_train_epochs)",
    )
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=1,
        help="Number of steps for gradient accumulation before optimization",
    )
    parser.add_argument(
        "--max_grad_norm",
        type=float,
        default=1.0,
        help="Max gradient norm for gradient clipping (0 to disable)",
    )
    parser.add_argument(
        "--lr_scheduler_type",
        type=SchedulerType,
        default="linear",
        help="LR scheduler type",
        choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
    )
    parser.add_argument(
        "--num_warmup_steps",
        type=int,
        default=0,
        help="Number of steps for the warmup in the LR scheduler",
    )
    parser.add_argument(
        "--warmup_ratio",
        type=float,
        default=None,
        help="Ratio of warmup steps to total steps (overrides num_warmup_steps if provided)",
    )
    parser.add_argument(
        "--max_train_samples",
        type=int,
        default=None,
        help="For debugging: limit the number of training examples",
    )
    parser.add_argument(
        "--max_eval_samples",
        type=int,
        default=None,
        help="For debugging: limit the number of evaluation examples",
    )
    
    # Other arguments
    parser.add_argument(
        "--seed",
        type=int,
        default=42,
        help="Random seed for initialization",
    )
    parser.add_argument(
        "--optimize_only",
        action="store_true",
        help="Only perform LoRA optimization without training",
    )
    parser.add_argument(
        "--log_level",
        type=str,
        default="info",
        choices=["debug", "info", "warning", "error", "critical"],
        help="Log level",
    )

    # Pruning arguments
    parser.add_argument(
        "--apply_pruning",
        action="store_true",
        help="Apply progressive pruning during training",
    )
    parser.add_argument(
        "--pruning_target_reduction",
        type=float,
        default=0.5,
        help="Target parameter reduction for progressive pruning (0.0-1.0)",
    )
    parser.add_argument(
        "--pruning_steps",
        type=int,
        default=4,
        help="Number of pruning steps for progressive pruning",
    )
    parser.add_argument(
        "--pruning_output_dir",
        type=str,
        default=None,
        help="Directory to save pruning results (defaults to output_dir/pruning)",
    )
    parser.add_argument(
        "--importance_ema_decay",
        type=float,
        default=0.9,
        help="EMA decay factor for layer importance scores (0.0-1.0)",
    )
    parser.add_argument(
        "--momentum_penalty_weight",
        type=float,
        default=0.1,
        help="Weight for momentum-based penalty in pruning",
    )

    parser.add_argument(
        "--recovery_steps",
        type=int,
        default=500,
        help="Number of recovery steps after pruning",
    )
    parser.add_argument(
        "--extended_recovery_steps",
        type=int,
        default=1000,
        help="Number of extended recovery steps after rollback",
    )

    parser.add_argument(
        "--disable_rollback",
        action="store_true",
        help="Disable rollback when pruning performance drops",
    )


    args = parser.parse_args()
    
    # Process arguments
    if args.output_dir is None:
        args.output_dir = os.path.join(
            "runs",
            f"{args.task_name}_{args.model_name_or_path.split('/')[-1]}_{time.strftime('%Y%m%d-%H%M%S')}"
        )
    
    # Parse r_values
    args.lora_r_values = [int(r) for r in args.lora_r_values.split(",")]
    
    return args

def set_deterministic_environment(seed=42):
    os.environ["PYTHONHASHSEED"] = str(seed)
    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" 
    

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    

    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        try:
            torch.use_deterministic_algorithms(True)
        except:
            pass  

    logger.info(f"Deterministic environment set with seed {seed}")

def seed_worker(worker_id):
    worker_seed = args.seed + worker_id
    np.random.seed(worker_seed)
    random.seed(worker_seed)

def format_flops(flops):
    """Format FLOPS value with appropriate units."""
    if flops >= 1e12:
        return f"{flops / 1e12:.4f} TFLOPs"
    elif flops >= 1e9:
        return f"{flops / 1e9:.4f} GFLOPs"
    elif flops >= 1e6:
        return f"{flops / 1e6:.4f} MFLOPs"
    else:
        return f"{flops:,.0f} FLOPs"


def format_runtime(seconds):
    """Return runtime in seconds with two decimal places."""
    return f"{seconds:.2f}"

def get_memory_stats():
    """Get essential memory statistics efficiently."""
    if not torch.cuda.is_available():
        return {}
    
    return {
        "peak_allocated_gb": torch.cuda.max_memory_allocated() / (1024**3),
        "reserved_gb": torch.cuda.memory_reserved() / (1024**3)
    }


def setup_minimal_logging():
    """Configure minimal logging to retain important warnings."""
    # Allow transformers warnings but suppress less critical messages
    transformers.logging.set_verbosity_warning()
    
    # Silence DeepSpeed profiler logging
    if 'deepspeed' in sys.modules and not _ds_logger_silenced:
        logging.getLogger('deepspeed').setLevel(logging.ERROR)


def validate_model_parameters(model):
    """Validate model parameters to ensure training stability."""
    # Check for NaN or Inf values in model parameters
    has_problem = False
    for name, param in model.named_parameters():
        if param.requires_grad:
            if torch.isnan(param).any() or torch.isinf(param).any():
                logger.warning(f"Parameter {name} contains NaN or Inf values!")
                has_problem = True
    
    if has_problem:
        logger.warning("Model contains problematic parameters. Consider adjusting learning rate.")
    
    return not has_problem


def get_layer_size(model, layer_name):
    """Get input and output dimensions of a named layer."""
    try:
        names = layer_name.split('.')
        module = model
        for name in names:
            module = getattr(module, name)
        
        if hasattr(module, 'in_features') and hasattr(module, 'out_features'):
            return module.in_features, module.out_features
        return 0, 0
    except (AttributeError, ValueError):
        return 0, 0

def check_lora_merge_status(model):
    """Check if LoRA layers are in merged state"""
    logger.info("=" * 80)
    logger.info("CHECKING LoRA MERGE STATUS")
    logger.info("=" * 80)
    
    merged_count = 0
    unmerged_count = 0
    
    for name, module in model.named_modules():
        if hasattr(module, 'lora_A') and hasattr(module, 'lora_B'):
            if hasattr(module, 'merged'):
                if module.merged:
                    merged_count += 1
                    logger.info(f"Layer {name} is MERGED")
                else:
                    unmerged_count += 1
                    logger.info(f"Layer {name} is UNMERGED")
            else:
                logger.info(f"Layer {name} merge status unknown")
    
    logger.info(f"Merged layers: {merged_count}, Unmerged layers: {unmerged_count}")
    logger.info("=" * 80)
    
    return merged_count, unmerged_count

def measure_computation(model, batch, metrics=["flops", "macs"], accelerator=None, log_output=False, measure_backward=False):
    """
    Accurately measure computation metrics using PyTorch Profiler with improved event filtering and error handling.
    Fixed to handle LoRA merge/unmerge operations outside of profiler context.
    
    Args:
        model: Model to profile
        batch: Input batch (exactly as it would be used in training/evaluation)
        metrics: List of metrics to calculate ("flops", "macs", or both)
        accelerator: Optional accelerator for model unwrapping
        log_output: Whether to log the output
        measure_backward: Whether to measure backward pass (only works when model in training mode)
        
    Returns:
        Dictionary with measured metrics and detailed breakdown
    """
    import torch.profiler as profiler
    import logging
    
    logger = logging.getLogger(__name__)
    
    result = {
        "flops": 0, 
        "macs": 0,
        "lora_flops": 0,
        "original_flops": 0,
        "attention_flops": 0,
        "correction_factors": {}
    }
    
    batch_size = batch["input_ids"].size(0) if "input_ids" in batch else 1
    seq_length = batch["input_ids"].size(1) if "input_ids" in batch else 0
    
    # Calculate effective sequence length considering padding
    if "attention_mask" in batch:
        non_pad_tokens = batch["attention_mask"].sum().item()
        avg_seq_length = non_pad_tokens / batch_size
    else:
        avg_seq_length = seq_length
    
    # Get model in correct state - do this BEFORE profiling
    unwrapped_model = accelerator.unwrap_model(model) if accelerator else model
    model_was_training = unwrapped_model.training
    
    # IMPORTANT: Set model mode BEFORE creating the profiler context
    if not measure_backward:
        unwrapped_model.eval()
    
    # Wait for any pending operations to complete
    if torch.cuda.is_available():
        torch.cuda.synchronize()
    
    # Prepare inputs for profiling
    allowed_keys = {"input_ids", "attention_mask", "token_type_ids", "labels"}
    profiling_batch = {k: v for k, v in batch.items() if k in allowed_keys}
    if not profiling_batch and batch:
        profiling_batch = batch
    
    try:
        # Configure PyTorch profiler with simpler settings
        activities = [profiler.ProfilerActivity.CPU]
        if torch.cuda.is_available():
            activities.append(profiler.ProfilerActivity.CUDA)
        
        # Create profiler with simplified configuration
        with profiler.profile(
            activities=activities,
            record_shapes=True,
            profile_memory=False,
            with_flops=True,
            # Remove with_modules and with_stack to avoid internal errors
        ) as prof:
            # Synchronize before profiling
            if torch.cuda.is_available():
                torch.cuda.synchronize()
            
            # Handle forward and backward passes
            if measure_backward and model_was_training:
                outputs = unwrapped_model(**profiling_batch)
                
                # Create dummy loss and run backward pass
                if hasattr(outputs, 'logits') and outputs.logits is not None:
                    if outputs.logits.size(0) > 0:
                        if outputs.logits.size(-1) > 1:  # Classification
                            dummy_target = torch.zeros(outputs.logits.size(0), dtype=torch.long, device=outputs.logits.device)
                            loss = torch.nn.functional.cross_entropy(outputs.logits, dummy_target)
                        else:  # Regression
                            dummy_target = torch.zeros_like(outputs.logits)
                            loss = torch.nn.functional.mse_loss(outputs.logits, dummy_target)
                        
                        # Run backward pass
                        loss.backward()
            else:
                # Just run forward pass with no_grad
                with torch.no_grad():
                    _ = unwrapped_model(**profiling_batch)
            
            # Synchronize after profiling
            if torch.cuda.is_available():
                torch.cuda.synchronize()
        
        # Process profiler events to get FLOPs
        total_flops = 0
        lora_specific_flops = 0
        attention_flops = 0
        
        for evt in prof.key_averages():
            if evt.flops > 0:
                total_flops += evt.flops
                
                # Get event name
                event_key = evt.key.lower()
                
                # Track LoRA-specific operations
                if any(x in event_key for x in ['lora', 'low_rank', 'adapter']):
                    lora_specific_flops += evt.flops
                    if log_output:
                        logger.debug(f"LoRA operation detected: {evt.key}, FLOPs: {evt.flops}")
                
                # Track attention operations
                if any(x in event_key for x in ['attention', 'softmax', 'attn']):
                    attention_flops += evt.flops
                    if log_output:
                        logger.debug(f"Attention operation detected: {evt.key}, FLOPs: {evt.flops}")
        
        # Store detected FLOPs in result
        result["attention_flops"] = attention_flops
        
        # Get model type for architecture-specific corrections
        model_type = "unknown"
        if hasattr(unwrapped_model, 'config'):
            if hasattr(unwrapped_model.config, 'model_type'):
                model_type = unwrapped_model.config.model_type.lower()
        
        # Apply attention mask correction
        if "attention_mask" in profiling_batch:
            mask_ratio = avg_seq_length / seq_length if seq_length > 0 else 1.0
            
            attention_mask = profiling_batch["attention_mask"]
            attention_ratio = mask_ratio
            
            if attention_mask.dim() == 2:  # [batch, seq_len]
                valid_tokens = attention_mask.sum(dim=1).float().mean().item()
                attention_ratio = valid_tokens / seq_length if seq_length > 0 else 1.0
                attention_correction = attention_ratio ** 2
            else:
                attention_correction = mask_ratio ** 2
            
            # Calculate final correction
            linear_weight = 0.65
            attention_weight = 0.35
            
            correction_factor = linear_weight * mask_ratio + attention_weight * attention_correction
            correction_factor = max(0.1, min(1.0, correction_factor))
            
            total_flops = int(total_flops * correction_factor)
            
            result["correction_factors"]["attention_mask"] = correction_factor
            
            if log_output:
                logger.info(f"Applied attention mask correction factor: {correction_factor:.4f} "
                           f"(mask_ratio={mask_ratio:.4f}, attention_ratio={attention_ratio:.4f})")
        
        # Apply LoRA architecture correction
        has_lora = False
        lora_layers = 0
        lora_flops_calculated = 0
        original_flops_calculated = 0
        
        # Extract r_config from model
        r_config = {}
        if hasattr(unwrapped_model, 'initial_r_config'):
            r_config = unwrapped_model.initial_r_config
        else:
            for name, module in unwrapped_model.named_modules():
                if hasattr(module, 'lora_A') and hasattr(module, 'lora_B'):
                    r = module.lora_A.shape[0] if hasattr(module, 'lora_A') else 0
                    if r > 0:
                        r_config[name] = r
        
        # Calculate LoRA correction
        for name, module in unwrapped_model.named_modules():
            if hasattr(module, 'lora_A') and hasattr(module, 'lora_B'):
                has_lora = True
                lora_layers += 1
                
                # Get dimensions
                if hasattr(module, 'weight'):
                    in_features = module.weight.size(1)
                    out_features = module.weight.size(0)
                else:
                    in_features, out_features = get_layer_dimensions(module)
                
                r = r_config.get(name, 0)
                
                # Fallback matching for complex layer names
                if r == 0:
                    for config_name, config_r in r_config.items():
                        if config_name in name or name in config_name:
                            r = config_r
                            if log_output:
                                logger.debug(f"Found r={r} for layer {name} via partial match with {config_name}")
                            break
                
                # If still no match but layer has LoRA matrices, use matrix dimensions
                if r == 0 and hasattr(module, 'lora_A'):
                    r = module.lora_A.shape[0]
                
                # Calculate actual computation dimensions
                batch_size_calc = profiling_batch["input_ids"].size(0) if "input_ids" in profiling_batch else 1
                seq_len_calc = profiling_batch["input_ids"].size(1) if "input_ids" in profiling_batch else 128
                
                # Use attention mask for effective sequence length
                if "attention_mask" in profiling_batch:
                    effective_seq_len = profiling_batch["attention_mask"].sum().item() / batch_size_calc
                else:
                    effective_seq_len = seq_len_calc
                
                # Original computation (without LoRA)
                original_layer_flops = 2 * batch_size_calc * effective_seq_len * in_features * out_features
                original_flops_calculated += original_layer_flops
                
                # LoRA computation with all operations
                if r > 0:
                    layer_lora_flops = calculate_lora_layer_flops(
                        in_features, out_features, r, batch_size_calc, effective_seq_len
                    )
                    lora_flops_calculated += layer_lora_flops
                    
                    if log_output:
                        logger.info(f"LoRA layer {name}: r={r}, "
                                  f"original_flops={original_layer_flops:,}, "
                                  f"lora_flops={layer_lora_flops:,}")
        
        result["lora_flops"] = lora_flops_calculated
        result["original_flops"] = original_flops_calculated
        
        # Apply LoRA correction if applicable
        if has_lora and original_flops_calculated > 0 and lora_flops_calculated > 0:
            actual_reduction = lora_flops_calculated / original_flops_calculated
            lora_correction = actual_reduction * 0.95
            lora_correction = max(0.1, min(0.9, lora_correction))
            
            total_flops = int(total_flops * lora_correction)
            
            result["correction_factors"]["lora"] = lora_correction
            
            if log_output:
                logger.info(f"LoRA correction applied: reduction={actual_reduction:.4f}, "
                           f"correction={lora_correction:.4f}")
                logger.info(f"LoRA stats: {lora_layers} layers, "
                           f"original_flops={original_flops_calculated:,}, "
                           f"lora_flops={lora_flops_calculated:,}")
        
        # Store measured flops
        result["flops"] = int(total_flops)
        
        # Calculate MACs
        if "macs" in metrics:
            try:
                mac_ratio = calculate_mac_ratio(unwrapped_model, profiling_batch)
            except Exception:
                logger.warning("calculate_mac_ratio() failed. Using fallback MAC ratio of 0.5")
                mac_ratio = 0.5
            
            macs = int(total_flops * mac_ratio)
            result["macs"] = macs
            
            if log_output:
                logger.info(f"MACs estimated from FLOPs: {format_flops(macs)} (using {mac_ratio:.2f} ratio)")
        
        if log_output:
            logger.info(f"Measured FLOPs: {format_flops(total_flops)} "
                       f"(batch_size={batch_size}, avg_seq_length={avg_seq_length:.1f})")
            
    except Exception as e:
        logger.error(f"Computation measurement error: {e}")
        if log_output and logger.level <= logging.DEBUG:
            import traceback
            logger.debug(f"Measurement error details: {traceback.format_exc()}")
        
        result["measurement_failed"] = True
    
    # Restore original model state - do this AFTER profiling
    if model_was_training:
        unwrapped_model.train()
    else:
        unwrapped_model.eval()
    
    # Add per-sample metrics
    if batch_size > 0:
        result["flops_per_sample"] = result["flops"] / batch_size
        result["macs_per_sample"] = result["macs"] / batch_size
    
    # Final cleanup
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        
    return result

def get_layer_dimensions(module):
    """Extract input and output dimensions from a layer module"""
    if hasattr(module, 'weight'):
        weight_shape = module.weight.shape
        if len(weight_shape) >= 2:
            return weight_shape[1], weight_shape[0]  # in_features, out_features
    if hasattr(module, 'in_features') and hasattr(module, 'out_features'):
        return module.in_features, module.out_features
    return 0, 0


def calculate_lora_layer_flops(in_features, out_features, r, batch_size, seq_len):
    """Calculate FLOPs for a specific LoRA layer with given r value"""
    # Base linear operation (if not using LoRA)
    base_flops = 2 * batch_size * seq_len * in_features * out_features
    
    # LoRA operations
    if r > 0:
        # Down projection: [batch, seq, in] × [in, r]
        down_flops = 2 * batch_size * seq_len * in_features * r
        # Up projection: [batch, seq, r] × [r, out]
        up_flops = 2 * batch_size * seq_len * r * out_features
        # Dropout operation (comparison + masking)
        dropout_flops = batch_size * seq_len * r
        # Scaling operation (multiplication)
        scaling_flops = batch_size * seq_len * out_features
        # Addition with base output
        addition_flops = batch_size * seq_len * out_features
        
        lora_flops = down_flops + up_flops + dropout_flops + scaling_flops + addition_flops
        return lora_flops
    else:
        return base_flops


def calculate_mac_ratio(model, sample_batch):
    """Calculate actual MAC/FLOP ratio based on model architecture analysis with LoRA considerations."""
    import torch.nn as nn
    
    mac_operations = 0
    total_operations = 0
    batch_size = sample_batch["input_ids"].size(0) if "input_ids" in sample_batch else 1
    seq_len = sample_batch["input_ids"].size(1) if "input_ids" in sample_batch else 128
    
    # Get r_config if available
    r_config = {}
    if hasattr(model, 'initial_r_config'):
        r_config = model.initial_r_config
    else:
        # Extract r_config from model
        for name, module in model.named_modules():
            if hasattr(module, 'lora_A') and hasattr(module, 'lora_B'):
                r = module.lora_A.shape[0] if hasattr(module, 'lora_A') else 0
                if r > 0:
                    r_config[name] = r
    
    # Analyze model architecture
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            in_features = module.in_features
            out_features = module.out_features
            
            # Check if this layer has LoRA
            r = r_config.get(name, 0)
            if r > 0:
                # LoRA computation
                # Down projection MACs
                down_macs = batch_size * seq_len * in_features * r
                mac_operations += down_macs
                total_operations += down_macs * 2
                
                # Up projection MACs
                up_macs = batch_size * seq_len * r * out_features
                mac_operations += up_macs
                total_operations += up_macs * 2
                
                # Dropout and scaling are not MAC operations
                total_operations += batch_size * seq_len * r  # dropout
                total_operations += batch_size * seq_len * out_features  # scaling
                total_operations += batch_size * seq_len * out_features  # addition
            else:
                # Regular linear layer
                macs = batch_size * seq_len * in_features * out_features
                mac_operations += macs
                total_operations += macs * 2
        
        elif isinstance(module, nn.MultiheadAttention) or 'attention' in name.lower():
            # Attention operations
            embed_dim = getattr(module, 'embed_dim', 768)
            num_heads = getattr(module, 'num_heads', 12)
            
            # QKV projections (considering LoRA)
            qkv_macs = 0
            qkv_total = 0
            
            # Check for query/key/value LoRA
            for comp in ['query', 'key', 'value']:
                comp_name = f"{name}.{comp}"
                if comp_name in r_config:
                    r = r_config[comp_name]
                    if r > 0:
                        # LoRA MACs
                        down_macs = batch_size * seq_len * embed_dim * r
                        up_macs = batch_size * seq_len * r * embed_dim
                        qkv_macs += down_macs + up_macs
                        qkv_total += (down_macs + up_macs) * 2 + batch_size * seq_len * embed_dim
                    else:
                        macs = batch_size * seq_len * embed_dim * embed_dim
                        qkv_macs += macs
                        qkv_total += macs * 2
                else:
                    # Regular projection
                    macs = batch_size * seq_len * embed_dim * embed_dim
                    qkv_macs += macs
                    qkv_total += macs * 2
            
            mac_operations += qkv_macs
            total_operations += qkv_total
            
            # Attention scores (Q @ K.T)
            attention_macs = batch_size * num_heads * seq_len * seq_len * (embed_dim // num_heads)
            mac_operations += attention_macs
            total_operations += attention_macs * 2
            
            # Softmax (not MAC operations)
            total_operations += batch_size * num_heads * seq_len * seq_len * 5
            
            # Context computation (attention @ V)
            context_macs = batch_size * num_heads * seq_len * seq_len * (embed_dim // num_heads)
            mac_operations += context_macs
            total_operations += context_macs * 2
            
            # Output projection (considering LoRA)
            output_name = f"{name}.out_proj"
            if output_name in r_config and r_config[output_name] > 0:
                r = r_config[output_name]
                down_macs = batch_size * seq_len * embed_dim * r
                up_macs = batch_size * seq_len * r * embed_dim
                mac_operations += down_macs + up_macs
                total_operations += (down_macs + up_macs) * 2 + batch_size * seq_len * embed_dim
            else:
                macs = batch_size * seq_len * embed_dim * embed_dim
                mac_operations += macs
                total_operations += macs * 2
    
    # Calculate ratio with bounds checking
    if total_operations > 0:
        ratio = mac_operations / total_operations
        return max(0.3, min(0.7, ratio))  # Reasonable bounds for transformer models
    else:
        return 0.4  # Default fallback for transformers with LoRA

def aggregate_computation_metrics(metrics_list, accelerator=None):
    """
    Accurately aggregate computation metrics across multiple measurements.
    
    This function properly handles distributed environments and ensures
    no duplication occurs when combining metrics from multiple batches.
    
    Args:
        metrics_list: List of dictionaries with metrics to aggregate
        accelerator: Accelerator for distributed training
        
    Returns:
        Dictionary with aggregated metrics
    """
    from datetime import timedelta
    
    if not metrics_list:
        return {"flops": 0, "macs": 0, "flops_per_sample": 0, "macs_per_sample": 0}
    
    # Initialize aggregated metrics with explicit types
    result = {
        "flops": 0.0,  # Use float for accumulation to prevent overflow
        "macs": 0.0,
        "samples": 0
    }
    
    # Check if any measurement failed
    measurement_failures = sum(1 for m in metrics_list if m.get("measurement_failed", False))
    if measurement_failures:
        logger.warning(f"{measurement_failures} out of {len(metrics_list)} measurements failed")
    
    # Filter out failed measurements
    valid_metrics = [m for m in metrics_list if not m.get("measurement_failed", False)]
    
    # If all measurements failed, return zeros with warning
    if not valid_metrics and metrics_list:
        logger.error("All measurements failed - returning zeros")
        return {"flops": 0, "macs": 0, "flops_per_sample": 0, "macs_per_sample": 0, "all_measurements_failed": True}
    
    # Simple aggregation for single process
    if not accelerator or accelerator.num_processes <= 1:
        for metrics in valid_metrics:
            result["flops"] += float(metrics.get("flops", 0))
            result["macs"] += float(metrics.get("macs", 0))
            # Track samples if available
            if "batch_size" in metrics:
                result["samples"] += int(metrics["batch_size"])
    else:
        # Distributed environment requires special handling
        try:
            import torch.distributed as dist
            
            # Enhanced distributed check
            if not dist.is_available():
                raise RuntimeError("Distributed package not available")
                
            if not dist.is_initialized():
                logger.warning("Distributed environment detected but not initialized. Falling back to local aggregation.")
                # Fall back to simple aggregation with validation
                for metrics in valid_metrics:
                    result["flops"] += float(metrics.get("flops", 0))
                    result["macs"] += float(metrics.get("macs", 0))
                    if "batch_size" in metrics:
                        result["samples"] += int(metrics["batch_size"])
                return result
            
            # Verify process group
            if not dist.is_nccl_available() and torch.cuda.is_available():
                logger.warning("NCCL not available, falling back to Gloo backend")
            
            # Sum local metrics with careful type checking
            local_flops = sum(float(metrics.get("flops", 0)) for metrics in valid_metrics)
            local_macs = sum(float(metrics.get("macs", 0)) for metrics in valid_metrics)
            local_samples = sum(int(metrics.get("batch_size", 0)) for metrics in valid_metrics)
            
            # Create tensors for all-reduce with explicit dtype
            metrics_tensor = torch.tensor(
                [local_flops, local_macs, float(local_samples)],
                device=accelerator.device if accelerator else torch.device("cpu"),
                dtype=torch.float64  # Use double precision for large values
            )
            
            # Enhanced synchronization
            if dist.is_initialized():
                # Use timeout to prevent hanging
                dist.barrier(timeout=timedelta(seconds=30))
                dist.all_reduce(metrics_tensor, op=dist.ReduceOp.SUM)
                
                # Verify reduction success
                if torch.isnan(metrics_tensor).any() or torch.isinf(metrics_tensor).any():
                    raise ValueError("Invalid values after all-reduce operation")
            
            # Extract results with appropriate type conversion
            result["flops"] = float(metrics_tensor[0].item())
            result["macs"] = float(metrics_tensor[1].item())
            result["samples"] = int(metrics_tensor[2].item())
            
        except Exception as e:
            logger.warning(f"Distributed metrics aggregation error: {e}")
            logger.info("Falling back to local aggregation")
            
            # Fall back to simple aggregation
            for metrics in valid_metrics:
                result["flops"] += float(metrics.get("flops", 0))
                result["macs"] += float(metrics.get("macs", 0))
                if "batch_size" in metrics:
                    result["samples"] += int(metrics["batch_size"])
    
    # Calculate per-sample metrics if samples are tracked
    if result["samples"] > 0:
        result["flops_per_sample"] = result["flops"] / result["samples"]
        result["macs_per_sample"] = result["macs"] / result["samples"]
    else:
        # Default to averages from the list if batch sizes missing
        valid_metrics_with_per_sample = [m for m in valid_metrics if "flops_per_sample" in m]
        if valid_metrics_with_per_sample:
            result["flops_per_sample"] = sum(m["flops_per_sample"] for m in valid_metrics_with_per_sample) / len(valid_metrics_with_per_sample)
            result["macs_per_sample"] = sum(m["macs_per_sample"] for m in valid_metrics_with_per_sample) / len(valid_metrics_with_per_sample)
        else:
            result["flops_per_sample"] = 0
            result["macs_per_sample"] = 0
    
    # Convert back to int for final metrics (after all calculations)
    result["flops"] = int(result["flops"])
    result["macs"] = int(result["macs"])
    
    return result

def measure_model_computation(model, dataloader, num_batches=None, metrics=["flops", "macs"],
                           accelerator=None, log_output=False, measure_backward=False):
    """
    Measure computation across multiple batches for more accurate estimation.
    
    This function addresses dataset variation issues by directly measuring
    multiple batches instead of extrapolating from a single batch.
    
    Args:
        model: Model to profile
        dataloader: DataLoader to iterate over batches
        num_batches: Number of batches to measure (None for all)
        metrics: List of metrics to calculate
        accelerator: Accelerator for distributed training
        log_output: Whether to log the output
        measure_backward: Whether to measure backward pass (only works when model in training mode)
        
    Returns:
        Dictionary with aggregated metrics
    """
    metrics_list = []
    sample_count = 0
    
    # Set appropriate mode based on measurement type
    model_was_training = model.training
    if not measure_backward:
        model.eval()
    
    # Save CUDA memory state to restore later
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        initial_memory = torch.cuda.memory_allocated()
    
    try:
        # Use deterministic subset of dataloader for reproducibility if num_batches is specified
        batch_indices = list(range(min(num_batches, len(dataloader)))) if num_batches is not None else None
        
        # Iterate through batches
        for batch_idx, batch in enumerate(dataloader):
            if batch_indices is not None and batch_idx >= len(batch_indices):
                break
            
            # Log progress for long measurements
            if log_output and batch_idx > 0 and batch_idx % 10 == 0:
                logger.info(f"Measuring batch {batch_idx}/{len(batch_indices) if batch_indices else len(dataloader)}")
            
            # Measure current batch
            batch_metrics = measure_computation(
                model=model,
                batch=batch,
                metrics=metrics,
                accelerator=accelerator,
                log_output=(log_output and batch_idx == 0),  # Log only first batch
                measure_backward=measure_backward
            )
            
            # Check if measurement failed
            if batch_metrics.get("measurement_failed", False):
                logger.warning(f"Measurement failed for batch {batch_idx} - skipping this batch")
                continue
            
            # Add batch size for weighted averaging
            batch_size = batch["input_ids"].size(0) if "input_ids" in batch else 1
            batch_metrics["batch_size"] = batch_size
            sample_count += batch_size
            
            # Save metrics
            metrics_list.append(batch_metrics)
            
            # Clear memory between measurements for large models
            if torch.cuda.is_available() and batch_idx % 5 == 0:
                torch.cuda.empty_cache()
            
        # Aggregate metrics
        results = aggregate_computation_metrics(metrics_list, accelerator)
        
        # Add metadata to results
        results["measured_batches"] = len(metrics_list)
        results["total_samples"] = sample_count
        
        # Log final results
        if log_output:
            success_rate = len(metrics_list) / (batch_idx + 1) if batch_idx >= 0 else 0
            logger.info(f"Measured {len(metrics_list)} batches ({sample_count} samples) with {success_rate:.1%} success rate:")
            if "flops" in metrics:
                logger.info(f"  Total FLOPs: {format_flops(results['flops'])}")
                logger.info(f"  FLOPs per sample: {format_flops(results['flops_per_sample'])}")
            if "macs" in metrics:
                logger.info(f"  Total MACs: {format_flops(results['macs'])}")
                logger.info(f"  MACs per sample: {format_flops(results['macs_per_sample'])}")
            if measure_backward:
                logger.info(f"  Note: Measurements include both forward and backward passes")
        
        return results
    
    finally:
        # Restore original model state
        if model_was_training:
            model.train()
        else:
            model.eval()
            
        # Clean up memory
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            # Check and log memory leaks
            final_memory = torch.cuda.memory_allocated()
            if final_memory > initial_memory + 10 * 1024 * 1024:  # 10MB threshold
                logger.warning(f"Possible memory leak detected: {(final_memory - initial_memory) / (1024*1024):.2f}MB increase")

# ---------- END OF COMPUTATION MEASUREMENT FUNCTIONS ----------

# Backward-compatible wrapper functions can remain the same or be simplified further
# These are just thin wrappers around the main functions

def calculate_flops(model, batch, step=-1, force=False, accelerator=None, log_output=False):
    """Backward-compatible wrapper for FLOPs calculation."""
    result = measure_computation(
        model=model, 
        batch=batch, 
        metrics=["flops"],
        accelerator=accelerator, 
        log_output=log_output
    )
    return result["flops"]

def calculate_macs(model, batch, step=-1, accelerator=None, log_output=False):
    """Backward-compatible wrapper for MACs calculation."""
    result = measure_computation(
        model=model, 
        batch=batch, 
        metrics=["macs"],
        accelerator=accelerator, 
        log_output=log_output
    )
    return result["macs"]

def calculate_compute_metrics(model, batch, step=-1, force=False, accelerator=None, 
                            log_output=False, metrics=["flops", "macs"], normalize_by_batch_size=True):
    """Backward-compatible wrapper for combined compute metrics calculation."""
    return measure_computation(
        model=model,
        batch=batch,
        metrics=metrics,
        accelerator=accelerator,
        log_output=log_output
    )

def store_compute_metrics(current_metrics, total_metrics, accelerator):
    """Backward-compatible wrapper for metrics aggregation."""
    # Convert any tensor values to Python types
    current_metrics_python = {}
    for k, v in current_metrics.items():
        if isinstance(v, torch.Tensor):
            current_metrics_python[k] = float(v)
        else:
            current_metrics_python[k] = v
    
    # Add values with proper type handling
    for key in ["flops", "macs"]:
        if key in current_metrics_python and key in total_metrics:
            total_metrics[key] = float(total_metrics[key]) + float(current_metrics_python[key])
    
    return total_metrics

# ---------- END OF COMPUTATION MEASUREMENT FUNCTIONS ----------

def preprocess_dataset(tokenizer, raw_datasets, task_name, max_length, max_train_samples=None, max_eval_samples=None):
    """Efficiently preprocess dataset for training and evaluation."""
    sentence1_key, sentence2_key = task_to_keys[task_name]
    
    def preprocess_function(examples):
        texts = (
            (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
        )
        result = tokenizer(*texts, padding=False, max_length=max_length, truncation=True)
        
        possible_label_keys = ["label", "labels", "target", "targets"]
        found_label = False

        for key in possible_label_keys:
            if key in examples:
                result["labels"] = examples[key]
                found_label = True
                logger.debug(f"Found label key: {key}")
                break

        if not found_label:
            logger.warning(f"No label key found in examples. Keys: {list(examples.keys())}")
        
        return result
    
    # Process datasets efficiently
    processed_datasets = raw_datasets.map(
        preprocess_function,
        batched=True,
        remove_columns=raw_datasets["train"].column_names,
        desc="Preprocessing dataset",
    )
    
    # Prepare train dataset with optional sample limit
    train_dataset = processed_datasets["train"]
    if max_train_samples is not None:
        train_dataset = train_dataset.select(range(min(max_train_samples, len(train_dataset))))
    
    # Prepare evaluation dataset(s)
    if task_name == "mnli":
        # For MNLI, prepare both matched and mismatched validation sets
        eval_matched_dataset = processed_datasets["validation_matched"]
        eval_mismatched_dataset = processed_datasets["validation_mismatched"]
        
        if max_eval_samples is not None:
            eval_matched_dataset = eval_matched_dataset.select(range(min(max_eval_samples, len(eval_matched_dataset))))
            eval_mismatched_dataset = eval_mismatched_dataset.select(range(min(max_eval_samples, len(eval_mismatched_dataset))))
        
        return train_dataset, eval_matched_dataset, eval_mismatched_dataset
    else:
        # For other tasks, return the single validation set
        eval_dataset = processed_datasets["validation"]
        if max_eval_samples is not None:
            eval_dataset = eval_dataset.select(range(min(max_eval_samples, len(eval_dataset))))
        
        return train_dataset, eval_dataset


def validate_lora_configuration(r_config, replaced_layers, logger):
    """Validate that the optimized r-values match the actually applied values."""
    logger.info("=" * 60)
    logger.info("VALIDATING LoRA CONFIGURATION")
    logger.info("=" * 60)

    applied_layers = 0
    skipped_layers = 0
    mismatch_layers = 0
    total_params = 0

    for layer_name, layer_info in replaced_layers.items():
        expected_r = r_config.get(layer_name, 0)
        actual_r = layer_info.get('r', 0)
        applied = layer_info.get('applied', False)
        
        # Calculate parameters if dimensions are available
        param_count = 0
        if 'in_features' in layer_info and 'out_features' in layer_info:
            param_count = actual_r * (layer_info['in_features'] + layer_info['out_features'])
            total_params += param_count
        
        status = "✓ MATCH" if expected_r == actual_r else "✗ MISMATCH"
        applied_status = "APPLIED" if applied else "SKIPPED"
        
        logger.info(f"Layer: {layer_name}")
        logger.info(f"  Expected r: {expected_r}, Actual r: {actual_r}, Status: {status}, {applied_status}")
        if param_count > 0:
            logger.info(f"  Parameters: {param_count:,}")
        
        if logger.level <= logging.DEBUG:
            logger.debug(f"Layer {layer_name}: Expected r={expected_r}, "
                        f"Actual r={actual_r}, Status: {status}, {applied_status}")
        
        if expected_r != actual_r:
            mismatch_layers += 1
            logger.warning(f"MISMATCH in layer {layer_name}: Expected r={expected_r}, Actual r={actual_r}")
        
        if applied:
            applied_layers += 1
        else:
            skipped_layers += 1

    logger.info(f"LoRA SUMMARY: {applied_layers} layers applied, {skipped_layers} layers skipped")
    logger.info(f"Total LoRA parameters: {total_params:,}")
    
    if mismatch_layers > 0:
        logger.error(f"CRITICAL: Found {mismatch_layers} layers with r-value mismatches!")
    else:
        logger.info("SUCCESS: All r-values correctly applied")
    logger.info("=" * 60)
    
    return {
        "applied_layers": applied_layers,
        "skipped_layers": skipped_layers,
        "mismatch_layers": mismatch_layers,
        "total_params": total_params
    }


def save_metrics(metrics, output_path):
    """Save metrics to a file, handling tensor conversion efficiently."""
    # Convert tensor values to Python types
    serializable_metrics = {}
    for k, v in metrics.items():
        if isinstance(v, torch.Tensor):
            serializable_metrics[k] = float(v)
        else:
            serializable_metrics[k] = v
    
    # Ensure directory exists
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    
    # Save metrics efficiently
    try:
        with open(output_path, 'w') as f:
            json.dump(serializable_metrics, f, indent=2)
    except Exception as e:
        logger.warning(f"Failed to save metrics: {e}")


def optimize_lora_config(model, dataloader, args, accelerator, output_dir):
    """Find optimal LoRA configuration."""
    logger.info(f"Starting LoRA optimization with seed {args.seed}...")
    
    # Use efficient optimizer with seed
    lora_optimizer = LoRAOptimizer(
        model=model,
        r_values=args.lora_r_values,
        budget=args.lora_budget,
        device=accelerator.device,
        output_dir=os.path.join(output_dir, "lora_optimization"),
        seed=args.seed
    )
    
    # Find optimal configuration
    with accelerator.main_process_first():
        model.to(accelerator.device)
        r_config = lora_optimizer.optimize(dataloader)
    
    # Save configuration for reference
    if accelerator.is_main_process:
        with open(os.path.join(output_dir, "lora_r_config.json"), 'w') as f:
            json.dump(r_config, f, indent=2)
    
    # Log summary of configuration
    layer_types = {
        "query": [],
        "key": [],
        "value": [],
        "attention.output": [],
        "intermediate": [],
        "output": [],
    }
    
    # Group layers by type
    for layer_name, r_value in r_config.items():
        if "query" in layer_name:
            layer_types["query"].append((layer_name, r_value))
        elif "key" in layer_name:
            layer_types["key"].append((layer_name, r_value))
        elif "value" in layer_name:
            layer_types["value"].append((layer_name, r_value))
        elif "attention.output" in layer_name:
            layer_types["attention.output"].append((layer_name, r_value))
        elif "intermediate" in layer_name:
            layer_types["intermediate"].append((layer_name, r_value))
        elif "output" in layer_name and "attention" not in layer_name:
            layer_types["output"].append((layer_name, r_value))
    
    # Log layer type summary
    logger.info("OPTIMAL LoRA RANK CONFIGURATION:")
    
    for layer_type, layers in layer_types.items():
        if not layers:
            continue
            
        r_values = [r for _, r in layers]
        avg_r = sum(r_values) / len(r_values) if r_values else 0
        logger.info(f"  {layer_type}: avg_r={avg_r:.2f}, count={len(layers)}")
        
    # Calculate and log parameter count
    total_params = sum(
        r_value * (get_layer_size(model, layer_name)[0] + get_layer_size(model, layer_name)[1]) 
        for layer_name, r_value in r_config.items() if r_value > 0
    )
    logger.info(f"Total LoRA parameters: {total_params:,}")
    
    return r_config

def check_lora_merge_status(model):
    """Check if LoRA layers are in merged state"""
    logger.info("=" * 80)
    logger.info("CHECKING LoRA MERGE STATUS")
    logger.info("=" * 80)
    
    merged_count = 0
    unmerged_count = 0
    zero_r_count = 0
    
    for name, module in model.named_modules():
        if hasattr(module, 'lora_A') and hasattr(module, 'lora_B'):
            if hasattr(module, 'merged') and hasattr(module, 'r'):
                if module.r == 0:
                    zero_r_count += 1
                    logger.info(f"Layer {name} has r=0 (pruned out)")
                elif module.merged:
                    merged_count += 1
                    logger.info(f"Layer {name} is MERGED")
                else:
                    unmerged_count += 1
                    logger.info(f"Layer {name} is UNMERGED")
            else:
                logger.info(f"Layer {name} merge status unknown")
    
    logger.info(f"Summary: Merged={merged_count}, Unmerged={unmerged_count}, Zero-r={zero_r_count}")
    logger.info("=" * 80)
    
    return merged_count, unmerged_count


def main():
    args = parse_args()
    
    # Set environment variables for reproducibility
    os.environ["PYTHONHASHSEED"] = str(args.seed)
    os.environ["GRB_NUM_THREADS"] = "1"  # Force Gurobi to use a single thread
    
    # Setup minimal logging
    log_level = getattr(logging, args.log_level.upper())
    logger = setup_logger("optimal_lora", os.path.join(args.output_dir, "training.log"), level=log_level)
    
    # Configure minimal logging for external libraries
    setup_minimal_logging()
    
    # Set random seed
    set_deterministic_environment(args.seed)
    
    # Initialize accelerator
    accelerator = Accelerator(
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        log_with="tensorboard",
        project_dir=os.path.join(args.output_dir, "logs")
    )
    
    # Initialize tracking variables
    train_flops = 0
    train_macs = 0
    eval_flops = 0
    eval_macs = 0
    
    # Basic memory tracking
    if torch.cuda.is_available():
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.empty_cache()
        logger.info(f"Initial GPU memory: {torch.cuda.memory_allocated() / (1024**3):.2f}GB")
    
    # Load configuration
    config = AutoConfig.from_pretrained(
        args.config_name if args.config_name else args.model_name_or_path,
        num_labels=3 if args.task_name == "mnli" else 1 if args.task_name == "stsb" else 2,
        finetuning_task=args.task_name,
    )
    
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(
        args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
        use_fast=True,
    )
    
    # Load dataset
    raw_datasets = load_dataset("glue", args.task_name, trust_remote_code=True)
    
    # Preprocess dataset
    if args.task_name == "mnli":
        train_dataset, eval_matched_dataset, eval_mismatched_dataset = preprocess_dataset(
            tokenizer=tokenizer,
            raw_datasets=raw_datasets,
            task_name=args.task_name,
            max_length=args.max_seq_length,
            max_train_samples=args.max_train_samples,
            max_eval_samples=args.max_eval_samples
        )
    else:
        train_dataset, eval_dataset = preprocess_dataset(
            tokenizer=tokenizer,
            raw_datasets=raw_datasets,
            task_name=args.task_name,
            max_length=args.max_seq_length,
            max_train_samples=args.max_train_samples,
            max_eval_samples=args.max_eval_samples
        )
    
    # Create data collator
    data_collator = DataCollatorWithPadding(
        tokenizer, 
        pad_to_multiple_of=8 if accelerator.mixed_precision == "fp16" else None
    )
    
    g = torch.Generator()
    g.manual_seed(args.seed)

    train_dataloader = DataLoader(
        train_dataset, 
        shuffle=True, 
        collate_fn=data_collator, 
        batch_size=args.per_device_train_batch_size,
        worker_init_fn=seed_worker,
        generator=g,
        persistent_workers=False
    )

    if args.task_name == "mnli":
        eval_dataloader = DataLoader(
            eval_matched_dataset, 
            collate_fn=data_collator, 
            batch_size=args.per_device_eval_batch_size,
            worker_init_fn=seed_worker,
            generator=g,
            persistent_workers=False
        )
        
        eval_mismatched_dataloader = DataLoader(
            eval_mismatched_dataset, 
            collate_fn=data_collator, 
            batch_size=args.per_device_eval_batch_size,
            worker_init_fn=seed_worker,
            generator=g,
            persistent_workers=False
        )
    else:
        eval_dataloader = DataLoader(
            eval_dataset, 
            collate_fn=data_collator, 
            batch_size=args.per_device_eval_batch_size,
            worker_init_fn=seed_worker,
            generator=g,
            persistent_workers=False
        )
    
    # Load base model
    logger.info("Loading base model...")
    model = AutoModelForSequenceClassification.from_pretrained(
        args.model_name_or_path,
        config=config,
    )
    
    # === LoRA Optimization Phase ===
    if args.use_existing_lora_config:
        logger.info(f"Using existing LoRA configuration from {args.use_existing_lora_config}")
        with open(args.use_existing_lora_config, 'r') as f:
            r_config = json.load(f)
    else:
        # Create a subset of the training data for optimization
        optimization_dataset = train_dataset.select(range(min(MAX_OPTIMIZATION_SAMPLES, len(train_dataset))))
        optimization_dataloader = DataLoader(
            optimization_dataset, 
            shuffle=False, 
            collate_fn=data_collator, 
            batch_size=args.per_device_eval_batch_size
        )
        
        # Find optimal LoRA config
        r_config = optimize_lora_config(
            model=model,
            dataloader=optimization_dataloader,
            args=args,
            accelerator=accelerator,
            output_dir=args.output_dir
        )

        # Explicitly log the optimal r configuration here in the main function
        if accelerator.is_main_process:
            # Group layers by type for detailed logging
            layer_types = {
                "query": [],
                "key": [],
                "value": [],
                "attention.output": [],
                "intermediate": [],
                "output": [],
            }
            
            # Group layers by type
            for layer_name, r_value in r_config.items():
                if "query" in layer_name:
                    layer_types["query"].append((layer_name, r_value))
                elif "key" in layer_name:
                    layer_types["key"].append((layer_name, r_value))
                elif "value" in layer_name:
                    layer_types["value"].append((layer_name, r_value))
                elif "attention.output" in layer_name:
                    layer_types["attention.output"].append((layer_name, r_value))
                elif "intermediate" in layer_name:
                    layer_types["intermediate"].append((layer_name, r_value))
                elif "output" in layer_name and "attention" not in layer_name:
                    layer_types["output"].append((layer_name, r_value))
            
            # Log detailed r configuration report
            logger.info("=" * 50)
            logger.info("OPTIMAL LoRA RANK (r) CONFIGURATION:")
            logger.info("=" * 50)
            for layer_type, layers in layer_types.items():
                if not layers:
                    continue
                
                logger.info(f"\n[{layer_type.upper()} LAYERS]")
                # Calculate average r value for this layer type
                r_values = [r for _, r in layers]
                avg_r = sum(r_values) / len(r_values) if r_values else 0
                
                for layer_name, r in sorted(layers, key=lambda x: x[0]):
                    logger.info(f"  {layer_name}: r = {r}")
                
                logger.info(f"  Average r for {layer_type} layers: {avg_r:.2f}")
            logger.info("=" * 50)
            
            # Calculate and log total LoRA parameters
            total_params = sum(
                r_value * (get_layer_size(model, layer_name)[0] + get_layer_size(model, layer_name)[1]) 
                for layer_name, r_value in r_config.items() if r_value > 0
            )
            logger.info(f"Total LoRA parameters: {total_params:,}")
            logger.info("=" * 50)

        # Free memory
        del optimization_dataset, optimization_dataloader
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    
    # If optimize_only, exit here
    if args.optimize_only:
        logger.info("Optimization completed. Exiting as --optimize_only was specified.")
        return
    
    # === Training Phase ===
    logger.info("Preparing model with LoRA for training...")
    
    # Replace model with LoRA version
    model, replaced_layers = prepare_model_for_glue(
        base_model_name=args.model_name_or_path,
        r_config=r_config,
        num_labels=config.num_labels,
        dropout=args.lora_dropout
    )
    
    # Validate the LoRA configuration
    validate_lora_configuration(r_config, replaced_layers, logger)
    
    # Save model architecture for reference
    if accelerator.is_main_process:
        with open(os.path.join(args.output_dir, "model_architecture.txt"), 'w') as f:
            f.write(str(model))
    
    # Mark only LoRA parameters as trainable for efficiency
    lora.mark_only_lora_as_trainable(model)
    
    # Log parameter information
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total_params = sum(p.numel() for p in model.parameters())
    logger.info(f"Trainable parameters: {trainable_params:,} ({trainable_params/total_params:.2%} of total)")
    
    # Calculate LoRA parameters
    total_lora_params = 0
    for layer_name, r_value in r_config.items():
        if r_value > 0:
            in_features, out_features = get_layer_size(model, layer_name)
            total_lora_params += r_value * (in_features + out_features)
    logger.info(f"Total LoRA parameters: {total_lora_params:,}")
    
    # Setup optimizer - use transformers AdamW for accuracy consistency
    optimizer = AdamW(
        [param for param in model.parameters() if param.requires_grad],
        lr=args.learning_rate,
        weight_decay=args.weight_decay,
        no_deprecation_warning=True  # Suppress deprecation warning
    )
    
    # Prepare with accelerator
    if args.task_name == "mnli":
        model, optimizer, train_dataloader, eval_dataloader, eval_mismatched_dataloader = accelerator.prepare(
            model, optimizer, train_dataloader, eval_dataloader, eval_mismatched_dataloader
        )
    else:
        model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
            model, optimizer, train_dataloader, eval_dataloader
        )
    
    # Calculate training steps
    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
    if args.max_train_steps is None:
        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
    else:
        args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
    
    # Calculate num_warmup_steps from warmup_ratio if provided
    if args.warmup_ratio is not None:
        # Standard interpretation: warmup_ratio is ratio of TOTAL training steps
        # This is the standard way used by HuggingFace Transformers and most ML frameworks
        args.num_warmup_steps = int(args.max_train_steps * args.warmup_ratio)
        logger.info(f"Calculated num_warmup_steps={args.num_warmup_steps} from warmup_ratio={args.warmup_ratio} (total_steps={args.max_train_steps})")
        
    # Create learning rate scheduler
    lr_scheduler = get_scheduler(
        name=args.lr_scheduler_type,
        optimizer=optimizer,
        num_warmup_steps=args.num_warmup_steps,
        num_training_steps=args.max_train_steps,
    )
    # Initialize metrics
    metric = evaluate.load("glue", args.task_name, trust_remote_code=True)
    
    # Initialize progressive pruning manager if enabled
    pruning_manager = None
    if args.apply_pruning:
        logger.info("=" * 80)
        logger.info("INITIALIZING PROGRESSIVE LORA PRUNING")
        logger.info("=" * 80)
        
        # Get initial r_config from model
        unwrapped_model = accelerator.unwrap_model(model)
        if not hasattr(unwrapped_model, 'initial_r_config'):
            logger.warning("Model doesn't have initial_r_config attribute. Extracting from model...")
            
            # Extract r_config from model
            initial_r_config = {}
            for name, module in unwrapped_model.named_modules():
                if hasattr(module, 'lora_A') and hasattr(module, 'lora_B'):
                    r = getattr(module, 'lora_A').shape[0] if hasattr(module, 'lora_A') else 0
                    if r > 0:
                        initial_r_config[name] = r
        else:
            initial_r_config = unwrapped_model.initial_r_config
            
        # Set pruning output directory
        pruning_output_dir = args.pruning_output_dir or os.path.join(args.output_dir, "pruning")
        os.makedirs(pruning_output_dir, exist_ok=True)
        
        # Override config values with command line arguments
        from pruning import pruning_config
        pruning_config.IMPORTANCE_EMA_DECAY = args.importance_ema_decay
        pruning_config.MOMENTUM_PENALTY_WEIGHT = args.momentum_penalty_weight
        pruning_config.RECOVERY_STEPS = args.recovery_steps
        pruning_config.EXTENDED_RECOVERY_STEPS = args.extended_recovery_steps
        
        # Log EMA and momentum parameters
        logger.info(f"Using importance EMA decay: {pruning_config.IMPORTANCE_EMA_DECAY}")
        logger.info(f"Using momentum penalty weight: {pruning_config.MOMENTUM_PENALTY_WEIGHT}")
        logger.info(f"Using recovery steps: {pruning_config.RECOVERY_STEPS}")
        logger.info(f"Using extended recovery steps: {pruning_config.EXTENDED_RECOVERY_STEPS}")
        
        # Create pruning manager with seed
        pruning_manager = ProgressivePruningManager(
            model=unwrapped_model,
            initial_r_config=initial_r_config,
            train_dataloader=train_dataloader,
            eval_dataloader=eval_dataloader,
            target_reduction=args.pruning_target_reduction,
            num_pruning_steps=args.pruning_steps,
            total_training_steps=args.max_train_steps,
            device=accelerator.device,
            output_dir=pruning_output_dir,
            seed=args.seed,
            enable_rollback=not args.disable_rollback
        )
        
        # Initialize baseline performance
        pruning_manager.initialize_baseline_performance()
        merged_count, unmerged_count = check_lora_merge_status(model)

    # Training loop preparation

    logger.info(f"Starting training for {args.num_train_epochs} epochs ({args.max_train_steps} steps)")
    logger.info(f"Training parameters: batch_size={args.per_device_train_batch_size}, lr={args.learning_rate}, "
                f"grad_accum={args.gradient_accumulation_steps}, max_grad_norm={args.max_grad_norm}")
    completed_steps = 0
    best_metric = 0.0
    train_start_time = time.time()
    total_train_loss = 0
    total_train_samples = 0
    total_train_steps = 0
    
    # Measure computation with sample batch for baseline
    logger.info("Measuring baseline computation metrics...")
    sample_batch = next(iter(train_dataloader))
    sample_batch_size_1 = {k: v[0:1] for k, v in sample_batch.items() if isinstance(v, torch.Tensor)}
    
    # Measure both metrics for baseline
    baseline_metrics = measure_computation(
        model=model, 
        batch=sample_batch_size_1, 
        metrics=["flops", "macs"],
        accelerator=accelerator, 
        log_output=True
    )
    
    logger.info(f"Baseline computation metrics (single sample):")
    logger.info(f"  - FLOPs per sample: {format_flops(baseline_metrics['flops_per_sample'])}")
    if baseline_metrics['macs_per_sample'] > 0:
        logger.info(f"  - MACs per sample: {format_flops(baseline_metrics['macs_per_sample'])}")
    
    # Verify model parameters before training
    validate_model_parameters(model)
    
    # Training loop - optimized for performance and accuracy
    for epoch in range(args.num_train_epochs):
        # Reset memory stats at epoch start
        if torch.cuda.is_available():
            torch.cuda.reset_peak_memory_stats()
            torch.cuda.empty_cache()
            logger.info(f"Epoch {epoch+1}/{args.num_train_epochs} - Starting memory: "
                       f"{torch.cuda.memory_allocated() / (1024**3):.2f}GB")
        
        model.train()
        train_loss = 0.0
        
        # Initialize list to store computation metrics for this epoch
        epoch_computation_metrics = []
        
        for step, batch in enumerate(train_dataloader):
            # Check if pruning should be applied at this step
            if pruning_manager is not None and pruning_manager.should_prune(completed_steps):
                # Execute pruning step
                pruning_success = pruning_manager.execute_pruning_step(completed_steps)
                # Refresh model reference after pruning
                if pruning_success:
                    unwrapped_model = accelerator.unwrap_model(model)
                    # Ensure proper LoRA layer states after pruning
                    logger.info("Checking LoRA layer states after pruning...")
                    check_lora_merge_status(unwrapped_model)
            
            outputs = model(**batch)
            
            # Measure computation metrics at specified intervals
            if step % LOG_FLOPS_INTERVAL == 0:
                try:
                    # Measure exact computation for this batch directly
                    batch_metrics = measure_computation(
                        model=model, 
                        batch=batch, 
                        metrics=["flops", "macs"],
                        accelerator=accelerator,
                        log_output=(step == 0)  # Only log the first batch
                    )
                    
                    # Add to tracking list with batch size for proper weighting
                    batch_metrics["batch_size"] = batch["input_ids"].size(0)
                    epoch_computation_metrics.append(batch_metrics)
                    
                except Exception as e:
                    logger.warning(f"Computation measurement skipped: {e}")
            
            loss = outputs.loss
            if loss is None:
                # Attempt to calculate loss from logits if available
                if hasattr(outputs, 'logits') and outputs.logits is not None:
                    if 'labels' in batch:
                        if batch['labels'].dim() == 1:  # Classification operations
                            loss = torch.nn.functional.cross_entropy(outputs.logits, batch['labels'])
                            logger.debug(f"Computed cross_entropy loss for batch with shape {outputs.logits.shape}")
                        else:  # Regression operations
                            loss = torch.nn.functional.mse_loss(outputs.logits.squeeze(), batch['labels'])
                            logger.debug(f"Computed mse loss for batch with shape {outputs.logits.shape}")
                    else:
                        # Check alternate keys if there is no label
                        label_found = False
                        for key in ['label', 'target', 'targets']:
                            if key in batch:
                                if batch[key].dim() == 1:  # Classification operations
                                    loss = torch.nn.functional.cross_entropy(outputs.logits, batch[key])
                                else:  # Regression operations
                                    loss = torch.nn.functional.mse_loss(outputs.logits.squeeze(), batch[key])
                                label_found = True
                                logger.debug(f"Used alternative label key: {key}")
                                break
                        
                        if not label_found:
                            logger.error(f"No labels found in batch. Skipping this batch. Batch keys: {list(batch.keys())}")
                            loss = None
                            continue  # Skip Batch
                else:
                    logger.warning(f"Cannot compute loss: No loss or logits in outputs. Output keys: {dir(outputs)}")
                    continue  # Skip Batch
            
            if loss is not None:
                # Process loss efficiently
                train_loss += loss.detach().float()
                total_train_loss += loss.detach().float().item()
                
                # Track batch statistics
                batch_size = batch["input_ids"].size(0)
                total_train_samples += batch_size
                
                # Scale loss for gradient accumulation
                loss = loss / args.gradient_accumulation_steps
                accelerator.backward(loss)
                
                if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
                    # Apply gradient clipping if enabled
                    if args.max_grad_norm > 0:
                        accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm)
                    
                    optimizer.step()
                    lr_scheduler.step()
                    optimizer.zero_grad()
                    
                    # Update recovery mode if pruning is active
                    if pruning_manager is not None and pruning_manager.scheduler.in_recovery_mode:
                        pruning_manager.update_recovery()
                    
                    completed_steps += 1
                    total_train_steps += 1
                    
                    # Log progress at reasonable intervals
                    if completed_steps % 100 == 0:
                        avg_loss = train_loss / (step + 1)
                        progress = (completed_steps / args.max_train_steps) * 100
                        current_lr = optimizer.param_groups[0]['lr']
                        
                        # Include pruning info in log if active
                        if pruning_manager is not None:
                            pruning_info = pruning_manager.get_progress_info()
                            pruning_status = f", Pruning: {pruning_info['reduction']:.2%} reduction"
                            if pruning_manager.scheduler.in_recovery_mode:
                                pruning_status += " (in recovery)"
                        else:
                            pruning_status = ""
                        
                        logger.info(f"Step {completed_steps}/{args.max_train_steps} ({progress:.1f}%): "
                                   f"Loss={avg_loss:.4f}, LR={current_lr:.2e}{pruning_status}")
                    
                    if completed_steps >= args.max_train_steps:
                        break
        
        # Aggregate epoch computation metrics
        if epoch_computation_metrics:
            epoch_results = aggregate_computation_metrics(epoch_computation_metrics, accelerator)
            
            # Debug: Check measurement details
            measured_steps = len(epoch_computation_metrics)
            total_steps = len(train_dataloader)
                        
            # Calculate expected measurements
            expected_measured_steps = (total_steps + LOG_FLOPS_INTERVAL - 1) // LOG_FLOPS_INTERVAL
            
            # Calculate total samples in measured batches
            measured_samples = sum(m.get("batch_size", 0) for m in epoch_computation_metrics)
            
            # Extrapolate to full epoch
            extrapolation_factor = total_steps / measured_steps if measured_steps > 0 else 1
            
            # Apply extrapolation to get total computation for the epoch
            epoch_flops_total = epoch_results["flops"] * extrapolation_factor
            epoch_macs_total = epoch_results["macs"] * extrapolation_factor
            
            train_flops += epoch_flops_total
            train_macs += epoch_macs_total
            
            # Log epoch computation with detailed breakdown
            logger.info(f"Epoch {epoch+1} computation breakdown:")
            logger.info(f"  Measured {measured_steps}/{total_steps} batches (every {LOG_FLOPS_INTERVAL} steps)")
            logger.info(f"  Measured FLOPs: {format_flops(epoch_results['flops'])}")
            logger.info(f"  Extrapolation factor: {extrapolation_factor:.2f}")
            logger.info(f"  Extrapolated total FLOPs: {format_flops(epoch_flops_total)}")
            logger.info(f"  Extrapolated total MACs: {format_flops(epoch_macs_total)}")
        
        # Evaluation at epoch end
        model.eval()
        
        # Measure actual evaluation computation
        logger.info("Measuring evaluation computation metrics...")
        eval_computation_metrics = measure_model_computation(
            model=model,
            dataloader=eval_dataloader,
            num_batches=min(5, len(eval_dataloader)),  # Measure on subset for efficiency
            metrics=["flops", "macs"],
            accelerator=accelerator,
            log_output=True
        )
        
        # Store computation metrics
        eval_flops = eval_computation_metrics["flops"] 
        eval_macs = eval_computation_metrics["macs"]
        
        # Run evaluation with precise timing using utility function
        logger.info("Running evaluation with precise timing...")
        
        # Check and log LoRA merge status before evaluation
        check_lora_merge_status(model)
        
        # Force merge for optimal inference if needed
        unwrapped_model = accelerator.unwrap_model(model)
        unwrapped_model.eval()  # This should merge LoRA weights
        
        timing_results = measure_evaluation_time(
            model=model,
            dataloader=eval_dataloader,
            device=accelerator.device,
            logger=logger
        )
        
        # Extract timing information
        eval_runtime = timing_results["total_time"]
        eval_steps_per_second = timing_results["steps_per_second"]
        avg_batch_time = timing_results["avg_batch_time"]
        
        # Process evaluation results
        eval_loss = 0.0
        eval_samples = 0
        valid_batch_added = False
        
        # Run evaluation loop for metrics calculation
        for batch in eval_dataloader:
            with torch.no_grad():
                outputs = model(**batch)
                
                # Track eval loss
                if outputs.loss is not None:
                    eval_loss += outputs.loss.detach().float().item()
                
                predictions = outputs.logits.argmax(dim=-1) if args.task_name != "stsb" else outputs.logits.squeeze()
                eval_samples += batch["input_ids"].size(0)
                
                if "labels" in batch:
                    metric.add_batch(
                        predictions=accelerator.gather(predictions),
                        references=accelerator.gather(batch["labels"]),
                    )
                    valid_batch_added = True
                else:
                    # Check for other possible label keys
                    possible_label_keys = ["label", "target", "targets"]
                    found_key = None
                    
                    for key in possible_label_keys:
                        if key in batch:
                            found_key = key
                            logger.info(f"Found alternative label key: {key}")
                            
                            # Add metric using the found key
                            metric.add_batch(
                                predictions=accelerator.gather(predictions),
                                references=accelerator.gather(batch[key]),
                            )
                            valid_batch_added = True
                            break
                    
                    if not found_key:
                        logger.warning(f"No label key found in batch. Keys: {list(batch.keys())}")
        
        if valid_batch_added:
            eval_metric = metric.compute()
            logger.info(f"Epoch {epoch + 1}: {eval_metric}")
        else:
            logger.warning("No valid batches were added to the metric. Skipping evaluation.")
            eval_metric = {"accuracy": 0.0}  # Provide default values

        # For MNLI also evaluate on the mismatched set
        if args.task_name == "mnli":
            # Reset metric for mismatched evaluation
            mismatched_metric = evaluate.load("glue", args.task_name, trust_remote_code=True)
            mismatched_loss = 0.0
            
            # Use utility function for precise timing measurement
            logger.info("Running MNLI mismatched evaluation with precise timing...")
            # Ensure merged state is maintained for mismatched evaluation
            unwrapped_model.eval()
            mismatched_timing_results = measure_evaluation_time(
                model=model,
                dataloader=eval_mismatched_dataloader,
                device=accelerator.device,
                logger=logger
            )
            
            # Extract timing information
            mismatched_total_runtime = mismatched_timing_results["total_time"]
            mismatched_steps_per_second = mismatched_timing_results["steps_per_second"]
            avg_mismatched_batch_time = mismatched_timing_results["avg_batch_time"]
            
            # Process mismatched evaluation results
            for batch in eval_mismatched_dataloader:
                with torch.no_grad():
                    outputs = model(**batch)
                    
                    if outputs.loss is not None:
                        mismatched_loss += outputs.loss.detach().float().item()
                    
                    predictions = outputs.logits.argmax(dim=-1)
                    
                    if "labels" in batch:
                        mismatched_metric.add_batch(
                            predictions=accelerator.gather(predictions),
                            references=accelerator.gather(batch["labels"]),
                        )
                    else:
                        # Check for other possible label keys
                        possible_label_keys = ["label", "target", "targets"]
                        found_key = None
                        
                        for key in possible_label_keys:
                            if key in batch:
                                found_key = key
                                logger.info(f"Found alternative label key: {key}")
                                
                                # Add metric using the found key
                                metric.add_batch(
                                    predictions=accelerator.gather(predictions),
                                    references=accelerator.gather(batch[key]),
                                )
                                valid_batch_added = True
                                break
            
            mismatched_results = mismatched_metric.compute()
            logger.info(f"Epoch {epoch + 1} MNLI mismatched: {mismatched_results}")
            
            # Add mismatched results to the eval_metric
            eval_metric["accuracy_mismatched"] = mismatched_results.get("accuracy", 0.0)

        # Add computational metrics to eval metrics
        eval_metric["eval_flops"] = eval_flops
        eval_metric["eval_flops_gflops"] = eval_flops / 1e9
        eval_metric["eval_macs"] = eval_macs
        eval_metric["eval_macs_gflops"] = eval_macs / 1e9
        eval_metric["eval_steps_per_second"] = eval_steps_per_second
        
        # Add eval loss
        eval_metric["eval_loss"] = eval_loss / len(eval_dataloader) if len(eval_dataloader) > 0 else 0
                    
        # Log metrics
        logger.info(f"Epoch {epoch+1}: {eval_metric}")
        
        # Track best model
        if args.task_name == "cola":
            current_metric = eval_metric["matthews_correlation"]
        elif args.task_name == "stsb":
            current_metric = eval_metric["pearson"]
        elif args.task_name == "mnli":
            # For MNLI, use the average of matched and mismatched accuracies
            matched_acc = eval_metric["accuracy"]
            mismatched_acc = eval_metric.get("accuracy_mismatched", 0)
            current_metric = (matched_acc + mismatched_acc) / 2
        else:
            current_metric = eval_metric["accuracy"]
        
        # if current_metric > best_metric:
        #     best_metric = current_metric
        #     # Save best model
        #     if accelerator.is_main_process:
        #         unwrapped_model = accelerator.unwrap_model(model)
        #         unwrapped_model.save_pretrained(
        #             os.path.join(args.output_dir, "best_model"),
        #             save_function=accelerator.save
        #         )
        #         tokenizer.save_pretrained(os.path.join(args.output_dir, "best_model"))
        #         logger.info(f"New best model saved with {args.task_name} metric: {current_metric:.4f}")
        
        # Properly clean memory after epoch with controlled garbage collection
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            gc.collect()
    
    # Calculate final train metrics
    train_runtime = time.time() - train_start_time
    train_samples_per_second = total_train_samples / train_runtime if train_runtime > 0 else 0
    train_steps_per_second = total_train_steps / train_runtime if train_runtime > 0 else 0
    avg_train_loss = total_train_loss / total_train_steps if total_train_steps > 0 else 0

    # Calculate model size and parameters
    model_size_mb = total_params * 4 / (1024 * 1024)  # Approximate model size (4 bytes per parameter)
    
    # Get formatted runtime
    formatted_train_runtime = format_runtime(train_runtime)
    
    # Get pruning metrics if available
    pruning_metrics = {}
    if pruning_manager is not None:
        pruning_info = pruning_manager.get_progress_info()
        pruning_metrics = pruning_info.get("pruning_metrics", {})
    
    # Prepare train metrics in the original order for consistency
    train_metrics = {
        "epoch": args.num_train_epochs,
        "total_flos_gflops": train_flops / (10**9),  # In GFLOPs
        "total_macs_gflops": train_macs / (10**9),  # MACs in GFLOPs
        "train_loss": avg_train_loss,
        "train_runtime_formatted": formatted_train_runtime,  # Human-readable format
        "train_samples": total_train_samples,
        "train_samples_per_second": train_samples_per_second,
        "train_steps_per_second": train_steps_per_second,
        "trainable_lora_parameters": trainable_params,
        "total_parameters": total_params,
        "model_size_mb": model_size_mb,
    }
    
    # Add pruning metrics if available
    if pruning_metrics:
        # Calculate accurate pruning metrics for final report
        if pruning_manager is not None:
            # Get accurate reduction values for final reporting
            current_params = pruning_utils.calculate_total_parameters(
                pruning_manager.current_r_config, pruning_manager.layer_sizes)
            actual_pruned_params = pruning_manager.initial_params - current_params
            actual_reduction_ratio = (pruning_manager.initial_params - current_params) / pruning_manager.initial_params
            
            # Update metrics with accurate values
            pruning_metrics["pruned_params"] = actual_pruned_params
            pruning_metrics["model_size_reduction"] = actual_reduction_ratio * 100  # as percentage
            
            # Log calculation for debugging
            logger.debug(f"Updated pruning metrics: pruned_params={actual_pruned_params:,}, "
                        f"reduction={actual_reduction_ratio:.4f} ({actual_reduction_ratio*100:.2f}%)")
        
        train_metrics.update({
            "pruned_params": pruning_metrics.get("pruned_params", 0),
            "pruned_model_size_mb": model_size_mb * (1 - pruning_metrics.get("model_size_reduction", 0) / 100),
            "model_size_reduction_pct": pruning_metrics.get("model_size_reduction", 0)
        })
    
    # Add memory statistics
    if torch.cuda.is_available():
        memory_stats = get_memory_stats()
        train_metrics.update({
            "peak_memory_gb": memory_stats.get("peak_allocated_gb", 0),
            "memory_reserved_gb": memory_stats.get("reserved_gb", 0),
        })
    
    # Log train metrics
    logger.info("***** Train metrics *****")
    for key, value in train_metrics.items():
        if isinstance(value, float):
            logger.info(f"  {key:<25} = {value:>10.6f}")
        else:
            logger.info(f"  {key:<25} = {value:>10}")
    
    # # Save train metrics
    # if accelerator.is_main_process:
    #     save_metrics(train_metrics, os.path.join(args.output_dir, "train_metrics.json"))
    
    # Finalize pruning if it was applied
    if pruning_manager is not None:
        logger.info("=" * 80)
        logger.info("FINALIZING PROGRESSIVE LORA PRUNING")
        logger.info("=" * 80)
        
        pruned_model, final_r_config, achieved_reduction = pruning_manager.finalize()
        
        # # Save pruned model
        # if accelerator.is_main_process:
        #     pruned_model.save_pretrained(
        #         os.path.join(pruning_manager.output_dir, "final_model"),
        #     )
        #     tokenizer.save_pretrained(os.path.join(pruning_manager.output_dir, "final_model"))
            
        #     # Save final r_config
        #     with open(os.path.join(pruning_manager.output_dir, "final_r_config.json"), 'w') as f:
        #         json.dump(final_r_config, f, indent=2)

    # Final evaluation
    logger.info("***** Running final evaluation *****")
    
    # Reset memory stats for evaluation
    if torch.cuda.is_available():
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.empty_cache()
    
    model.eval()
    metric = evaluate.load("glue", args.task_name, trust_remote_code=True)
    
    # Measure final evaluation computation more accurately
    final_eval_metrics = measure_model_computation(
        model=model,
        dataloader=eval_dataloader,
        num_batches=None,  # Measure all batches for accuracy
        metrics=["flops", "macs"],
        accelerator=accelerator,
        log_output=True
    )
    
    final_eval_total_flops = final_eval_metrics["flops"]
    final_eval_total_macs = final_eval_metrics["macs"]
    
    # For MNLI, we need to measure mismatched set computation as well
    if args.task_name == "mnli":
        logger.info("Measuring MNLI mismatched evaluation computation...")
        mismatched_eval_metrics = measure_model_computation(
            model=model,
            dataloader=eval_mismatched_dataloader,
            num_batches=None,
            metrics=["flops", "macs"],
            accelerator=accelerator,
            log_output=True
        )
        
        # Add mismatched computation to total
        final_eval_total_flops += mismatched_eval_metrics["flops"]
        final_eval_total_macs += mismatched_eval_metrics["macs"]
    
    # Use utility function for precise timing measurement of final evaluation
    logger.info("Running final evaluation with precise timing...")
    
    # Ensure optimal LoRA state for final evaluation
    unwrapped_model = accelerator.unwrap_model(model)
    logger.info("Ensuring optimal LoRA state for final evaluation...")
    check_lora_merge_status(unwrapped_model)
    unwrapped_model.eval()  # Force merge for inference
    
    timing_results = measure_evaluation_time(
        model=model,
        dataloader=eval_dataloader,
        device=accelerator.device,
        logger=logger
    )
    
    # Extract timing information
    eval_runtime = timing_results["total_time"]
    eval_samples_per_second = timing_results["samples_per_second"]
    eval_steps_per_second = timing_results["steps_per_second"]
    avg_batch_time = timing_results["avg_batch_time"]
    
    logger.info(f"Final evaluation completed: total runtime={eval_runtime:.2f}s, avg batch time={avg_batch_time:.4f}s")
    
    # Process evaluation results
    total_eval_loss = 0
    total_eval_samples = 0
    valid_batch_added = False
    
    # Run evaluation loop for metrics calculation
    for batch in eval_dataloader:
        with torch.no_grad():
            outputs = model(**batch)
            
            # Track loss
            if outputs.loss is not None:
                total_eval_loss += outputs.loss.detach().float().item()
            
            # Process predictions
            predictions = outputs.logits.argmax(dim=-1) if args.task_name != "stsb" else outputs.logits.squeeze()
            total_eval_samples += batch["input_ids"].size(0)
            
            # Update metrics
            if "labels" in batch:
                gathered_predictions = accelerator.gather(predictions)
                gathered_references = accelerator.gather(batch["labels"])
                
                metric.add_batch(
                    predictions=gathered_predictions,
                    references=gathered_references
                )
                valid_batch_added = True
    
    avg_eval_loss = total_eval_loss / len(eval_dataloader) if len(eval_dataloader) > 0 else 0
    formatted_eval_runtime = format_runtime(eval_runtime)

    # Calculate per-sample metrics with correct total samples
    if total_eval_samples > 0:
        final_eval_flops_per_sample = final_eval_total_flops / total_eval_samples
        final_eval_macs_per_sample = final_eval_total_macs / total_eval_samples
        
        logger.info(f"Total evaluation computation: {format_flops(final_eval_total_flops)} FLOPs, {format_flops(final_eval_total_macs)} MACs")
        logger.info(f"Per-sample evaluation: {format_flops(final_eval_flops_per_sample)} FLOPs/sample, {format_flops(final_eval_macs_per_sample)} MACs/sample")
        logger.info(f"Total eval samples (matched + mismatched): {total_eval_samples}")

    # Calculate and save pruning metrics after final evaluation (if pruning was applied)
    if pruning_manager is not None and accelerator.is_main_process:
        # Get task performance metric
        if args.task_name == "cola":
            task_performance = eval_metric.get("matthews_correlation", 0)
        elif args.task_name == "stsb":
            task_performance = eval_metric.get("pearson", 0)
        elif args.task_name == "mnli":
            matched_acc = eval_metric.get("accuracy", 0)
            mismatched_acc = eval_metric.get("accuracy_mismatched", 0)
            task_performance = (matched_acc + mismatched_acc) / 2
        else:
            task_performance = eval_metric.get("accuracy", 0)
        
        # Calculate efficiency metrics based on final evaluation
        accuracy_per_gflops = task_performance / (final_eval_total_flops / 1e9) if final_eval_total_flops > 0 else 0
        accuracy_per_gmacs = task_performance / (final_eval_total_macs / 1e9) if final_eval_total_macs > 0 else 0

        # Prepare and save complete pruning metrics
        pruning_metrics = {
            "initial_parameters": pruning_manager.initial_params,
            "final_parameters": pruning_utils.calculate_total_parameters(
                final_r_config, pruning_manager.layer_sizes),
            "reduction_rate": float(achieved_reduction),
            "target_reduction_rate": args.pruning_target_reduction,
            "pruning_steps": args.pruning_steps,
            "baseline_performance": float(pruning_manager.baseline_performance),
            "final_performance": task_performance,  # Use final evaluation performance
            "eval_steps_per_second": eval_steps_per_second,
            "accuracy_per_gflops": accuracy_per_gflops,
            "accuracy_per_gmacs": accuracy_per_gmacs,
            "final_eval_flops_per_sample_gflops": final_eval_flops_per_sample / 1e9,
            "final_eval_macs_per_sample_gflops": final_eval_macs_per_sample / 1e9,
            "final_eval_total_flops_gflops": final_eval_total_flops / 1e9,
            "final_eval_total_macs_gflops": final_eval_total_macs / 1e9,
            "pruned_model_size_mb": model_size_mb * (1 - achieved_reduction),
            "model_size_reduction_pct": achieved_reduction * 100
        }
        
        # Save pruning metrics
        save_metrics(pruning_metrics, os.path.join(pruning_manager.output_dir, "pruning_metrics.json"))
        
        # Log final pruning efficiency metrics
        logger.info("=" * 80)
        logger.info("PRUNING EFFICIENCY METRICS (FINAL EVALUATION):")
        logger.info("=" * 80)
        logger.info(f"  Eval steps per second: {eval_steps_per_second:.2f}")
        logger.info(f"  Accuracy per GFLOPs: {accuracy_per_gflops:.6f}")
        logger.info(f"  Accuracy per GMACs: {accuracy_per_gmacs:.6f}")
        logger.info(f"  Pruned model size: {pruning_metrics['pruned_model_size_mb']:.2f} MB")
        logger.info(f"  Size reduction: {achieved_reduction*100:.2f}%")
        logger.info(f"  Final performance: {task_performance:.4f}")
        logger.info("=" * 80)

    # Calculate metrics
    if valid_batch_added:
        eval_metric = metric.compute()
        
        # For MNLI, also evaluate mismatched set
        mismatched_accuracy = None
        if args.task_name == "mnli":
            # Run evaluation on mismatched dataset with precise timing
            mismatched_metric = evaluate.load("glue", args.task_name, trust_remote_code=True)
            mismatched_batch_times = []
            mismatched_total_runtime = 0.0
            
            for batch_idx, batch in enumerate(eval_mismatched_dataloader):
                # Start timing this batch
                if torch.cuda.is_available():
                    torch.cuda.synchronize()
                batch_start = time.perf_counter()
                
                with torch.no_grad():
                    outputs = model(**batch)
                    
                    # Ensure computation is complete
                    if torch.cuda.is_available():
                        torch.cuda.synchronize()
                
                # Stop timing for this batch
                batch_end = time.perf_counter()
                batch_time = batch_end - batch_start
                mismatched_batch_times.append(batch_time)
                mismatched_total_runtime += batch_time
                
                predictions = outputs.logits.argmax(dim=-1)
                total_eval_samples += batch["input_ids"].size(0)  # Add to total samples
                
                if "labels" in batch:
                    mismatched_metric.add_batch(
                        predictions=accelerator.gather(predictions),
                        references=accelerator.gather(batch["labels"]),
                    )
            
            # Calculate mismatched runtime statistics
            avg_mismatched_batch_time = sum(mismatched_batch_times) / len(mismatched_batch_times) if mismatched_batch_times else 0
            mismatched_steps_per_second = len(eval_mismatched_dataloader) / mismatched_total_runtime if mismatched_total_runtime > 0 else 0
            
            logger.info(f"MNLI mismatched final evaluation: runtime={mismatched_total_runtime:.2f}s, " 
                       f"avg batch time={avg_mismatched_batch_time:.4f}s, "
                       f"steps/sec={mismatched_steps_per_second:.2f}")
            
            # Get mismatched results
            mismatched_results = mismatched_metric.compute()
            mismatched_accuracy = mismatched_results.get("accuracy", 0.0)
            
        # Prepare combined score based on task
        if args.task_name == "cola":
            combined_score = eval_metric["matthews_correlation"]
        elif args.task_name in ["mrpc", "qqp"]:
            combined_score = (eval_metric["accuracy"] + eval_metric["f1"]) / 2
        elif args.task_name == "mnli" and mismatched_accuracy is not None:
            # For MNLI, use average of matched and mismatched accuracies
            combined_score = (eval_metric["accuracy"] + mismatched_accuracy) / 2
        else:
            combined_score = eval_metric.get("accuracy", 0)
            
        # Determine accuracy key name based on task
        accuracy_key = "eval_accuracy_matched" if args.task_name == "mnli" else "eval_accuracy"
        
        # Prepare eval metrics in the original order
        eval_metrics = {
            "epoch": args.num_train_epochs,
            "eval_flops_gflops": final_eval_total_flops / (10**9),  # Total GFLOPs
            "eval_macs_gflops": final_eval_total_macs / (10**9),  # Total GMACs
            accuracy_key: eval_metric.get("accuracy", eval_metric.get("matthews_correlation", 0)),
            "eval_combined_score": combined_score,
            "eval_loss": avg_eval_loss,
            "eval_runtime_formatted": formatted_eval_runtime,  # Human-readable format
            "eval_samples": total_eval_samples,
            "eval_samples_per_second": eval_samples_per_second,
            "eval_steps_per_second": len(eval_dataloader) / eval_runtime if eval_runtime > 0 else 0
        }
        
        # Add per-sample metrics as additional fields
        eval_metrics.update({
            "eval_flops_per_sample_gflops": final_eval_flops_per_sample / (10**9),
            "eval_macs_per_sample_gflops": final_eval_macs_per_sample / (10**9),
        })
        
        # Add specific metrics based on task
        if "f1" in eval_metric:
            eval_metrics["eval_f1"] = eval_metric["f1"]
        if "matthews_correlation" in eval_metric:
            eval_metrics["eval_matthews_correlation"] = eval_metric["matthews_correlation"]
        if "pearson" in eval_metric:
            eval_metrics["eval_pearson"] = eval_metric["pearson"]
        if "spearmanr" in eval_metric:
            eval_metrics["eval_spearmanr"] = eval_metric["spearmanr"]
        if args.task_name == "mnli" and mismatched_accuracy is not None:
            eval_metrics["eval_accuracy_mismatched"] = mismatched_accuracy
        
        # Add memory statistics
        if torch.cuda.is_available():
            memory_stats = get_memory_stats()
            eval_metrics.update({
                "eval_peak_memory_gb": memory_stats.get("peak_allocated_gb", 0),
                "eval_memory_reserved_gb": memory_stats.get("reserved_gb", 0),
            })
        
        # Log eval metrics
        logger.info("***** Eval metrics *****")
        for key, value in eval_metrics.items():
            if isinstance(value, float):
                logger.info(f"  {key:<25} = {value:>10.6f}")
            else:
                logger.info(f"  {key:<25} = {value:>10}")
                
    else:
        logger.warning("No valid batches were added to the metric. Final evaluation skipped.")
        eval_metrics = {"eval_accuracy": 0.0}  # Provide default value

    # Save results
    if accelerator.is_main_process:
        # # Save final model
        # unwrapped_model = accelerator.unwrap_model(model)
        # unwrapped_model.save_pretrained(
        #     os.path.join(args.output_dir, "final_model"),
        #     save_function=accelerator.save
        # )
        # tokenizer.save_pretrained(os.path.join(args.output_dir, "final_model"))
        
        # Save metrics
        save_metrics(eval_metrics, os.path.join(args.output_dir, "eval_metrics.json"))
        save_metrics({**train_metrics, **eval_metrics}, os.path.join(args.output_dir, "all_metrics.json"))
    
    logger.info(f"Training completed. Results saved to {args.output_dir}")


if __name__ == "__main__":
    main()