#!/usr/bin/env python
# coding=utf-8

"""
Fine-tuning LLM models for Alpaca instruction-following task using optimal LoRA configurations.
Optimized for memory efficiency and computational performance while maintaining generation quality.
"""
import argparse
import logging
import math
import os
import random
import json
import time
import sys
import gc
from functools import partial
from pathlib import Path
import numpy as np

# Load environment variables from .env file
from dotenv import load_dotenv
load_dotenv()

import datasets
import torch
from datasets import load_dataset, load_metric
import evaluate
from torch.utils.data import DataLoader
import torch.profiler as profiler
import transformers
from accelerate import Accelerator
from transformers import (
    AdamW,
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
    DataCollatorForLanguageModeling,
    PretrainedConfig,
    SchedulerType,
    default_data_collator,
    get_scheduler,
    set_seed,
)
import subprocess
import tempfile
import shutil

from trainer.optimizer import LoRAOptimizer
from models.optimal_lora import prepare_model_for_alpaca
from utils.logging_utils import setup_logger, measure_batch_inference_time, measure_evaluation_time
import loralib as lora

# Import progressive pruning modules
from pruning.progressive_pruning import ProgressivePruningManager
from pruning import pruning_utils

# Control flags for optimized execution
LOG_FLOPS_INTERVAL = 50  # How often to log computation metrics during training
MAX_OPTIMIZATION_SAMPLES = 500  # Maximum samples to use for LoRA optimization

# Alpaca dataset configuration
ALPACA_PROMPT_TEMPLATE = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
{instruction}

### Input:
{input}

### Response:
{output}"""

ALPACA_PROMPT_NO_INPUT_TEMPLATE = """Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
{instruction}

### Response:
{output}"""

logger = logging.getLogger(__name__)

def run_mt_bench_evaluation(model, tokenizer, output_dir, accelerator, logger, skip_judgment=False, fastchat_path=None):
    """
    Run MT-Bench evaluation on the final trained model.
    
    Args:
        model: The trained model (can be wrapped by accelerator)
        tokenizer: The tokenizer
        output_dir: Directory to save MT-Bench results
        accelerator: Accelerator object
        logger: Logger object
        skip_judgment: Skip GPT-4 judgment if True
        fastchat_path: Path to FastChat repository
        
    Returns:
        Dictionary with MT-Bench results or None if failed
    """
    try:
        # Find FastChat path if not provided
        if fastchat_path is None:
            possible_paths = [
                "./FastChat",
                "../FastChat",
                "../../FastChat",
                "/home/work/cjpark/llama3_DPlora_0.8_v3.2_v1_alpaca_v2/FastChat",
                "/home/work/cjpark/llama3_DPlora_0.8_v3.2_v1_alpaca_v4/FastChat",
            ]
            for path in possible_paths:
                if os.path.exists(path) and os.path.exists(os.path.join(path, "fastchat")):
                    fastchat_path = path
                    break
            
            if fastchat_path is None:
                logger.error("FastChat repository not found. Please clone: git clone https://github.com/lm-sys/FastChat.git")
                return None
        
        fastchat_path = os.path.abspath(fastchat_path)
        llm_judge_path = os.path.join(fastchat_path, "fastchat", "llm_judge")
        
        # Create output directory for MT-Bench results
        os.makedirs(output_dir, exist_ok=True)
        
        # Create temporary directory to save model
        temp_model_dir = tempfile.mkdtemp(prefix="mt_bench_model_")
        logger.info(f"Saving model to temporary directory: {temp_model_dir}")
        
        # Unwrap and save model
        unwrapped_model = accelerator.unwrap_model(model)
        unwrapped_model.save_pretrained(temp_model_dir)
        tokenizer.save_pretrained(temp_model_dir)
        
        # Generate unique model ID with timestamp
        model_id = f"alpaca_final_{time.strftime('%Y%m%d_%H%M%S')}"
        
        results = {
            "model_id": model_id,
            "model_path": temp_model_dir,
            "timestamp": time.strftime('%Y-%m-%d %H:%M:%S')
        }
        
        # Step 1: Generate model answers
        logger.info("Generating MT-Bench answers...")
        gen_answer_script = os.path.join(llm_judge_path, "gen_model_answer.py")
        
        cmd = [
            "python", gen_answer_script,
            "--model-path", temp_model_dir,
            "--model-id", model_id,
            "--num-gpus-per-model", "1",
            "--num-gpus-total", str(torch.cuda.device_count() if torch.cuda.is_available() else 1),
        ]
        
        logger.info(f"Running command: {' '.join(cmd)}")
        
        # Run answer generation
        result = subprocess.run(
            cmd,
            cwd=llm_judge_path,
            capture_output=True,
            text=True
        )
        
        if result.returncode != 0:
            logger.error(f"Answer generation failed: {result.stderr}")
            shutil.rmtree(temp_model_dir, ignore_errors=True)
            return results
        
        answer_file = os.path.join(llm_judge_path, "data/mt_bench/model_answer", f"{model_id}.jsonl")
        results["answer_file"] = answer_file
        logger.info(f"Answers generated: {answer_file}")
        
        # Step 2: Generate GPT-4 judgments (optional)
        if not skip_judgment and "OPENAI_API_KEY" in os.environ:
            logger.info("Getting GPT-4 judgments...")
            gen_judgment_script = os.path.join(llm_judge_path, "gen_judgment.py")
            
            cmd = [
                "python", gen_judgment_script,
                "--model-list", model_id,
                "--parallel", "2",
                "--mode", "single",
            ]
            
            result = subprocess.run(
                cmd,
                cwd=llm_judge_path,
                capture_output=True,
                text=True,
                env=os.environ.copy()
            )
            
            if result.returncode == 0:
                # Step 3: Calculate scores
                show_result_script = os.path.join(llm_judge_path, "show_result.py")
                
                cmd = [
                    "python", show_result_script,
                    "--model-list", model_id,
                    "--mode", "single",
                ]
                
                result = subprocess.run(
                    cmd,
                    cwd=llm_judge_path,
                    capture_output=True,
                    text=True
                )
                
                if result.returncode == 0:
                    # Parse scores from output
                    output_lines = result.stdout.split('\n')
                    for line in output_lines:
                        if model_id in line:
                            # Try to extract score
                            parts = line.split()
                            for part in parts:
                                try:
                                    score = float(part)
                                    if 0 <= score <= 10:
                                        results["overall_score"] = score
                                        logger.info(f"MT-Bench Overall Score: {score:.2f}")
                                        break
                                except ValueError:
                                    continue
                else:
                    logger.warning("Score calculation failed")
            else:
                logger.warning("GPT-4 judgment failed or skipped")
        else:
            if skip_judgment:
                logger.info("Skipping GPT-4 judgment as requested")
            else:
                logger.warning("Skipping GPT-4 judgment - OpenAI API key not set")
            results["judgment_skipped"] = True
        
        # Save results to JSON
        results_file = os.path.join(output_dir, f"mt_bench_results_{model_id}.json")
        with open(results_file, 'w') as f:
            json.dump(results, f, indent=2)
        logger.info(f"MT-Bench results saved to {results_file}")
        
        # Clean up temporary model directory
        shutil.rmtree(temp_model_dir, ignore_errors=True)
        logger.info("Cleaned up temporary model directory")
        
        return results
        
    except Exception as e:
        logger.error(f"MT-Bench evaluation failed with error: {e}")
        # Clean up on error
        if 'temp_model_dir' in locals() and os.path.exists(temp_model_dir):
            shutil.rmtree(temp_model_dir, ignore_errors=True)
        return None

def parse_args():
    """Parse command line arguments."""
    parser = argparse.ArgumentParser(description="Fine-tune LLM models with optimal LoRA for Alpaca instruction-following task")
    
    # Task and model arguments
    parser.add_argument(
        "--task_name",
        type=str,
        default="alpaca",
        help="Task name (default: alpaca)",
    )
    parser.add_argument(
        "--model_name_or_path",
        type=str,
        default="meta-llama/Llama-2-7b-hf",
        help="Path to pretrained model or model identifier from huggingface.co/models",
    )
    parser.add_argument(
        "--config_name",
        type=str,
        default=None,
        help="Pretrained config name or path if not the same as model_name",
    )
    parser.add_argument(
        "--tokenizer_name",
        type=str,
        default=None,
        help="Pretrained tokenizer name or path if not the same as model_name",
    )
    parser.add_argument(
        "--max_seq_length",
        type=int,
        default=512,
        help="Maximum sequence length after tokenization",
    )
    
    # LoRA optimization arguments
    parser.add_argument(
        "--lora_r_values",
        type=str,
        default="0,1,2,4,8,16,32",
        help="Comma-separated list of r values to consider for LoRA optimization",
    )
    parser.add_argument(
        "--lora_alpha",
        type=float,
        default=None,
        help="LoRA scaling parameter (alpha). If None, uses 2*r for each layer",
    )
    parser.add_argument(
        "--lora_dropout",
        type=float,
        default=0.1,
        help="Dropout probability for LoRA layers",
    )
    parser.add_argument(
        "--lora_budget",
        type=float,
        default=400000000.0,  # ~5% of LLama model parameters
        help="Budget constraint for LoRA optimization",
    )
    parser.add_argument(
        "--use_existing_lora_config",
        type=str,
        default=None,
        help="Path to existing LoRA configuration file (skips optimization phase)",
    )
    
    # Training arguments
    parser.add_argument(
        "--output_dir", 
        type=str, 
        default=None,
        help="Directory to save model, logs, and results",
    )
    parser.add_argument(
        "--overwrite_output_dir",
        action="store_true",
        help="Overwrite the content of the output directory",
    )
    parser.add_argument(
        "--per_device_train_batch_size",
        type=int,
        default=4,
        help="Batch size per GPU/TPU for training",
    )
    parser.add_argument(
        "--per_device_eval_batch_size",
        type=int,
        default=8,
        help="Batch size per GPU/TPU for evaluation",
    )
    parser.add_argument(
        "--learning_rate",
        type=float,
        default=2e-5,
        help="Initial learning rate (after warmup period)",
    )
    parser.add_argument(
        "--weight_decay",
        type=float,
        default=0.0,
        help="Weight decay to apply",
    )
    parser.add_argument(
        "--max_grad_norm",
        type=float,
        default=1.0,
        help="Maximum gradient norm for gradient clipping",
    )
    parser.add_argument(
        "--num_train_epochs",
        type=int,
        default=3,
        help="Total number of training epochs",
    )
    parser.add_argument(
        "--max_train_steps",
        type=int,
        default=None,
        help="Total number of training steps (overrides num_train_epochs)",
    )
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=1,
        help="Number of steps for gradient accumulation before optimization",
    )
    parser.add_argument(
        "--lr_scheduler_type",
        type=SchedulerType,
        default="linear",
        help="LR scheduler type",
        choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
    )
    parser.add_argument(
        "--num_warmup_steps",
        type=int,
        default=0,
        help="Number of steps for the warmup in the LR scheduler",
    )
    parser.add_argument(
        "--max_train_samples",
        type=int,
        default=None,
        help="For debugging: limit the number of training examples",
    )
    parser.add_argument(
        "--max_eval_samples",
        type=int,
        default=None,
        help="For debugging: limit the number of evaluation examples",
    )
    parser.add_argument(
        "--dataset_path",
        type=str,
        default=None,
        help="Path to local Alpaca dataset JSON file",
    )
    
    # Other arguments
    parser.add_argument(
        "--seed",
        type=int,
        default=42,
        help="Random seed for initialization",
    )
    parser.add_argument(
        "--optimize_only",
        action="store_true",
        help="Only perform LoRA optimization without training",
    )
    parser.add_argument(
        "--log_level",
        type=str,
        default="info",
        choices=["debug", "info", "warning", "error", "critical"],
        help="Log level",
    )

    # Pruning arguments
    parser.add_argument(
        "--apply_pruning",
        action="store_true",
        help="Apply progressive pruning during training",
    )
    parser.add_argument(
        "--pruning_target_reduction",
        type=float,
        default=0.5,
        help="Target parameter reduction for progressive pruning (0.0-1.0)",
    )
    parser.add_argument(
        "--pruning_steps",
        type=int,
        default=4,
        help="Number of pruning steps for progressive pruning",
    )
    parser.add_argument(
        "--pruning_output_dir",
        type=str,
        default=None,
        help="Directory to save pruning results (defaults to output_dir/pruning)",
    )
    parser.add_argument(
        "--importance_ema_decay",
        type=float,
        default=0.9,
        help="EMA decay factor for layer importance scores (0.0-1.0)",
    )
    parser.add_argument(
        "--momentum_penalty_weight",
        type=float,
        default=0.1,
        help="Weight for momentum-based penalty in pruning",
    )

    parser.add_argument(
        "--recovery_steps",
        type=int,
        default=500,
        help="Number of recovery steps after pruning",
    )
    parser.add_argument(
        "--extended_recovery_steps",
        type=int,
        default=1000,
        help="Number of extended recovery steps after rollback",
    )

    parser.add_argument(
        "--disable_rollback",
        action="store_true",
        help="Disable rollback when pruning performance drops",
    )
    
    # Additional training arguments
    parser.add_argument(
        "--logging_steps",
        type=int,
        default=50,
        help="Log every X updates steps",
    )
    parser.add_argument(
        "--save_steps",
        type=int,
        default=500,
        help="Save checkpoint every X updates steps",
    )
    parser.add_argument(
        "--eval_steps",
        type=int,
        default=500,
        help="Run evaluation every X steps",
    )
    parser.add_argument(
        "--warmup_ratio",
        type=float,
        default=0.0,
        help="Ratio of total training steps used for warmup",
    )
    
    # MT-Bench evaluation arguments
    parser.add_argument(
        "--run_mt_bench",
        action="store_true",
        help="Run MT-Bench evaluation after final training",
    )
    parser.add_argument(
        "--mt_bench_output_dir",
        type=str,
        default=None,
        help="Directory to save MT-Bench results (default: output_dir/mt_bench)",
    )
    parser.add_argument(
        "--skip_mt_bench_judgment", 
        action="store_true",
        help="Skip GPT-4 judgment (useful when OpenAI API key not available)",
    )
    parser.add_argument(
        "--fastchat_path",
        type=str,
        default=None,
        help="Path to FastChat repository (will auto-detect if not provided)",
    )

    args = parser.parse_args()
    
    # Process arguments
    if args.output_dir is None:
        args.output_dir = os.path.join(
            "runs",
            f"{args.task_name}_{args.model_name_or_path.split('/')[-1]}_{time.strftime('%Y%m%d-%H%M%S')}"
        )
    
    # Parse r_values
    args.lora_r_values = [int(r) for r in args.lora_r_values.split(",")]
    
    # Set MT-Bench output directory if not specified
    if args.mt_bench_output_dir is None:
        args.mt_bench_output_dir = os.path.join(args.output_dir, "mt_bench")
    
    return args

def set_deterministic_environment(seed=42):
    os.environ["PYTHONHASHSEED"] = str(seed)
    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" 
    

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    

    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        try:
            torch.use_deterministic_algorithms(True)
        except:
            pass  

    logger.info(f"Deterministic environment set with seed {seed}")

def seed_worker(worker_id):
    worker_seed = args.seed + worker_id
    np.random.seed(worker_seed)
    random.seed(worker_seed)

def format_flops(flops):
    """Format FLOPS value with appropriate units."""
    if flops >= 1e12:
        return f"{flops / 1e12:.4f} TFLOPs"
    elif flops >= 1e9:
        return f"{flops / 1e9:.4f} GFLOPs"
    elif flops >= 1e6:
        return f"{flops / 1e6:.4f} MFLOPs"
    else:
        return f"{flops:,.0f} FLOPs"


def format_runtime(seconds):
    """Return runtime in seconds with two decimal places."""
    return f"{seconds:.2f}"

def get_memory_stats():
    """Get essential memory statistics efficiently."""
    if not torch.cuda.is_available():
        return {}
    
    return {
        "peak_allocated_gb": torch.cuda.max_memory_allocated() / (1024**3),
        "reserved_gb": torch.cuda.memory_reserved() / (1024**3)
    }


def setup_minimal_logging():
    """Configure minimal logging to retain important warnings."""
    # Allow transformers warnings but suppress less critical messages
    transformers.logging.set_verbosity_warning()
    
    # Silence DeepSpeed profiler logging
    if 'deepspeed' in sys.modules and not _ds_logger_silenced:
        logging.getLogger('deepspeed').setLevel(logging.ERROR)


def validate_model_parameters(model):
    """Validate model parameters to ensure training stability."""
    # Check for NaN or Inf values in model parameters
    has_problem = False
    for name, param in model.named_parameters():
        if param.requires_grad:
            if torch.isnan(param).any() or torch.isinf(param).any():
                logger.warning(f"Parameter {name} contains NaN or Inf values!")
                has_problem = True
    
    if has_problem:
        logger.warning("Model contains problematic parameters. Consider adjusting learning rate.")
    
    return not has_problem


def get_layer_size(model, layer_name):
    """Get input and output dimensions of a named layer."""
    try:
        names = layer_name.split('.')
        module = model
        for name in names:
            module = getattr(module, name)
        
        if hasattr(module, 'in_features') and hasattr(module, 'out_features'):
            return module.in_features, module.out_features
        return 0, 0
    except (AttributeError, ValueError):
        return 0, 0

def check_lora_merge_status(model):
    """Check if LoRA layers are in merged state"""
    logger.info("=" * 80)
    logger.info("CHECKING LoRA MERGE STATUS")
    logger.info("=" * 80)
    
    merged_count = 0
    unmerged_count = 0
    
    for name, module in model.named_modules():
        if hasattr(module, 'lora_A') and hasattr(module, 'lora_B'):
            if hasattr(module, 'merged'):
                if module.merged:
                    merged_count += 1
                    logger.info(f"Layer {name} is MERGED")
                else:
                    unmerged_count += 1
                    logger.info(f"Layer {name} is UNMERGED")
            else:
                logger.info(f"Layer {name} merge status unknown")
    
    logger.info(f"Merged layers: {merged_count}, Unmerged layers: {unmerged_count}")
    logger.info("=" * 80)
    
    return merged_count, unmerged_count

def measure_computation(model, batch, metrics=["flops", "macs"], accelerator=None, log_output=False, measure_backward=False):
    """
    Accurately measure computation metrics using PyTorch Profiler with improved event filtering and error handling.
    Fixed to handle LoRA merge/unmerge operations outside of profiler context.
    
    Args:
        model: Model to profile
        batch: Input batch (exactly as it would be used in training/evaluation)
        metrics: List of metrics to calculate ("flops", "macs", or both)
        accelerator: Optional accelerator for model unwrapping
        log_output: Whether to log the output
        measure_backward: Whether to measure backward pass (only works when model in training mode)
        
    Returns:
        Dictionary with measured metrics and detailed breakdown
    """
    import torch.profiler as profiler
    import logging
    
    logger = logging.getLogger(__name__)
    
    result = {
        "flops": 0, 
        "macs": 0,
        "lora_flops": 0,
        "original_flops": 0,
        "attention_flops": 0,
        "correction_factors": {}
    }
    
    batch_size = batch["input_ids"].size(0) if "input_ids" in batch else 1
    seq_length = batch["input_ids"].size(1) if "input_ids" in batch else 0
    
    # Calculate effective sequence length considering padding
    if "attention_mask" in batch:
        non_pad_tokens = batch["attention_mask"].sum().item()
        avg_seq_length = non_pad_tokens / batch_size
    else:
        avg_seq_length = seq_length
    
    # Get model in correct state - do this BEFORE profiling
    unwrapped_model = accelerator.unwrap_model(model) if accelerator else model
    model_was_training = unwrapped_model.training
    
    # IMPORTANT: Set model mode BEFORE creating the profiler context
    if not measure_backward:
        unwrapped_model.eval()
    
    # Wait for any pending operations to complete
    if torch.cuda.is_available():
        torch.cuda.synchronize()
    
    # Prepare inputs for profiling
    allowed_keys = {"input_ids", "attention_mask", "token_type_ids", "labels"}
    profiling_batch = {k: v for k, v in batch.items() if k in allowed_keys}
    if not profiling_batch and batch:
        profiling_batch = batch
    
    try:
        # Configure PyTorch profiler with simpler settings
        activities = [profiler.ProfilerActivity.CPU]
        if torch.cuda.is_available():
            activities.append(profiler.ProfilerActivity.CUDA)
        
        # Create profiler with simplified configuration
        with profiler.profile(
            activities=activities,
            record_shapes=True,
            profile_memory=False,
            with_flops=True,
            # Remove with_modules and with_stack to avoid internal errors
        ) as prof:
            # Synchronize before profiling
            if torch.cuda.is_available():
                torch.cuda.synchronize()
            
            # Handle forward and backward passes
            if measure_backward and model_was_training:
                outputs = unwrapped_model(**profiling_batch)
                
                # Create dummy loss and run backward pass
                if hasattr(outputs, 'logits') and outputs.logits is not None:
                    if outputs.logits.size(0) > 0:
                        if outputs.logits.size(-1) > 1:  # Classification
                            dummy_target = torch.zeros(outputs.logits.size(0), dtype=torch.long, device=outputs.logits.device)
                            loss = torch.nn.functional.cross_entropy(outputs.logits, dummy_target)
                        else:  # Regression
                            dummy_target = torch.zeros_like(outputs.logits)
                            loss = torch.nn.functional.mse_loss(outputs.logits, dummy_target)
                        
                        # Run backward pass
                        loss.backward()
            else:
                # Just run forward pass with no_grad
                with torch.no_grad():
                    _ = unwrapped_model(**profiling_batch)
            
            # Synchronize after profiling
            if torch.cuda.is_available():
                torch.cuda.synchronize()
        
        # Process profiler events to get FLOPs
        total_flops = 0
        lora_specific_flops = 0
        attention_flops = 0
        
        for evt in prof.key_averages():
            if evt.flops > 0:
                total_flops += evt.flops
                
                # Get event name
                event_key = evt.key.lower()
                
                # Track LoRA-specific operations
                if any(x in event_key for x in ['lora', 'low_rank', 'adapter']):
                    lora_specific_flops += evt.flops
                    if log_output:
                        logger.debug(f"LoRA operation detected: {evt.key}, FLOPs: {evt.flops}")
                
                # Track attention operations
                if any(x in event_key for x in ['attention', 'softmax', 'attn']):
                    attention_flops += evt.flops
                    if log_output:
                        logger.debug(f"Attention operation detected: {evt.key}, FLOPs: {evt.flops}")
        
        # Store detected FLOPs in result
        result["attention_flops"] = attention_flops
        
        # Get model type for architecture-specific corrections
        model_type = "unknown"
        if hasattr(unwrapped_model, 'config'):
            if hasattr(unwrapped_model.config, 'model_type'):
                model_type = unwrapped_model.config.model_type.lower()
        
        # Apply attention mask correction
        if "attention_mask" in profiling_batch:
            mask_ratio = avg_seq_length / seq_length if seq_length > 0 else 1.0
            
            attention_mask = profiling_batch["attention_mask"]
            attention_ratio = mask_ratio
            
            if attention_mask.dim() == 2:  # [batch, seq_len]
                valid_tokens = attention_mask.sum(dim=1).float().mean().item()
                attention_ratio = valid_tokens / seq_length if seq_length > 0 else 1.0
                attention_correction = attention_ratio ** 2
            else:
                attention_correction = mask_ratio ** 2
            
            # Calculate final correction
            linear_weight = 0.65
            attention_weight = 0.35
            
            correction_factor = linear_weight * mask_ratio + attention_weight * attention_correction
            correction_factor = max(0.1, min(1.0, correction_factor))
            
            total_flops = int(total_flops * correction_factor)
            
            result["correction_factors"]["attention_mask"] = correction_factor
            
            if log_output:
                logger.info(f"Applied attention mask correction factor: {correction_factor:.4f} "
                           f"(mask_ratio={mask_ratio:.4f}, attention_ratio={attention_ratio:.4f})")
        
        # Apply LoRA architecture correction
        has_lora = False
        lora_layers = 0
        lora_flops_calculated = 0
        original_flops_calculated = 0
        
        # Extract r_config from model
        r_config = {}
        if hasattr(unwrapped_model, 'initial_r_config'):
            r_config = unwrapped_model.initial_r_config
        else:
            for name, module in unwrapped_model.named_modules():
                if hasattr(module, 'lora_A') and hasattr(module, 'lora_B'):
                    r = module.lora_A.shape[0] if hasattr(module, 'lora_A') else 0
                    if r > 0:
                        r_config[name] = r
        
        # Calculate LoRA correction
        for name, module in unwrapped_model.named_modules():
            if hasattr(module, 'lora_A') and hasattr(module, 'lora_B'):
                has_lora = True
                lora_layers += 1
                
                # Get dimensions
                if hasattr(module, 'weight'):
                    in_features = module.weight.size(1)
                    out_features = module.weight.size(0)
                else:
                    in_features, out_features = get_layer_dimensions(module)
                
                r = r_config.get(name, 0)
                
                # Fallback matching for complex layer names
                if r == 0:
                    for config_name, config_r in r_config.items():
                        if config_name in name or name in config_name:
                            r = config_r
                            if log_output:
                                logger.debug(f"Found r={r} for layer {name} via partial match with {config_name}")
                            break
                
                # If still no match but layer has LoRA matrices, use matrix dimensions
                if r == 0 and hasattr(module, 'lora_A'):
                    r = module.lora_A.shape[0]
                
                # Calculate actual computation dimensions
                batch_size_calc = profiling_batch["input_ids"].size(0) if "input_ids" in profiling_batch else 1
                seq_len_calc = profiling_batch["input_ids"].size(1) if "input_ids" in profiling_batch else 128
                
                # Use attention mask for effective sequence length
                if "attention_mask" in profiling_batch:
                    effective_seq_len = profiling_batch["attention_mask"].sum().item() / batch_size_calc
                else:
                    effective_seq_len = seq_len_calc
                
                # Original computation (without LoRA)
                original_layer_flops = 2 * batch_size_calc * effective_seq_len * in_features * out_features
                original_flops_calculated += original_layer_flops
                
                # LoRA computation with all operations
                if r > 0:
                    layer_lora_flops = calculate_lora_layer_flops(
                        in_features, out_features, r, batch_size_calc, effective_seq_len
                    )
                    lora_flops_calculated += layer_lora_flops
                    
                    if log_output:
                        logger.info(f"LoRA layer {name}: r={r}, "
                                  f"original_flops={original_layer_flops:,}, "
                                  f"lora_flops={layer_lora_flops:,}")
        
        result["lora_flops"] = lora_flops_calculated
        result["original_flops"] = original_flops_calculated
        
        # Apply LoRA correction if applicable
        if has_lora and original_flops_calculated > 0 and lora_flops_calculated > 0:
            actual_reduction = lora_flops_calculated / original_flops_calculated
            lora_correction = actual_reduction * 0.95
            lora_correction = max(0.1, min(0.9, lora_correction))
            
            total_flops = int(total_flops * lora_correction)
            
            result["correction_factors"]["lora"] = lora_correction
            
            if log_output:
                logger.info(f"LoRA correction applied: reduction={actual_reduction:.4f}, "
                           f"correction={lora_correction:.4f}")
                logger.info(f"LoRA stats: {lora_layers} layers, "
                           f"original_flops={original_flops_calculated:,}, "
                           f"lora_flops={lora_flops_calculated:,}")
        
        # Store measured flops
        result["flops"] = int(total_flops)
        
        # Calculate MACs
        if "macs" in metrics:
            try:
                mac_ratio = calculate_mac_ratio(unwrapped_model, profiling_batch)
            except Exception:
                logger.warning("calculate_mac_ratio() failed. Using fallback MAC ratio of 0.5")
                mac_ratio = 0.5
            
            macs = int(total_flops * mac_ratio)
            result["macs"] = macs
            
            if log_output:
                logger.info(f"MACs estimated from FLOPs: {format_flops(macs)} (using {mac_ratio:.2f} ratio)")
        
        if log_output:
            logger.info(f"Measured FLOPs: {format_flops(total_flops)} "
                       f"(batch_size={batch_size}, avg_seq_length={avg_seq_length:.1f})")
            
    except Exception as e:
        logger.error(f"Computation measurement error: {e}")
        if log_output and logger.level <= logging.DEBUG:
            import traceback
            logger.debug(f"Measurement error details: {traceback.format_exc()}")
        
        result["measurement_failed"] = True
    
    # Restore original model state - do this AFTER profiling
    if model_was_training:
        unwrapped_model.train()
    else:
        unwrapped_model.eval()
    
    # Add per-sample metrics
    if batch_size > 0:
        result["flops_per_sample"] = result["flops"] / batch_size
        result["macs_per_sample"] = result["macs"] / batch_size
    
    # Final cleanup
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        
    return result

def get_layer_dimensions(module):
    """Extract input and output dimensions from a layer module"""
    if hasattr(module, 'weight'):
        weight_shape = module.weight.shape
        if len(weight_shape) >= 2:
            return weight_shape[1], weight_shape[0]  # in_features, out_features
    if hasattr(module, 'in_features') and hasattr(module, 'out_features'):
        return module.in_features, module.out_features
    return 0, 0


def calculate_lora_layer_flops(in_features, out_features, r, batch_size, seq_len):
    """Calculate FLOPs for a specific LoRA layer with given r value"""
    # Base linear operation (if not using LoRA)
    base_flops = 2 * batch_size * seq_len * in_features * out_features
    
    # LoRA operations
    if r > 0:
        # Down projection: [batch, seq, in] × [in, r]
        down_flops = 2 * batch_size * seq_len * in_features * r
        # Up projection: [batch, seq, r] × [r, out]
        up_flops = 2 * batch_size * seq_len * r * out_features
        # Dropout operation (comparison + masking)
        dropout_flops = batch_size * seq_len * r
        # Scaling operation (multiplication)
        scaling_flops = batch_size * seq_len * out_features
        # Addition with base output
        addition_flops = batch_size * seq_len * out_features
        
        lora_flops = down_flops + up_flops + dropout_flops + scaling_flops + addition_flops
        return lora_flops
    else:
        return base_flops


def calculate_mac_ratio(model, sample_batch):
    """Calculate actual MAC/FLOP ratio based on model architecture analysis with LoRA considerations."""
    import torch.nn as nn
    
    mac_operations = 0
    total_operations = 0
    batch_size = sample_batch["input_ids"].size(0) if "input_ids" in sample_batch else 1
    seq_len = sample_batch["input_ids"].size(1) if "input_ids" in sample_batch else 128
    
    # Get r_config if available
    r_config = {}
    if hasattr(model, 'initial_r_config'):
        r_config = model.initial_r_config
    else:
        # Extract r_config from model
        for name, module in model.named_modules():
            if hasattr(module, 'lora_A') and hasattr(module, 'lora_B'):
                r = module.lora_A.shape[0] if hasattr(module, 'lora_A') else 0
                if r > 0:
                    r_config[name] = r
    
    # Analyze model architecture
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            in_features = module.in_features
            out_features = module.out_features
            
            # Check if this layer has LoRA
            r = r_config.get(name, 0)
            if r > 0:
                # LoRA computation
                # Down projection MACs
                down_macs = batch_size * seq_len * in_features * r
                mac_operations += down_macs
                total_operations += down_macs * 2
                
                # Up projection MACs
                up_macs = batch_size * seq_len * r * out_features
                mac_operations += up_macs
                total_operations += up_macs * 2
                
                # Dropout and scaling are not MAC operations
                total_operations += batch_size * seq_len * r  # dropout
                total_operations += batch_size * seq_len * out_features  # scaling
                total_operations += batch_size * seq_len * out_features  # addition
            else:
                # Regular linear layer
                macs = batch_size * seq_len * in_features * out_features
                mac_operations += macs
                total_operations += macs * 2
        
        elif isinstance(module, nn.MultiheadAttention) or 'attention' in name.lower():
            # Attention operations
            embed_dim = getattr(module, 'embed_dim', 768)
            num_heads = getattr(module, 'num_heads', 12)
            
            # QKV projections (considering LoRA)
            qkv_macs = 0
            qkv_total = 0
            
            # Check for query/key/value LoRA
            for comp in ['query', 'key', 'value']:
                comp_name = f"{name}.{comp}"
                if comp_name in r_config:
                    r = r_config[comp_name]
                    if r > 0:
                        # LoRA MACs
                        down_macs = batch_size * seq_len * embed_dim * r
                        up_macs = batch_size * seq_len * r * embed_dim
                        qkv_macs += down_macs + up_macs
                        qkv_total += (down_macs + up_macs) * 2 + batch_size * seq_len * embed_dim
                    else:
                        macs = batch_size * seq_len * embed_dim * embed_dim
                        qkv_macs += macs
                        qkv_total += macs * 2
                else:
                    # Regular projection
                    macs = batch_size * seq_len * embed_dim * embed_dim
                    qkv_macs += macs
                    qkv_total += macs * 2
            
            mac_operations += qkv_macs
            total_operations += qkv_total
            
            # Attention scores (Q @ K.T)
            attention_macs = batch_size * num_heads * seq_len * seq_len * (embed_dim // num_heads)
            mac_operations += attention_macs
            total_operations += attention_macs * 2
            
            # Softmax (not MAC operations)
            total_operations += batch_size * num_heads * seq_len * seq_len * 5
            
            # Context computation (attention @ V)
            context_macs = batch_size * num_heads * seq_len * seq_len * (embed_dim // num_heads)
            mac_operations += context_macs
            total_operations += context_macs * 2
            
            # Output projection (considering LoRA)
            output_name = f"{name}.out_proj"
            if output_name in r_config and r_config[output_name] > 0:
                r = r_config[output_name]
                down_macs = batch_size * seq_len * embed_dim * r
                up_macs = batch_size * seq_len * r * embed_dim
                mac_operations += down_macs + up_macs
                total_operations += (down_macs + up_macs) * 2 + batch_size * seq_len * embed_dim
            else:
                macs = batch_size * seq_len * embed_dim * embed_dim
                mac_operations += macs
                total_operations += macs * 2
    
    # Calculate ratio with bounds checking
    if total_operations > 0:
        ratio = mac_operations / total_operations
        return max(0.3, min(0.7, ratio))  # Reasonable bounds for transformer models
    else:
        return 0.4  # Default fallback for transformers with LoRA

def aggregate_computation_metrics(metrics_list, accelerator=None):
    """
    Accurately aggregate computation metrics across multiple measurements.
    
    This function properly handles distributed environments and ensures
    no duplication occurs when combining metrics from multiple batches.
    
    Args:
        metrics_list: List of dictionaries with metrics to aggregate
        accelerator: Accelerator for distributed training
        
    Returns:
        Dictionary with aggregated metrics
    """
    from datetime import timedelta
    
    if not metrics_list:
        return {"flops": 0, "macs": 0, "flops_per_sample": 0, "macs_per_sample": 0}
    
    # Initialize aggregated metrics with explicit types
    result = {
        "flops": 0.0,  # Use float for accumulation to prevent overflow
        "macs": 0.0,
        "samples": 0
    }
    
    # Check if any measurement failed
    measurement_failures = sum(1 for m in metrics_list if m.get("measurement_failed", False))
    if measurement_failures:
        logger.warning(f"{measurement_failures} out of {len(metrics_list)} measurements failed")
    
    # Filter out failed measurements
    valid_metrics = [m for m in metrics_list if not m.get("measurement_failed", False)]
    
    # If all measurements failed, return zeros with warning
    if not valid_metrics and metrics_list:
        logger.error("All measurements failed - returning zeros")
        return {"flops": 0, "macs": 0, "flops_per_sample": 0, "macs_per_sample": 0, "all_measurements_failed": True}
    
    # Simple aggregation for single process
    if not accelerator or accelerator.num_processes <= 1:
        for metrics in valid_metrics:
            result["flops"] += float(metrics.get("flops", 0))
            result["macs"] += float(metrics.get("macs", 0))
            # Track samples if available
            if "batch_size" in metrics:
                result["samples"] += int(metrics["batch_size"])
    else:
        # Distributed environment requires special handling
        try:
            import torch.distributed as dist
            
            # Enhanced distributed check
            if not dist.is_available():
                raise RuntimeError("Distributed package not available")
                
            if not dist.is_initialized():
                logger.warning("Distributed environment detected but not initialized. Falling back to local aggregation.")
                # Fall back to simple aggregation with validation
                for metrics in valid_metrics:
                    result["flops"] += float(metrics.get("flops", 0))
                    result["macs"] += float(metrics.get("macs", 0))
                    if "batch_size" in metrics:
                        result["samples"] += int(metrics["batch_size"])
                return result
            
            # Verify process group
            if not dist.is_nccl_available() and torch.cuda.is_available():
                logger.warning("NCCL not available, falling back to Gloo backend")
            
            # Sum local metrics with careful type checking
            local_flops = sum(float(metrics.get("flops", 0)) for metrics in valid_metrics)
            local_macs = sum(float(metrics.get("macs", 0)) for metrics in valid_metrics)
            local_samples = sum(int(metrics.get("batch_size", 0)) for metrics in valid_metrics)
            
            # Create tensors for all-reduce with explicit dtype
            metrics_tensor = torch.tensor(
                [local_flops, local_macs, float(local_samples)],
                device=accelerator.device if accelerator else torch.device("cpu"),
                dtype=torch.float64  # Use double precision for large values
            )
            
            # Enhanced synchronization
            if dist.is_initialized():
                # Use timeout to prevent hanging
                dist.barrier(timeout=timedelta(seconds=30))
                dist.all_reduce(metrics_tensor, op=dist.ReduceOp.SUM)
                
                # Verify reduction success
                if torch.isnan(metrics_tensor).any() or torch.isinf(metrics_tensor).any():
                    raise ValueError("Invalid values after all-reduce operation")
            
            # Extract results with appropriate type conversion
            result["flops"] = float(metrics_tensor[0].item())
            result["macs"] = float(metrics_tensor[1].item())
            result["samples"] = int(metrics_tensor[2].item())
            
        except Exception as e:
            logger.warning(f"Distributed metrics aggregation error: {e}")
            logger.info("Falling back to local aggregation")
            
            # Fall back to simple aggregation
            for metrics in valid_metrics:
                result["flops"] += float(metrics.get("flops", 0))
                result["macs"] += float(metrics.get("macs", 0))
                if "batch_size" in metrics:
                    result["samples"] += int(metrics["batch_size"])
    
    # Calculate per-sample metrics if samples are tracked
    if result["samples"] > 0:
        result["flops_per_sample"] = result["flops"] / result["samples"]
        result["macs_per_sample"] = result["macs"] / result["samples"]
    else:
        # Default to averages from the list if batch sizes missing
        valid_metrics_with_per_sample = [m for m in valid_metrics if "flops_per_sample" in m]
        if valid_metrics_with_per_sample:
            result["flops_per_sample"] = sum(m["flops_per_sample"] for m in valid_metrics_with_per_sample) / len(valid_metrics_with_per_sample)
            result["macs_per_sample"] = sum(m["macs_per_sample"] for m in valid_metrics_with_per_sample) / len(valid_metrics_with_per_sample)
        else:
            result["flops_per_sample"] = 0
            result["macs_per_sample"] = 0
    
    # Convert back to int for final metrics (after all calculations)
    result["flops"] = int(result["flops"])
    result["macs"] = int(result["macs"])
    
    return result

def measure_model_computation(model, dataloader, num_batches=None, metrics=["flops", "macs"],
                           accelerator=None, log_output=False, measure_backward=False):
    """
    Measure computation across multiple batches for more accurate estimation.
    
    This function addresses dataset variation issues by directly measuring
    multiple batches instead of extrapolating from a single batch.
    
    Args:
        model: Model to profile
        dataloader: DataLoader to iterate over batches
        num_batches: Number of batches to measure (None for all)
        metrics: List of metrics to calculate
        accelerator: Accelerator for distributed training
        log_output: Whether to log the output
        measure_backward: Whether to measure backward pass (only works when model in training mode)
        
    Returns:
        Dictionary with aggregated metrics
    """
    metrics_list = []
    sample_count = 0
    
    # Set appropriate mode based on measurement type
    model_was_training = model.training
    if not measure_backward:
        model.eval()
    
    # Save CUDA memory state to restore later
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        initial_memory = torch.cuda.memory_allocated()
    
    try:
        # Use deterministic subset of dataloader for reproducibility if num_batches is specified
        batch_indices = list(range(min(num_batches, len(dataloader)))) if num_batches is not None else None
        
        # Iterate through batches
        for batch_idx, batch in enumerate(dataloader):
            if batch_indices is not None and batch_idx >= len(batch_indices):
                break
            
            # Log progress for long measurements
            if log_output and batch_idx > 0 and batch_idx % 10 == 0:
                logger.info(f"Measuring batch {batch_idx}/{len(batch_indices) if batch_indices else len(dataloader)}")
            
            # Measure current batch
            batch_metrics = measure_computation(
                model=model,
                batch=batch,
                metrics=metrics,
                accelerator=accelerator,
                log_output=(log_output and batch_idx == 0),  # Log only first batch
                measure_backward=measure_backward
            )
            
            # Check if measurement failed
            if batch_metrics.get("measurement_failed", False):
                logger.warning(f"Measurement failed for batch {batch_idx} - skipping this batch")
                continue
            
            # Add batch size for weighted averaging
            batch_size = batch["input_ids"].size(0) if "input_ids" in batch else 1
            batch_metrics["batch_size"] = batch_size
            sample_count += batch_size
            
            # Save metrics
            metrics_list.append(batch_metrics)
            
            # Clear memory between measurements for large models
            if torch.cuda.is_available() and batch_idx % 5 == 0:
                torch.cuda.empty_cache()
            
        # Aggregate metrics
        results = aggregate_computation_metrics(metrics_list, accelerator)
        
        # Add metadata to results
        results["measured_batches"] = len(metrics_list)
        results["total_samples"] = sample_count
        
        # Log final results
        if log_output:
            success_rate = len(metrics_list) / (batch_idx + 1) if batch_idx >= 0 else 0
            logger.info(f"Measured {len(metrics_list)} batches ({sample_count} samples) with {success_rate:.1%} success rate:")
            if "flops" in metrics:
                logger.info(f"  Total FLOPs: {format_flops(results['flops'])}")
                logger.info(f"  FLOPs per sample: {format_flops(results['flops_per_sample'])}")
            if "macs" in metrics:
                logger.info(f"  Total MACs: {format_flops(results['macs'])}")
                logger.info(f"  MACs per sample: {format_flops(results['macs_per_sample'])}")
            if measure_backward:
                logger.info(f"  Note: Measurements include both forward and backward passes")
        
        return results
    
    finally:
        # Restore original model state
        if model_was_training:
            model.train()
        else:
            model.eval()
            
        # Clean up memory
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            # Check and log memory leaks
            final_memory = torch.cuda.memory_allocated()
            if final_memory > initial_memory + 10 * 1024 * 1024:  # 10MB threshold
                logger.warning(f"Possible memory leak detected: {(final_memory - initial_memory) / (1024*1024):.2f}MB increase")

# ---------- END OF COMPUTATION MEASUREMENT FUNCTIONS ----------

# Backward-compatible wrapper functions can remain the same or be simplified further
# These are just thin wrappers around the main functions

def calculate_flops(model, batch, step=-1, force=False, accelerator=None, log_output=False):
    """Backward-compatible wrapper for FLOPs calculation."""
    result = measure_computation(
        model=model, 
        batch=batch, 
        metrics=["flops"],
        accelerator=accelerator, 
        log_output=log_output
    )
    return result["flops"]

def calculate_macs(model, batch, step=-1, accelerator=None, log_output=False):
    """Backward-compatible wrapper for MACs calculation."""
    result = measure_computation(
        model=model, 
        batch=batch, 
        metrics=["macs"],
        accelerator=accelerator, 
        log_output=log_output
    )
    return result["macs"]

def calculate_compute_metrics(model, batch, step=-1, force=False, accelerator=None, 
                            log_output=False, metrics=["flops", "macs"], normalize_by_batch_size=True):
    """Backward-compatible wrapper for combined compute metrics calculation."""
    return measure_computation(
        model=model,
        batch=batch,
        metrics=metrics,
        accelerator=accelerator,
        log_output=log_output
    )

def store_compute_metrics(current_metrics, total_metrics, accelerator):
    """Backward-compatible wrapper for metrics aggregation."""
    # Convert any tensor values to Python types
    current_metrics_python = {}
    for k, v in current_metrics.items():
        if isinstance(v, torch.Tensor):
            current_metrics_python[k] = float(v)
        else:
            current_metrics_python[k] = v
    
    # Add values with proper type handling
    for key in ["flops", "macs"]:
        if key in current_metrics_python and key in total_metrics:
            total_metrics[key] = float(total_metrics[key]) + float(current_metrics_python[key])
    
    return total_metrics

# ---------- END OF COMPUTATION MEASUREMENT FUNCTIONS ----------

def preprocess_dataset(tokenizer, raw_datasets, task_name, max_length, max_train_samples=None, max_eval_samples=None):
    """Efficiently preprocess Alpaca dataset for training and evaluation."""
    
    def preprocess_function(examples):
        """Format Alpaca examples into prompt-response pairs."""
        texts = []
        
        for instruction, input_text, output in zip(
            examples.get('instruction', []),
            examples.get('input', []),
            examples.get('output', [])
        ):
            # Format prompt using Alpaca template
            if input_text and input_text.strip():
                text = ALPACA_PROMPT_TEMPLATE.format(
                    instruction=instruction,
                    input=input_text,
                    output=output
                )
            else:
                text = ALPACA_PROMPT_NO_INPUT_TEMPLATE.format(
                    instruction=instruction,
                    output=output
                )
            
            texts.append(text)
        
        # Tokenize the full texts with padding for consistent length
        model_inputs = tokenizer(
            texts,
            padding="max_length",
            max_length=max_length,
            truncation=True,
            return_tensors=None
        )
        
        # For language modeling, labels are the same as input_ids
        # Copy labels properly
        model_inputs["labels"] = model_inputs["input_ids"].copy()
        
        return model_inputs
    
    # Process datasets efficiently
    processed_datasets = raw_datasets.map(
        preprocess_function,
        batched=True,
        remove_columns=raw_datasets["train"].column_names if "train" in raw_datasets else raw_datasets.column_names,
        desc="Preprocessing Alpaca dataset",
    )
    
    # Split dataset if needed
    if "train" not in processed_datasets:
        # If no train split, create train/eval split from the data
        processed_datasets = processed_datasets["train"].train_test_split(test_size=0.1, seed=42)
        processed_datasets["validation"] = processed_datasets.pop("test")
    
    # Prepare train dataset with optional sample limit
    train_dataset = processed_datasets["train"]
    if max_train_samples is not None:
        train_dataset = train_dataset.select(range(min(max_train_samples, len(train_dataset))))
    
    # Prepare evaluation dataset
    eval_dataset = processed_datasets.get("validation", processed_datasets.get("test"))
    if eval_dataset is None:
        # Create validation set from training data if not present
        split_datasets = train_dataset.train_test_split(test_size=0.1, seed=42)
        train_dataset = split_datasets["train"]
        eval_dataset = split_datasets["test"]
    
    if max_eval_samples is not None:
        eval_dataset = eval_dataset.select(range(min(max_eval_samples, len(eval_dataset))))
    
    return train_dataset, eval_dataset


def validate_lora_configuration(r_config, replaced_layers, logger):
    """Validate that the optimized r-values match the actually applied values."""
    logger.info("=" * 60)
    logger.info("VALIDATING LoRA CONFIGURATION")
    logger.info("=" * 60)

    applied_layers = 0
    skipped_layers = 0
    mismatch_layers = 0
    total_params = 0

    for layer_name, layer_info in replaced_layers.items():
        expected_r = r_config.get(layer_name, 0)
        actual_r = layer_info.get('r', 0)
        applied = layer_info.get('applied', False)
        
        # Calculate parameters if dimensions are available
        param_count = 0
        if 'in_features' in layer_info and 'out_features' in layer_info:
            param_count = actual_r * (layer_info['in_features'] + layer_info['out_features'])
            total_params += param_count
        
        status = "✓ MATCH" if expected_r == actual_r else "✗ MISMATCH"
        applied_status = "APPLIED" if applied else "SKIPPED"
        
        logger.info(f"Layer: {layer_name}")
        logger.info(f"  Expected r: {expected_r}, Actual r: {actual_r}, Status: {status}, {applied_status}")
        if param_count > 0:
            logger.info(f"  Parameters: {param_count:,}")
        
        if logger.level <= logging.DEBUG:
            logger.debug(f"Layer {layer_name}: Expected r={expected_r}, "
                        f"Actual r={actual_r}, Status: {status}, {applied_status}")
        
        if expected_r != actual_r:
            mismatch_layers += 1
            logger.warning(f"MISMATCH in layer {layer_name}: Expected r={expected_r}, Actual r={actual_r}")
        
        if applied:
            applied_layers += 1
        else:
            skipped_layers += 1

    logger.info(f"LoRA SUMMARY: {applied_layers} layers applied, {skipped_layers} layers skipped")
    logger.info(f"Total LoRA parameters: {total_params:,}")
    
    if mismatch_layers > 0:
        logger.error(f"CRITICAL: Found {mismatch_layers} layers with r-value mismatches!")
    else:
        logger.info("SUCCESS: All r-values correctly applied")
    logger.info("=" * 60)
    
    return {
        "applied_layers": applied_layers,
        "skipped_layers": skipped_layers,
        "mismatch_layers": mismatch_layers,
        "total_params": total_params
    }


def evaluate_alpaca_generation(model, eval_dataloader, tokenizer, accelerator, max_new_tokens=128):
    """
    Evaluate Alpaca model on generation quality using perplexity and optionally AlpacaEval.
    
    Args:
        model: The model to evaluate
        eval_dataloader: DataLoader for evaluation
        tokenizer: Tokenizer
        accelerator: Accelerator for distributed training
        max_new_tokens: Maximum tokens to generate for quality assessment
        
    Returns:
        Dictionary with evaluation metrics
    """
    model.eval()
    losses = []
    generated_texts = []
    reference_texts = []
    
    with torch.no_grad():
        for batch in eval_dataloader:
            # Calculate perplexity
            outputs = model(**batch)
            loss = outputs.loss
            losses.append(accelerator.gather(loss.repeat(batch["input_ids"].shape[0])))
            
            # Generate text for quality assessment (sample a few examples)
            if len(generated_texts) < 10:  # Sample 10 examples for quality assessment
                input_ids = batch["input_ids"]
                attention_mask = batch["attention_mask"]
                
                # Find where the prompt ends (before the response)
                prompt_length = attention_mask.sum(dim=1).min().item() // 2  # Approximate
                prompt_ids = input_ids[:, :prompt_length]
                
                # Generate response
                generated_ids = model.generate(
                    prompt_ids,
                    max_new_tokens=max_new_tokens,
                    temperature=0.7,
                    do_sample=True,
                    top_p=0.9,
                    pad_token_id=tokenizer.pad_token_id,
                    eos_token_id=tokenizer.eos_token_id,
                )
                
                # Decode generated text
                for i in range(min(batch["input_ids"].shape[0], 10 - len(generated_texts))):
                    generated_text = tokenizer.decode(generated_ids[i], skip_special_tokens=True)
                    generated_texts.append(generated_text)
                    
                    # Get reference text from labels
                    if "labels" in batch:
                        ref_ids = batch["labels"][i]
                        ref_ids = ref_ids[ref_ids != -100]  # Remove padding
                        reference_text = tokenizer.decode(ref_ids, skip_special_tokens=True)
                        reference_texts.append(reference_text)
    
    # Calculate perplexity
    losses = torch.cat(losses)
    perplexity = torch.exp(losses.mean())
    
    metrics = {
        "perplexity": perplexity.item(),
        "eval_loss": losses.mean().item(),
        "num_samples": len(losses),
    }
    
    # Optional: Run AlpacaEval if available
    if generated_texts and reference_texts:
        try:
            # Save generated texts for AlpacaEval
            eval_data = [
                {"instruction": "", "output": gen, "generator": "model"}
                for gen in generated_texts
            ]
            
            # Note: Full AlpacaEval integration would require additional setup
            # For now, we just log sample generations
            logger.info("Sample generations for quality assessment:")
            for i, gen in enumerate(generated_texts[:3]):
                logger.info(f"Sample {i+1}: {gen[:200]}...")
                
        except Exception as e:
            logger.warning(f"Could not run AlpacaEval: {e}")
    
    return metrics


def save_metrics(metrics, output_path):
    """Save metrics to a file, handling tensor conversion efficiently."""
    # Convert tensor values to Python types
    serializable_metrics = {}
    for k, v in metrics.items():
        if isinstance(v, torch.Tensor):
            serializable_metrics[k] = float(v)
        else:
            serializable_metrics[k] = v
    
    # Ensure directory exists
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    
    # Save metrics efficiently
    try:
        with open(output_path, 'w') as f:
            json.dump(serializable_metrics, f, indent=2)
    except Exception as e:
        logger.warning(f"Failed to save metrics: {e}")


def optimize_lora_config(model, dataloader, args, accelerator, output_dir):
    """Find optimal LoRA configuration."""
    logger.info(f"Starting LoRA optimization with seed {args.seed}...")
    
    # Use efficient optimizer with seed
    lora_optimizer = LoRAOptimizer(
        model=model,
        r_values=args.lora_r_values,
        budget=args.lora_budget,
        device=accelerator.device,
        output_dir=os.path.join(output_dir, "lora_optimization"),
        seed=args.seed
    )
    
    # Find optimal configuration
    with accelerator.main_process_first():
        model.to(accelerator.device)
        r_config = lora_optimizer.optimize(dataloader)
    
    # Save configuration for reference
    if accelerator.is_main_process:
        with open(os.path.join(output_dir, "lora_r_config.json"), 'w') as f:
            json.dump(r_config, f, indent=2)
    
    # Log summary of configuration
    layer_types = {
        "query": [],
        "key": [],
        "value": [],
        "attention.output": [],
        "intermediate": [],
        "output": [],
    }
    
    # Group layers by type (handle both BERT/RoBERTa and LLaMA naming)
    for layer_name, r_value in r_config.items():
        # LLaMA layer names
        if "q_proj" in layer_name:
            layer_types["query"].append((layer_name, r_value))
        elif "k_proj" in layer_name:
            layer_types["key"].append((layer_name, r_value))
        elif "v_proj" in layer_name:
            layer_types["value"].append((layer_name, r_value))
        elif "o_proj" in layer_name:
            layer_types["attention.output"].append((layer_name, r_value))
        elif "gate_proj" in layer_name or "up_proj" in layer_name:
            layer_types["intermediate"].append((layer_name, r_value))
        elif "down_proj" in layer_name:
            layer_types["output"].append((layer_name, r_value))
        # BERT/RoBERTa layer names
        elif "query" in layer_name:
            layer_types["query"].append((layer_name, r_value))
        elif "key" in layer_name:
            layer_types["key"].append((layer_name, r_value))
        elif "value" in layer_name:
            layer_types["value"].append((layer_name, r_value))
        elif "attention.output" in layer_name:
            layer_types["attention.output"].append((layer_name, r_value))
        elif "intermediate" in layer_name:
            layer_types["intermediate"].append((layer_name, r_value))
        elif "output" in layer_name and "attention" not in layer_name:
            layer_types["output"].append((layer_name, r_value))
    
    # Log layer type summary
    logger.info("OPTIMAL LoRA RANK CONFIGURATION:")
    
    for layer_type, layers in layer_types.items():
        if not layers:
            continue
            
        r_values = [r for _, r in layers]
        avg_r = sum(r_values) / len(r_values) if r_values else 0
        logger.info(f"  {layer_type}: avg_r={avg_r:.2f}, count={len(layers)}")
        
    # Calculate and log parameter count
    total_params = sum(
        r_value * (get_layer_size(model, layer_name)[0] + get_layer_size(model, layer_name)[1]) 
        for layer_name, r_value in r_config.items() if r_value > 0
    )
    logger.info(f"Total LoRA parameters: {total_params:,}")
    
    return r_config

def check_lora_merge_status(model):
    """Check if LoRA layers are in merged state"""
    logger.info("=" * 80)
    logger.info("CHECKING LoRA MERGE STATUS")
    logger.info("=" * 80)
    
    merged_count = 0
    unmerged_count = 0
    zero_r_count = 0
    
    for name, module in model.named_modules():
        if hasattr(module, 'lora_A') and hasattr(module, 'lora_B'):
            if hasattr(module, 'merged') and hasattr(module, 'r'):
                if module.r == 0:
                    zero_r_count += 1
                    logger.info(f"Layer {name} has r=0 (pruned out)")
                elif module.merged:
                    merged_count += 1
                    logger.info(f"Layer {name} is MERGED")
                else:
                    unmerged_count += 1
                    logger.info(f"Layer {name} is UNMERGED")
            else:
                logger.info(f"Layer {name} merge status unknown")
    
    logger.info(f"Summary: Merged={merged_count}, Unmerged={unmerged_count}, Zero-r={zero_r_count}")
    logger.info("=" * 80)
    
    return merged_count, unmerged_count


def main():
    args = parse_args()
    
    # Set environment variables for reproducibility
    os.environ["PYTHONHASHSEED"] = str(args.seed)
    os.environ["GRB_NUM_THREADS"] = "1"  # Force Gurobi to use a single thread
    
    # Setup minimal logging
    log_level = getattr(logging, args.log_level.upper())
    logger = setup_logger("optimal_lora", os.path.join(args.output_dir, "training.log"), level=log_level)
    
    # Hugging Face authentication for LLaMA models
    if "llama" in args.model_name_or_path.lower():
        try:
            from huggingface_hub import login
            # Try to get token from environment variable first
            hf_token = os.environ.get("HF_TOKEN")
            if hf_token:
                login(token=hf_token)
                logger.info("Successfully authenticated with Hugging Face using HF_TOKEN")
            else:
                # Fallback: replace with your actual token
                login(token="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")
                logger.info("Successfully authenticated with Hugging Face using hardcoded token")
        except Exception as e:
            logger.warning(f"Hugging Face authentication failed: {e}")
            logger.warning("Please set your HF token or use CLI login: huggingface-cli login")
    
    # Configure minimal logging for external libraries
    setup_minimal_logging()
    
    # Set random seed
    set_deterministic_environment(args.seed)
    
    # Initialize accelerator
    accelerator = Accelerator(
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        log_with="tensorboard",
        project_dir=os.path.join(args.output_dir, "logs")
    )
    
    # Initialize tracking variables
    train_flops = 0
    train_macs = 0
    eval_flops = 0
    eval_macs = 0
    
    # Basic memory tracking
    if torch.cuda.is_available():
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.empty_cache()
        logger.info(f"Initial GPU memory: {torch.cuda.memory_allocated() / (1024**3):.2f}GB")
    
    # Load configuration for Alpaca
    config = AutoConfig.from_pretrained(
        args.config_name if args.config_name else args.model_name_or_path,
    )
    
    # Load tokenizer with special handling for LLaMA
    tokenizer = AutoTokenizer.from_pretrained(
        args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
        use_fast=True,
    )
    
    # Set padding token if not already set (for models like LLaMA, GPT-2, etc.)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.eos_token_id
        logger.info(f"Set pad_token to eos_token: {tokenizer.pad_token}")
    
    # Load Alpaca dataset
    if args.dataset_path and os.path.exists(args.dataset_path):
        # Load from local file
        with open(args.dataset_path, 'r') as f:
            data = json.load(f)
        raw_datasets = datasets.Dataset.from_dict({k: [d[k] for d in data] for k in data[0].keys()})
        raw_datasets = datasets.DatasetDict({"train": raw_datasets})
    else:
        # Load from Hugging Face Hub
        raw_datasets = load_dataset("tatsu-lab/alpaca", trust_remote_code=True)
    
    # Preprocess Alpaca dataset
    train_dataset, eval_dataset = preprocess_dataset(
        tokenizer=tokenizer,
        raw_datasets=raw_datasets,
        task_name=args.task_name,
        max_length=args.max_seq_length,
        max_train_samples=args.max_train_samples,
        max_eval_samples=args.max_eval_samples
    )
    
    # Create data collator for language modeling
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer,
        mlm=False,  # Causal LM, not masked LM
        pad_to_multiple_of=8 if accelerator.mixed_precision == "fp16" else None
    )
    
    g = torch.Generator()
    g.manual_seed(args.seed)

    train_dataloader = DataLoader(
        train_dataset, 
        shuffle=True, 
        collate_fn=data_collator, 
        batch_size=args.per_device_train_batch_size,
        worker_init_fn=seed_worker,
        generator=g,
        persistent_workers=False
    )

    # Create evaluation dataloader for Alpaca
    eval_dataloader = DataLoader(
        eval_dataset, 
        collate_fn=data_collator, 
        batch_size=args.per_device_eval_batch_size,
        worker_init_fn=seed_worker,
        generator=g,
        persistent_workers=False
    )
    
    # Load base model for generation task
    logger.info("Loading base model for Alpaca task...")
    if "llama" in args.model_name_or_path.lower():
        # For LLaMA, we'll use prepare_model_for_alpaca which handles the model loading
        pass  # Model will be loaded in prepare_model_for_alpaca
    else:
        model = AutoModelForCausalLM.from_pretrained(
            args.model_name_or_path,
            config=config,
        )
    
    # === LoRA Optimization Phase ===
    if args.use_existing_lora_config:
        logger.info(f"Using existing LoRA configuration from {args.use_existing_lora_config}")
        with open(args.use_existing_lora_config, 'r') as f:
            r_config = json.load(f)
    else:
        # For LLaMA, we need to load a temporary model for optimization
        if "llama" in args.model_name_or_path.lower():
            logger.info("Loading temporary LLaMA model for optimization...")
            from transformers import LlamaForCausalLM
            temp_model = LlamaForCausalLM.from_pretrained(
                args.model_name_or_path,
                torch_dtype=torch.bfloat16,
                device_map="cpu",  # Load on CPU first for optimization
                pad_token_id=tokenizer.pad_token_id  # Set pad token for the model
            )
            model = temp_model
        
        # Create a subset of the training data for optimization
        optimization_dataset = train_dataset.select(range(min(MAX_OPTIMIZATION_SAMPLES, len(train_dataset))))
        optimization_dataloader = DataLoader(
            optimization_dataset, 
            shuffle=False, 
            collate_fn=data_collator, 
            batch_size=args.per_device_eval_batch_size
        )
        
        # Find optimal LoRA config
        r_config = optimize_lora_config(
            model=model,
            dataloader=optimization_dataloader,
            args=args,
            accelerator=accelerator,
            output_dir=args.output_dir
        )
        
        # Explicitly log the optimal r configuration here in the main function
        if accelerator.is_main_process:
            # Group layers by type for detailed logging
            layer_types = {
                "query": [],
                "key": [],
                "value": [],
                "attention.output": [],
                "intermediate": [],
                "output": [],
            }
            
            # Group layers by type (handle both BERT/RoBERTa and LLaMA naming)
            for layer_name, r_value in r_config.items():
                # LLaMA layer names
                if "q_proj" in layer_name:
                    layer_types["query"].append((layer_name, r_value))
                elif "k_proj" in layer_name:
                    layer_types["key"].append((layer_name, r_value))
                elif "v_proj" in layer_name:
                    layer_types["value"].append((layer_name, r_value))
                elif "o_proj" in layer_name:
                    layer_types["attention.output"].append((layer_name, r_value))
                elif "gate_proj" in layer_name or "up_proj" in layer_name:
                    layer_types["intermediate"].append((layer_name, r_value))
                elif "down_proj" in layer_name:
                    layer_types["output"].append((layer_name, r_value))
                # BERT/RoBERTa layer names
                elif "query" in layer_name:
                    layer_types["query"].append((layer_name, r_value))
                elif "key" in layer_name:
                    layer_types["key"].append((layer_name, r_value))
                elif "value" in layer_name:
                    layer_types["value"].append((layer_name, r_value))
                elif "attention.output" in layer_name:
                    layer_types["attention.output"].append((layer_name, r_value))
                elif "intermediate" in layer_name:
                    layer_types["intermediate"].append((layer_name, r_value))
                elif "output" in layer_name and "attention" not in layer_name:
                    layer_types["output"].append((layer_name, r_value))
            
            # Log detailed r configuration report
            logger.info("=" * 50)
            logger.info("OPTIMAL LoRA RANK (r) CONFIGURATION:")
            logger.info("=" * 50)
            for layer_type, layers in layer_types.items():
                if not layers:
                    continue
                
                logger.info(f"\n[{layer_type.upper()} LAYERS]")
                # Calculate average r value for this layer type
                r_values = [r for _, r in layers]
                avg_r = sum(r_values) / len(r_values) if r_values else 0
                
                for layer_name, r in sorted(layers, key=lambda x: x[0]):
                    logger.info(f"  {layer_name}: r = {r}")
                
                logger.info(f"  Average r for {layer_type} layers: {avg_r:.2f}")
            logger.info("=" * 50)
            
            # Calculate and log total LoRA parameters (skip for LLaMA since model is deleted)
            if "llama" not in args.model_name_or_path.lower():
                total_params = sum(
                    r_value * (get_layer_size(model, layer_name)[0] + get_layer_size(model, layer_name)[1]) 
                    for layer_name, r_value in r_config.items() if r_value > 0
                )
                logger.info(f"Total LoRA parameters: {total_params:,}")
            else:
                logger.info("Total LoRA parameters: Will be calculated after model loading")
            logger.info("=" * 50)

        # Clean up temporary model for LLaMA
        if "llama" in args.model_name_or_path.lower():
            del temp_model
            del model
            torch.cuda.empty_cache()

        # Free memory
        del optimization_dataset, optimization_dataloader
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    
    # If optimize_only, exit here
    if args.optimize_only:
        logger.info("Optimization completed. Exiting as --optimize_only was specified.")
        return
    
    # === Training Phase ===
    logger.info("Preparing model with LoRA for training...")
    
    # Replace model with LoRA version for Alpaca
    model, replaced_layers = prepare_model_for_alpaca(
        base_model_name=args.model_name_or_path,
        r_config=r_config,
        dropout=args.lora_dropout,
        tokenizer=tokenizer
    )
    
    # Validate the LoRA configuration
    validate_lora_configuration(r_config, replaced_layers, logger)
    
    # Save model architecture for reference
    if accelerator.is_main_process:
        with open(os.path.join(args.output_dir, "model_architecture.txt"), 'w') as f:
            f.write(str(model))
    
    # Mark only LoRA parameters as trainable for efficiency
    lora.mark_only_lora_as_trainable(model)
    
    # Log parameter information
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total_params = sum(p.numel() for p in model.parameters())
    logger.info(f"Trainable parameters: {trainable_params:,} ({trainable_params/total_params:.2%} of total)")
    
    # Calculate LoRA parameters
    total_lora_params = 0
    for layer_name, r_value in r_config.items():
        if r_value > 0:
            in_features, out_features = get_layer_size(model, layer_name)
            total_lora_params += r_value * (in_features + out_features)
    logger.info(f"Total LoRA parameters: {total_lora_params:,}")
    
    # Setup optimizer - use transformers AdamW for accuracy consistency
    optimizer = AdamW(
        [param for param in model.parameters() if param.requires_grad],
        lr=args.learning_rate,
        weight_decay=args.weight_decay,
        no_deprecation_warning=True  # Suppress deprecation warning
    )
    
    # Prepare with accelerator for Alpaca
    model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
        model, optimizer, train_dataloader, eval_dataloader
    )
    
    # Calculate training steps
    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
    if args.max_train_steps is None:
        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
    else:
        args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
    
    # Calculate warmup steps from ratio if not explicitly set
    if args.num_warmup_steps == 0 and args.warmup_ratio > 0:
        args.num_warmup_steps = int(args.max_train_steps * args.warmup_ratio)
        
    # Create learning rate scheduler
    lr_scheduler = get_scheduler(
        name=args.lr_scheduler_type,
        optimizer=optimizer,
        num_warmup_steps=args.num_warmup_steps,
        num_training_steps=args.max_train_steps,
    )
    # Initialize metrics for Alpaca evaluation
    # We'll use perplexity as primary metric and optionally AlpacaEval for quality assessment
    
    # Initialize progressive pruning manager if enabled
    pruning_manager = None
    if args.apply_pruning:
        logger.info("=" * 80)
        logger.info("INITIALIZING PROGRESSIVE LORA PRUNING")
        logger.info("=" * 80)
        
        # Get initial r_config from model
        unwrapped_model = accelerator.unwrap_model(model)
        if not hasattr(unwrapped_model, 'initial_r_config'):
            logger.warning("Model doesn't have initial_r_config attribute. Extracting from model...")
            
            # Extract r_config from model
            initial_r_config = {}
            for name, module in unwrapped_model.named_modules():
                if hasattr(module, 'lora_A') and hasattr(module, 'lora_B'):
                    r = getattr(module, 'lora_A').shape[0] if hasattr(module, 'lora_A') else 0
                    if r > 0:
                        initial_r_config[name] = r
        else:
            initial_r_config = unwrapped_model.initial_r_config
            
        # Set pruning output directory
        pruning_output_dir = args.pruning_output_dir or os.path.join(args.output_dir, "pruning")
        os.makedirs(pruning_output_dir, exist_ok=True)
        
        # Override config values with command line arguments
        from pruning import pruning_config
        pruning_config.IMPORTANCE_EMA_DECAY = args.importance_ema_decay
        pruning_config.MOMENTUM_PENALTY_WEIGHT = args.momentum_penalty_weight
        pruning_config.RECOVERY_STEPS = args.recovery_steps
        pruning_config.EXTENDED_RECOVERY_STEPS = args.extended_recovery_steps
        
        # Log EMA and momentum parameters
        logger.info(f"Using importance EMA decay: {pruning_config.IMPORTANCE_EMA_DECAY}")
        logger.info(f"Using momentum penalty weight: {pruning_config.MOMENTUM_PENALTY_WEIGHT}")
        logger.info(f"Using recovery steps: {pruning_config.RECOVERY_STEPS}")
        logger.info(f"Using extended recovery steps: {pruning_config.EXTENDED_RECOVERY_STEPS}")
        
        # Create pruning manager with seed
        pruning_manager = ProgressivePruningManager(
            model=unwrapped_model,
            initial_r_config=initial_r_config,
            train_dataloader=train_dataloader,
            eval_dataloader=eval_dataloader,
            target_reduction=args.pruning_target_reduction,
            num_pruning_steps=args.pruning_steps,
            total_training_steps=args.max_train_steps,
            device=accelerator.device,
            output_dir=pruning_output_dir,
            seed=args.seed,
            enable_rollback=not args.disable_rollback
        )
        
        # Initialize baseline performance
        pruning_manager.initialize_baseline_performance()
        merged_count, unmerged_count = check_lora_merge_status(model)

    # Training loop preparation

    logger.info(f"Starting training for {args.num_train_epochs} epochs ({args.max_train_steps} steps)")
    completed_steps = 0
    best_metric = 0.0
    train_start_time = time.time()
    total_train_loss = 0
    total_train_samples = 0
    total_train_steps = 0
    
    # Measure computation with sample batch for baseline
    logger.info("Measuring baseline computation metrics...")
    sample_batch = next(iter(train_dataloader))
    sample_batch_size_1 = {k: v[0:1] for k, v in sample_batch.items() if isinstance(v, torch.Tensor)}
    
    # Measure both metrics for baseline
    baseline_metrics = measure_computation(
        model=model, 
        batch=sample_batch_size_1, 
        metrics=["flops", "macs"],
        accelerator=accelerator, 
        log_output=True
    )
    
    logger.info(f"Baseline computation metrics (single sample):")
    logger.info(f"  - FLOPs per sample: {format_flops(baseline_metrics['flops_per_sample'])}")
    if baseline_metrics['macs_per_sample'] > 0:
        logger.info(f"  - MACs per sample: {format_flops(baseline_metrics['macs_per_sample'])}")
    
    # Verify model parameters before training
    validate_model_parameters(model)
    
    # Training loop - optimized for performance and accuracy
    for epoch in range(args.num_train_epochs):
        # Reset memory stats at epoch start
        if torch.cuda.is_available():
            torch.cuda.reset_peak_memory_stats()
            torch.cuda.empty_cache()
            logger.info(f"Epoch {epoch+1}/{args.num_train_epochs} - Starting memory: "
                       f"{torch.cuda.memory_allocated() / (1024**3):.2f}GB")
        
        model.train()
        train_loss = 0.0
        
        # Initialize list to store computation metrics for this epoch
        epoch_computation_metrics = []
        
        for step, batch in enumerate(train_dataloader):
            # Check if pruning should be applied at this step
            if pruning_manager is not None and pruning_manager.should_prune(completed_steps):
                # Execute pruning step
                pruning_success = pruning_manager.execute_pruning_step(completed_steps)
                # Refresh model reference after pruning
                if pruning_success:
                    unwrapped_model = accelerator.unwrap_model(model)
                    # Ensure proper LoRA layer states after pruning
                    logger.info("Checking LoRA layer states after pruning...")
                    check_lora_merge_status(unwrapped_model)
            
            outputs = model(**batch)
            
            # Measure computation metrics at specified intervals
            if step % LOG_FLOPS_INTERVAL == 0:
                try:
                    # Measure exact computation for this batch directly
                    batch_metrics = measure_computation(
                        model=model, 
                        batch=batch, 
                        metrics=["flops", "macs"],
                        accelerator=accelerator,
                        log_output=(step == 0)  # Only log the first batch
                    )
                    
                    # Add to tracking list with batch size for proper weighting
                    batch_metrics["batch_size"] = batch["input_ids"].size(0)
                    epoch_computation_metrics.append(batch_metrics)
                    
                except Exception as e:
                    logger.warning(f"Computation measurement skipped: {e}")
            
            loss = outputs.loss
            if loss is None:
                # Attempt to calculate loss from logits if available
                if hasattr(outputs, 'logits') and outputs.logits is not None:
                    if 'labels' in batch:
                        if batch['labels'].dim() == 1:  # Classification operations
                            loss = torch.nn.functional.cross_entropy(outputs.logits, batch['labels'])
                            logger.debug(f"Computed cross_entropy loss for batch with shape {outputs.logits.shape}")
                        else:  # Regression operations
                            loss = torch.nn.functional.mse_loss(outputs.logits.squeeze(), batch['labels'])
                            logger.debug(f"Computed mse loss for batch with shape {outputs.logits.shape}")
                    else:
                        # Check alternate keys if there is no label
                        label_found = False
                        for key in ['label', 'target', 'targets']:
                            if key in batch:
                                if batch[key].dim() == 1:  # Classification operations
                                    loss = torch.nn.functional.cross_entropy(outputs.logits, batch[key])
                                else:  # Regression operations
                                    loss = torch.nn.functional.mse_loss(outputs.logits.squeeze(), batch[key])
                                label_found = True
                                logger.debug(f"Used alternative label key: {key}")
                                break
                        
                        if not label_found:
                            logger.error(f"No labels found in batch. Skipping this batch. Batch keys: {list(batch.keys())}")
                            loss = None
                            continue  # Skip Batch
                else:
                    logger.warning(f"Cannot compute loss: No loss or logits in outputs. Output keys: {dir(outputs)}")
                    continue  # Skip Batch
            
            if loss is not None:
                # Process loss efficiently
                train_loss += loss.detach().float()
                total_train_loss += loss.detach().float().item()
                
                # Track batch statistics
                batch_size = batch["input_ids"].size(0)
                total_train_samples += batch_size
                
                # Scale loss for gradient accumulation
                loss = loss / args.gradient_accumulation_steps
                accelerator.backward(loss)
                
                if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
                    # Gradient clipping
                    torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
                    
                    optimizer.step()
                    lr_scheduler.step()
                    optimizer.zero_grad()
                    
                    # Update recovery mode if pruning is active
                    if pruning_manager is not None and pruning_manager.scheduler.in_recovery_mode:
                        pruning_manager.update_recovery()
                    
                    completed_steps += 1
                    total_train_steps += 1
                    
                    # Log progress at reasonable intervals
                    if completed_steps % 100 == 0:
                        avg_loss = train_loss / (step + 1)
                        progress = (completed_steps / args.max_train_steps) * 100
                        current_lr = optimizer.param_groups[0]['lr']
                        
                        # Include pruning info in log if active
                        if pruning_manager is not None:
                            pruning_info = pruning_manager.get_progress_info()
                            pruning_status = f", Pruning: {pruning_info['reduction']:.2%} reduction"
                            if pruning_manager.scheduler.in_recovery_mode:
                                pruning_status += " (in recovery)"
                        else:
                            pruning_status = ""
                        
                        logger.info(f"Step {completed_steps}/{args.max_train_steps} ({progress:.1f}%): "
                                   f"Loss={avg_loss:.4f}, LR={current_lr:.2e}{pruning_status}")
                    
                    if completed_steps >= args.max_train_steps:
                        break
        
        # Aggregate epoch computation metrics
        if epoch_computation_metrics:
            epoch_results = aggregate_computation_metrics(epoch_computation_metrics, accelerator)
            
            # Debug: Check measurement details
            measured_steps = len(epoch_computation_metrics)
            total_steps = len(train_dataloader)
                        
            # Calculate expected measurements
            expected_measured_steps = (total_steps + LOG_FLOPS_INTERVAL - 1) // LOG_FLOPS_INTERVAL
            
            # Calculate total samples in measured batches
            measured_samples = sum(m.get("batch_size", 0) for m in epoch_computation_metrics)
            
            # Extrapolate to full epoch
            extrapolation_factor = total_steps / measured_steps if measured_steps > 0 else 1
            
            # Apply extrapolation to get total computation for the epoch
            epoch_flops_total = epoch_results["flops"] * extrapolation_factor
            epoch_macs_total = epoch_results["macs"] * extrapolation_factor
            
            train_flops += epoch_flops_total
            train_macs += epoch_macs_total
            
            # Log epoch computation with detailed breakdown
            logger.info(f"Epoch {epoch+1} computation breakdown:")
            logger.info(f"  Measured {measured_steps}/{total_steps} batches (every {LOG_FLOPS_INTERVAL} steps)")
            logger.info(f"  Measured FLOPs: {format_flops(epoch_results['flops'])}")
            logger.info(f"  Extrapolation factor: {extrapolation_factor:.2f}")
            logger.info(f"  Extrapolated total FLOPs: {format_flops(epoch_flops_total)}")
            logger.info(f"  Extrapolated total MACs: {format_flops(epoch_macs_total)}")
        
        # Evaluation at epoch end
        model.eval()
        
        # Measure actual evaluation computation
        logger.info("Measuring evaluation computation metrics...")
        eval_computation_metrics = measure_model_computation(
            model=model,
            dataloader=eval_dataloader,
            num_batches=min(5, len(eval_dataloader)),  # Measure on subset for efficiency
            metrics=["flops", "macs"],
            accelerator=accelerator,
            log_output=True
        )
        
        # Store computation metrics
        eval_flops = eval_computation_metrics["flops"] 
        eval_macs = eval_computation_metrics["macs"]
        
        # Run evaluation with precise timing using utility function
        logger.info("Running evaluation with precise timing...")
        
        # Check and log LoRA merge status before evaluation
        check_lora_merge_status(model)
        
        # Force merge for optimal inference if needed
        # Don't eval the unwrapped model - it will merge LoRA weights permanently
        # Keep using the wrapped model for eval
        model.eval()  # Use wrapped model to maintain proper state
        
        timing_results = measure_evaluation_time(
            model=model,
            dataloader=eval_dataloader,
            device=accelerator.device,
            logger=logger
        )
        
        # Extract timing information
        eval_runtime = timing_results["total_time"]
        eval_steps_per_second = timing_results["steps_per_second"]
        avg_batch_time = timing_results["avg_batch_time"]
        
        # Process evaluation results
        eval_loss = 0.0
        eval_samples = 0
        valid_batch_added = False
        
        # Run Alpaca evaluation
        logger.info("Running Alpaca evaluation...")
        eval_metric = evaluate_alpaca_generation(
            model=model,
            eval_dataloader=eval_dataloader,
            tokenizer=tokenizer,
            accelerator=accelerator,
            max_new_tokens=128
        )
        
        logger.info(f"Epoch {epoch + 1} evaluation results:")
        logger.info(f"  Perplexity: {eval_metric.get('perplexity', 0):.2f}")
        logger.info(f"  Eval Loss: {eval_metric.get('eval_loss', 0):.4f}")

        # No MNLI-specific logic needed for Alpaca task

        # Add computational metrics to eval metrics
        eval_metric["eval_flops"] = eval_flops
        eval_metric["eval_flops_gflops"] = eval_flops / 1e9
        eval_metric["eval_macs"] = eval_macs
        eval_metric["eval_macs_gflops"] = eval_macs / 1e9
        eval_metric["eval_steps_per_second"] = eval_steps_per_second
        
        # Don't overwrite eval_loss - it's already in eval_metric from evaluate_alpaca_generation
        # eval_metric["eval_loss"] is already set correctly
                    
        # Log metrics
        logger.info(f"Epoch {epoch+1}: {eval_metric}")
        
        # Track best model using perplexity (lower is better)
        current_metric = -eval_metric.get("perplexity", float('inf'))  # Negative for consistency (higher is better)
        
        if current_metric > best_metric:
            best_metric = current_metric
            # Save best model
            if accelerator.is_main_process:
                unwrapped_model = accelerator.unwrap_model(model)
                unwrapped_model.save_pretrained(
                    os.path.join(args.output_dir, "best_model"),
                    save_function=accelerator.save
                )
                tokenizer.save_pretrained(os.path.join(args.output_dir, "best_model"))
                logger.info(f"New best model saved with {args.task_name} metric: {current_metric:.4f}")
        
        # Properly clean memory after epoch with controlled garbage collection
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            gc.collect()
    
    # Calculate final train metrics
    train_runtime = time.time() - train_start_time
    train_samples_per_second = total_train_samples / train_runtime if train_runtime > 0 else 0
    train_steps_per_second = total_train_steps / train_runtime if train_runtime > 0 else 0
    avg_train_loss = total_train_loss / total_train_steps if total_train_steps > 0 else 0

    # Calculate model size and parameters
    model_size_mb = total_params * 4 / (1024 * 1024)  # Approximate model size (4 bytes per parameter)
    
    # Get formatted runtime
    formatted_train_runtime = format_runtime(train_runtime)
    
    # Get pruning metrics if available
    pruning_metrics = {}
    if pruning_manager is not None:
        pruning_info = pruning_manager.get_progress_info()
        pruning_metrics = pruning_info.get("pruning_metrics", {})
    
    # Prepare train metrics in the original order for consistency
    train_metrics = {
        "epoch": args.num_train_epochs,
        "total_flos_gflops": train_flops / (10**9),  # In GFLOPs
        "total_macs_gflops": train_macs / (10**9),  # MACs in GFLOPs
        "train_loss": avg_train_loss,
        "train_runtime_formatted": formatted_train_runtime,  # Human-readable format
        "train_samples": total_train_samples,
        "train_samples_per_second": train_samples_per_second,
        "train_steps_per_second": train_steps_per_second,
        "trainable_lora_parameters": trainable_params,
        "total_parameters": total_params,
        "model_size_mb": model_size_mb,
    }
    
    # Add pruning metrics if available
    if pruning_metrics:
        # Calculate accurate pruning metrics for final report
        if pruning_manager is not None:
            # Get accurate reduction values for final reporting
            current_params = pruning_utils.calculate_total_parameters(
                pruning_manager.current_r_config, pruning_manager.layer_sizes)
            actual_pruned_params = pruning_manager.initial_params - current_params
            actual_reduction_ratio = (pruning_manager.initial_params - current_params) / pruning_manager.initial_params
            
            # Update metrics with accurate values
            pruning_metrics["pruned_params"] = actual_pruned_params
            pruning_metrics["model_size_reduction"] = actual_reduction_ratio * 100  # as percentage
            
            # Log calculation for debugging
            logger.debug(f"Updated pruning metrics: pruned_params={actual_pruned_params:,}, "
                        f"reduction={actual_reduction_ratio:.4f} ({actual_reduction_ratio*100:.2f}%)")
        
        train_metrics.update({
            "pruned_params": pruning_metrics.get("pruned_params", 0),
            "pruned_model_size_mb": model_size_mb * (1 - pruning_metrics.get("model_size_reduction", 0) / 100),
            "model_size_reduction_pct": pruning_metrics.get("model_size_reduction", 0)
        })
    
    # Add memory statistics
    if torch.cuda.is_available():
        memory_stats = get_memory_stats()
        train_metrics.update({
            "peak_memory_gb": memory_stats.get("peak_allocated_gb", 0),
            "memory_reserved_gb": memory_stats.get("reserved_gb", 0),
        })
    
    # Log train metrics
    logger.info("***** Train metrics *****")
    for key, value in train_metrics.items():
        if isinstance(value, float):
            logger.info(f"  {key:<25} = {value:>10.6f}")
        else:
            logger.info(f"  {key:<25} = {value:>10}")
    
    # Save train metrics
    if accelerator.is_main_process:
        save_metrics(train_metrics, os.path.join(args.output_dir, "train_metrics.json"))
    
    # Finalize pruning if it was applied
    if pruning_manager is not None:
        logger.info("=" * 80)
        logger.info("FINALIZING PROGRESSIVE LORA PRUNING")
        logger.info("=" * 80)
        
        pruned_model, final_r_config, achieved_reduction = pruning_manager.finalize()
        
        # Save pruned model - COMMENTED OUT to avoid duplicate saving
        # The model will be saved once at the end of training in the main output directory
        # if accelerator.is_main_process:
        #     pruned_model.save_pretrained(
        #         os.path.join(pruning_manager.output_dir, "final_model"),
        #     )
        #     tokenizer.save_pretrained(os.path.join(pruning_manager.output_dir, "final_model"))
        #     # Save final r_config
        #     with open(os.path.join(pruning_manager.output_dir, "final_r_config.json"), 'w') as f:
        #         json.dump(final_r_config, f, indent=2)
        
        # Still save the r_config separately for reference
        if accelerator.is_main_process:
            with open(os.path.join(pruning_manager.output_dir, "final_r_config.json"), 'w') as f:
                json.dump(final_r_config, f, indent=2)

    # Final evaluation
    logger.info("***** Running final evaluation *****")
    
    # Reset memory stats for evaluation
    if torch.cuda.is_available():
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.empty_cache()
    
    model.eval()
    # No need to load GLUE metric for Alpaca task
    
    # Measure final evaluation computation more accurately
    final_eval_metrics = measure_model_computation(
        model=model,
        dataloader=eval_dataloader,
        num_batches=None,  # Measure all batches for accuracy
        metrics=["flops", "macs"],
        accelerator=accelerator,
        log_output=True
    )
    
    final_eval_total_flops = final_eval_metrics["flops"]
    final_eval_total_macs = final_eval_metrics["macs"]
    
    
    # Use utility function for precise timing measurement of final evaluation
    logger.info("Running final evaluation with precise timing...")
    
    # Ensure optimal LoRA state for final evaluation
    logger.info("Ensuring optimal LoRA state for final evaluation...")
    model.eval()  # Use wrapped model to maintain state
    unwrapped_model = accelerator.unwrap_model(model)
    check_lora_merge_status(unwrapped_model)
    
    timing_results = measure_evaluation_time(
        model=model,
        dataloader=eval_dataloader,
        device=accelerator.device,
        logger=logger
    )
    
    # Extract timing information
    eval_runtime = timing_results["total_time"]
    eval_samples_per_second = timing_results["samples_per_second"]
    eval_steps_per_second = timing_results["steps_per_second"]
    avg_batch_time = timing_results["avg_batch_time"]
    
    logger.info(f"Final evaluation completed: total runtime={eval_runtime:.2f}s, avg batch time={avg_batch_time:.4f}s")
    
    # Process evaluation results
    total_eval_loss = 0
    total_eval_samples = 0
    # Run final Alpaca evaluation
    final_eval_results = evaluate_alpaca_generation(
        model=model,
        eval_dataloader=eval_dataloader,
        tokenizer=tokenizer,
        accelerator=accelerator,
        max_new_tokens=128
    )
    
    total_eval_samples = final_eval_results.get("num_samples", 0)
    total_eval_loss = final_eval_results.get("eval_loss", 0) * len(eval_dataloader)
    valid_batch_added = True
    
    avg_eval_loss = total_eval_loss / len(eval_dataloader) if len(eval_dataloader) > 0 else 0
    formatted_eval_runtime = format_runtime(eval_runtime)

    # Calculate per-sample metrics with correct total samples
    if total_eval_samples > 0:
        final_eval_flops_per_sample = final_eval_total_flops / total_eval_samples
        final_eval_macs_per_sample = final_eval_total_macs / total_eval_samples
        
        logger.info(f"Total evaluation computation: {format_flops(final_eval_total_flops)} FLOPs, {format_flops(final_eval_total_macs)} MACs")
        logger.info(f"Per-sample evaluation: {format_flops(final_eval_flops_per_sample)} FLOPs/sample, {format_flops(final_eval_macs_per_sample)} MACs/sample")
        logger.info(f"Total eval samples: {total_eval_samples}")

    # Calculate and save pruning metrics after final evaluation (if pruning was applied)
    if pruning_manager is not None and accelerator.is_main_process:
        # Get task performance metric (for Alpaca, use inverse perplexity)
        task_performance = 1.0 / eval_metric.get("perplexity", float('inf'))
        
        # Calculate efficiency metrics based on final evaluation
        accuracy_per_gflops = task_performance / (final_eval_total_flops / 1e9) if final_eval_total_flops > 0 else 0
        accuracy_per_gmacs = task_performance / (final_eval_total_macs / 1e9) if final_eval_total_macs > 0 else 0

        # Prepare and save complete pruning metrics
        pruning_metrics = {
            "initial_parameters": pruning_manager.initial_params,
            "final_parameters": pruning_utils.calculate_total_parameters(
                final_r_config, pruning_manager.layer_sizes),
            "reduction_rate": float(achieved_reduction),
            "target_reduction_rate": args.pruning_target_reduction,
            "pruning_steps": args.pruning_steps,
            "baseline_performance": float(pruning_manager.baseline_performance),
            "final_performance": task_performance,  # Use final evaluation performance
            "eval_steps_per_second": eval_steps_per_second,
            "accuracy_per_gflops": accuracy_per_gflops,
            "accuracy_per_gmacs": accuracy_per_gmacs,
            "final_eval_flops_per_sample_gflops": final_eval_flops_per_sample / 1e9,
            "final_eval_macs_per_sample_gflops": final_eval_macs_per_sample / 1e9,
            "final_eval_total_flops_gflops": final_eval_total_flops / 1e9,
            "final_eval_total_macs_gflops": final_eval_total_macs / 1e9,
            "pruned_model_size_mb": model_size_mb * (1 - achieved_reduction),
            "model_size_reduction_pct": achieved_reduction * 100
        }
        
        # Save pruning metrics
        save_metrics(pruning_metrics, os.path.join(pruning_manager.output_dir, "pruning_metrics.json"))
        
        # Log final pruning efficiency metrics
        logger.info("=" * 80)
        logger.info("PRUNING EFFICIENCY METRICS (FINAL EVALUATION):")
        logger.info("=" * 80)
        logger.info(f"  Eval steps per second: {eval_steps_per_second:.2f}")
        logger.info(f"  Accuracy per GFLOPs: {accuracy_per_gflops:.6f}")
        logger.info(f"  Accuracy per GMACs: {accuracy_per_gmacs:.6f}")
        logger.info(f"  Pruned model size: {pruning_metrics['pruned_model_size_mb']:.2f} MB")
        logger.info(f"  Size reduction: {achieved_reduction*100:.2f}%")
        logger.info(f"  Final performance: {task_performance:.4f}")
        logger.info("=" * 80)

    # Use Alpaca evaluation results
    if valid_batch_added:
        eval_metric = final_eval_results
        
        # For Alpaca, use perplexity as the main metric
        combined_score = 1.0 / eval_metric.get("perplexity", float('inf'))  # Lower perplexity is better
        
        # Determine metric key name
        accuracy_key = "eval_perplexity"
        
        # Prepare eval metrics in the original order
        eval_metrics = {
            "epoch": args.num_train_epochs,
            "eval_flops_gflops": final_eval_total_flops / (10**9),  # Total GFLOPs
            "eval_macs_gflops": final_eval_total_macs / (10**9),  # Total GMACs
            accuracy_key: eval_metric.get("accuracy", eval_metric.get("matthews_correlation", 0)),
            "eval_combined_score": combined_score,
            "eval_loss": avg_eval_loss,
            "eval_runtime_formatted": formatted_eval_runtime,  # Human-readable format
            "eval_samples": total_eval_samples,
            "eval_samples_per_second": eval_samples_per_second,
            "eval_steps_per_second": len(eval_dataloader) / eval_runtime if eval_runtime > 0 else 0
        }
        
        # Add per-sample metrics as additional fields
        eval_metrics.update({
            "eval_flops_per_sample_gflops": final_eval_flops_per_sample / (10**9),
            "eval_macs_per_sample_gflops": final_eval_macs_per_sample / (10**9),
        })
        
        # Add Alpaca-specific metrics
        if "perplexity" in eval_metric:
            eval_metrics["eval_perplexity"] = eval_metric["perplexity"]
        
        # Add memory statistics
        if torch.cuda.is_available():
            memory_stats = get_memory_stats()
            eval_metrics.update({
                "eval_peak_memory_gb": memory_stats.get("peak_allocated_gb", 0),
                "eval_memory_reserved_gb": memory_stats.get("reserved_gb", 0),
            })
        
        # Log eval metrics
        logger.info("***** Eval metrics *****")
        for key, value in eval_metrics.items():
            if isinstance(value, float):
                logger.info(f"  {key:<25} = {value:>10.6f}")
            else:
                logger.info(f"  {key:<25} = {value:>10}")
                
    else:
        logger.warning("No valid batches for evaluation. Final evaluation skipped.")
        eval_metrics = {"eval_perplexity": float('inf'), "eval_loss": 0.0}  # Provide default values

    # Run MT-Bench evaluation if requested
    mt_bench_results = None
    if args.run_mt_bench and accelerator.is_main_process:
        logger.info("***** Running MT-Bench evaluation *****")
        logger.info(f"MT-Bench results will be saved to: {args.mt_bench_output_dir}")
        
        mt_bench_results = run_mt_bench_evaluation(
            model=model,
            tokenizer=tokenizer,
            output_dir=args.mt_bench_output_dir,
            accelerator=accelerator,
            logger=logger,
            skip_judgment=args.skip_mt_bench_judgment,
            fastchat_path=args.fastchat_path
        )
        
        if mt_bench_results:
            # Add MT-Bench results to eval metrics
            eval_metrics["mt_bench_score"] = mt_bench_results.get("overall_score", -1)
            eval_metrics["mt_bench_judgment_skipped"] = mt_bench_results.get("judgment_skipped", False)
            eval_metrics["mt_bench_answer_file"] = mt_bench_results.get("answer_file", "")
            
            if "overall_score" in mt_bench_results:
                logger.info(f"MT-Bench Overall Score: {mt_bench_results['overall_score']:.2f}")
            else:
                logger.info("MT-Bench answers generated. Judgment skipped or pending.")
        else:
            logger.warning("MT-Bench evaluation failed or was skipped")
            eval_metrics["mt_bench_score"] = -1
            eval_metrics["mt_bench_error"] = True

    # Save results
    if accelerator.is_main_process:
        # Save final model
        unwrapped_model = accelerator.unwrap_model(model)
        unwrapped_model.save_pretrained(
            os.path.join(args.output_dir, "final_model"),
            save_function=accelerator.save
        )
        tokenizer.save_pretrained(os.path.join(args.output_dir, "final_model"))
        
        # Save metrics
        save_metrics(eval_metrics, os.path.join(args.output_dir, "eval_metrics.json"))
        save_metrics({**train_metrics, **eval_metrics}, os.path.join(args.output_dir, "all_metrics.json"))
    
    logger.info(f"Training completed. Results saved to {args.output_dir}")


if __name__ == "__main__":
    main()