"""
Utility functions for LoRA pruning.

This module provides functions for measuring layer importance, estimating performance loss,
classifying layers, and validating pruning results.
"""

import torch
import numpy as np
import logging
import sys
import time
import matplotlib.pyplot as plt
from typing import Dict, List, Tuple, Optional, Any, Union, Set
from torch.utils.data import DataLoader
import torch.nn as nn
import os
from collections import defaultdict
import copy
import math

from . import pruning_config

logger = logging.getLogger(__name__)

def get_layer_size(model: nn.Module, layer_name: str) -> Tuple[int, int]:
    """
    Get the input and output dimensions of a named layer.
    
    Args:
        model: The model containing the layer
        layer_name: The name of the layer
        
    Returns:
        Tuple of (in_features, out_features)
    """
    try:
        names = layer_name.split('.')
        module = model
        for name in names:
            module = getattr(module, name)
        
        if hasattr(module, 'in_features') and hasattr(module, 'out_features'):
            return module.in_features, module.out_features
        
        # For merged/special layers, try to infer dimensions
        if hasattr(module, 'weight'):
            shape = module.weight.shape
            if len(shape) == 2:  # Linear layer
                return shape[1], shape[0]  # (in_features, out_features)
            
        logger.warning(f"Could not determine dimensions for layer {layer_name}")
        return 0, 0
    except (AttributeError, ValueError) as e:
        logger.warning(f"Error getting size for {layer_name}: {e}")
        return 0, 0

def classify_layers_by_type(layer_names: List[str]) -> Dict[str, List[str]]:
    """
    Classify layers by their type (query, key, value, etc.).
    Supports both BERT/RoBERTa and LLama model naming conventions.
    
    Args:
        layer_names: List of layer names
        
    Returns:
        Dictionary mapping layer type to list of layer names
    """
    layer_types = {
        "query": [],
        "key": [],
        "value": [],
        "attention_output": [],
        "intermediate": [],
        "output": [],
        "other": []
    }
    
    for layer_name in layer_names:
        # Support both BERT/RoBERTa and LLama naming conventions
        if "query" in layer_name or "q_proj" in layer_name:
            layer_types["query"].append(layer_name)
        elif "key" in layer_name or "k_proj" in layer_name:
            layer_types["key"].append(layer_name)
        elif "value" in layer_name or "v_proj" in layer_name:
            layer_types["value"].append(layer_name)
        elif "attention.output" in layer_name or "o_proj" in layer_name:
            layer_types["attention_output"].append(layer_name)
        elif "intermediate" in layer_name or "gate_proj" in layer_name or "up_proj" in layer_name:
            layer_types["intermediate"].append(layer_name)
        elif ("output" in layer_name and "attention" not in layer_name) or "down_proj" in layer_name:
            layer_types["output"].append(layer_name)
        else:
            layer_types["other"].append(layer_name)
    
    # Remove empty categories
    return {k: v for k, v in layer_types.items() if v}

def find_consecutive_layers(layer_names: List[str]) -> List[Tuple[str, str]]:
    """
    Find consecutive layer pairs in the model architecture.
    
    Args:
        layer_names: List of layer names
        
    Returns:
        List of consecutive layer pairs
    """
    # Extract transformer block indices and types
    layer_info = []
    for name in layer_names:
        parts = name.split('.')
        
        # Extract layer and block indices for transformer architectures
        block_idx = None
        layer_type = None
        
        for i, part in enumerate(parts):
            if part.isdigit() and i < len(parts) - 1:
                block_idx = int(part)
                # Layer type is usually the next part(s)
                layer_type = '.'.join(parts[i+1:])
                break
        
        if block_idx is not None and layer_type is not None:
            layer_info.append((name, block_idx, layer_type))
    
    # Sort by block index and find consecutive layers
    layer_info.sort(key=lambda x: (x[1], x[2]))
    consecutive_pairs = []
    
    # Group by block index
    blocks = defaultdict(list)
    for name, block_idx, layer_type in layer_info:
        blocks[block_idx].append((name, layer_type))
    
    # Find connected layers within each block
    # Support both BERT/RoBERTa and LLama architectures
    connections = [
        # Self-attention internal connections (BERT/RoBERTa)
        ("query", "key"), ("key", "value"), 
        ("value", "attention.output"),
        # Self-attention internal connections (LLama)
        ("q_proj", "k_proj"), ("k_proj", "v_proj"),
        ("v_proj", "o_proj"),
        # FFN connections (BERT/RoBERTa)
        ("attention.output", "intermediate"), 
        ("intermediate", "output"),
        # FFN connections (LLama) 
        ("o_proj", "gate_proj"), ("o_proj", "up_proj"),
        ("gate_proj", "down_proj"), ("up_proj", "down_proj")
    ]
    
    for block_idx, layers in blocks.items():
        layer_dict = {layer_type: name for name, layer_type in layers}
        
        for src_type, dst_type in connections:
            if src_type in layer_dict and dst_type in layer_dict:
                consecutive_pairs.append((layer_dict[src_type], layer_dict[dst_type]))
    
    return consecutive_pairs

def measure_layer_importance(
    model: nn.Module, 
    dataloader: DataLoader, 
    r_config: Dict[str, int], 
    device: torch.device,
    num_batches: int = 5,
    prev_importances: Optional[Dict[str, float]] = None,
    ema_decay: Optional[float] = None  # Allow override from caller
) -> Dict[str, float]:
    """
    Measure the importance of each LoRA layer based on gradient information with EMA.
    
    Args:
        model: The model to analyze
        dataloader: DataLoader for batch processing
        r_config: Current r configuration
        device: Device to run the model on
        num_batches: Number of batches to use for importance estimation
        prev_importances: Previous importance scores for EMA calculation
        ema_decay: EMA decay factor (default from config)
        
    Returns:
        Dictionary mapping layer_name to importance score
    """
    # Use provided ema_decay or fall back to config default
    actual_ema_decay = ema_decay if ema_decay is not None else pruning_config.IMPORTANCE_EMA_DECAY
    
    logger.info(f"Measuring layer importance with EMA decay: {actual_ema_decay}")
    start_time = time.time()
    
    # Save original model state
    was_training = model.training
    
    model.to(device)
    model.train()  # Set to train mode to enable gradients
    
    # Initialize importance dict for all layers in r_config
    importances = {name: 0.0 for name in r_config}
    
    # Identify LoRA layers and map names to modules
    lora_layers = {}
    for name, module in model.named_modules():
        if (hasattr(module, 'lora_A') and hasattr(module, 'lora_B')) and name in r_config:
            lora_layers[name] = module
    
    # Process batches
    batch_count = 0
    grad_history = {name: [] for name in r_config}
    
    for batch_idx, batch in enumerate(dataloader):
        if batch_idx >= num_batches:
            break
        
        # Move batch to device efficiently
        batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
        
        # Forward pass with efficient memory usage
        model.zero_grad(set_to_none=True)  # More memory efficient than zero_grad()
        outputs = model(**batch)
        
        # Use loss if available, otherwise create a pseudo-loss
        loss = None
        if hasattr(outputs, 'loss') and outputs.loss is not None:
            loss = outputs.loss
        elif hasattr(outputs, 'logits') and outputs.logits is not None:
            if 'labels' in batch:
                # Use cross entropy loss for classification
                if outputs.logits.dim() > 1 and outputs.logits.size(-1) > 1:
                    loss = torch.nn.functional.cross_entropy(outputs.logits, batch['labels'])
                else:
                    # Use MSE for regression
                    loss = torch.nn.functional.mse_loss(outputs.logits.squeeze(), batch['labels'].float())
            else:
                # Create a pseudo-loss using logits (L2 norm)
                loss = torch.mean(outputs.logits ** 2)
        
        if loss is None:
            logger.warning(f"Batch {batch_idx}: Unable to compute loss, skipping")
            continue
            
        # Backward pass
        loss.backward()
        batch_count += 1
        
        # Calculate importance based on gradients of lora_A and lora_B
        for name, layer in lora_layers.items():
            if hasattr(layer, 'lora_A') and hasattr(layer, 'lora_B'):
                if layer.lora_A.grad is not None and layer.lora_B.grad is not None:
                    # Average norm of gradients for A and B matrices
                    grad_a_norm = layer.lora_A.grad.abs().mean().item()
                    grad_b_norm = layer.lora_B.grad.abs().mean().item()
                    
                    # Geometric mean of gradient norms for more stable importance
                    importance = (grad_a_norm * grad_b_norm) ** 0.5
                    
                    # Weight by current r value to account for parameter count
                    current_r = r_config[name]
                    if current_r > 0:  # Avoid division by zero
                        # Normalized importance: measure per-dimension importance
                        importance = importance * (current_r ** 0.5)
                    
                    # Store gradient history for this layer
                    grad_history[name].append(importance)
                    importances[name] += importance
    
    # Average importance across batches
    if batch_count > 0:
        for name in importances:
            importances[name] /= batch_count
    
    # Apply Exponential Moving Average (EMA) if previous importances are available
    if prev_importances is not None:
        for name in importances:
            if name in prev_importances:
                old_importance = prev_importances[name]
                new_importance = importances[name]
                importances[name] = actual_ema_decay * old_importance + (1 - actual_ema_decay) * new_importance
                
                # Log significant changes in importance scores
                change_ratio = abs(new_importance - old_importance) / (old_importance + 1e-10)
                if change_ratio > 0.5 and logger.level <= logging.DEBUG:  # Only log if significant change
                    logger.debug(f"Layer {name}: importance changed from {old_importance:.6f} to {new_importance:.6f} "
                                f"(EMA: {importances[name]:.6f}, change: {change_ratio:.2%})")
    
    # Normalize importances if configured
    if pruning_config.NORMALIZE_IMPORTANCE and importances:
        max_importance = max(importances.values())
        if max_importance > 0:  # Avoid division by zero
            for name in importances:
                importances[name] /= max_importance
    
    # Set importance to a small non-zero value for layers with zero importance
    min_importance = min([imp for imp in importances.values() if imp > 0], default=1e-6)
    for name in importances:
        if importances[name] <= 0:
            importances[name] = min_importance * 0.1
    
    # Log importance scores
    if pruning_config.LOG_IMPORTANCE_SCORES:
        logger.info(f"Layer importance scores (top 10):")
        for name, importance in sorted(importances.items(), key=lambda x: x[1], reverse=True)[:10]:
            logger.info(f"  {name}: {importance:.6f}")
    
    # Restore original model state
    if not was_training:
        model.eval()
    
    logger.info(f"Importance measurement completed in {time.time() - start_time:.2f}s")
    return importances

def estimate_performance_loss(
    layer_name: str,
    current_r: int,
    new_r: int,
    importance: float,
    layer_size: Tuple[int, int]
) -> float:
    """
    Estimate the performance loss when reducing a layer's r value.
    
    Args:
        layer_name: Name of the layer
        current_r: Current r value
        new_r: New r value
        importance: Importance score of the layer
        layer_size: Tuple of (in_features, out_features) for the layer
        
    Returns:
        Estimated performance loss
    """
    # If increasing r or keeping the same, no loss
    if new_r >= current_r:
        return 0.0
    
    # If current_r is 0, can't reduce further
    if current_r <= 0:
        return 0.0
    
    # Calculate reduction ratio
    if current_r > 0:
        reduction_ratio = (current_r - new_r) / current_r
    else:
        reduction_ratio = 1.0 if new_r == 0 else 0.0
    
    # Estimate performance loss based on importance and reduction ratio only
    # Simplified formula: loss = importance * reduction_ratio
    loss = importance * reduction_ratio
    
    return loss

def calculate_total_parameters(r_config: Dict[str, int], layer_sizes: Dict[str, Tuple[int, int]]) -> int:
    """
    Calculate the total number of parameters in a LoRA configuration.
    
    Args:
        r_config: Dictionary mapping layer name to r value
        layer_sizes: Dictionary mapping layer name to (in_features, out_features)
        
    Returns:
        Total parameter count
    """
    total_params = 0
    
    for layer_name, r in r_config.items():
        if r <= 0:
            continue
            
        if layer_name in layer_sizes:
            in_features, out_features = layer_sizes[layer_name]
            # LoRA parameters: r * (in_features + out_features)
            layer_params = r * (in_features + out_features)
            total_params += layer_params
    
    return total_params

def get_lora_modules(model: nn.Module) -> Dict[str, nn.Module]:
    """
    Get all LoRA modules in the model.
    
    Args:
        model: The model to analyze
        
    Returns:
        Dictionary mapping layer name to LoRA module
    """
    lora_modules = {}
    for name, module in model.named_modules():
        if hasattr(module, 'lora_A') and hasattr(module, 'lora_B'):
            lora_modules[name] = module
    return lora_modules

def save_model_checkpoint(model: nn.Module, path: str) -> None:
    """
    Save model checkpoint for rollback purposes.
    
    Args:
        model: Model to save
        path: Path to save the checkpoint
    """
    # Create directory if it doesn't exist
    os.makedirs(os.path.dirname(path), exist_ok=True)
    
    # Save only LoRA weights to minimize checkpoint size
    lora_state_dict = {}
    for name, param in model.state_dict().items():
        if 'lora_' in name:
            lora_state_dict[name] = param
    
    torch.save(lora_state_dict, path)
    logger.info(f"Saved LoRA checkpoint to {path}")

def load_model_checkpoint(model: nn.Module, path: str, allow_size_mismatch: bool = True) -> nn.Module:
    """
    Load model checkpoint for rollback with handling of LoRA layer size mismatches.
    
    Args:
        model: Model to update
        path: Path to the checkpoint
        allow_size_mismatch: If True, handle LoRA layer size mismatches
        
    Returns:
        Updated model
    """
    if not os.path.exists(path):
        logger.error(f"Checkpoint not found at {path}")
        return model
    
    # Load state dict
    try:
        lora_state_dict = torch.load(path, map_location='cpu')
        
        # Get current state dict
        model_state_dict = model.state_dict()
        
        # Handle size mismatches for LoRA layers
        resized_params = []
        
        # First, check for size mismatches
        mismatched_params = []
        for name, checkpoint_param in lora_state_dict.items():
            if name in model_state_dict:
                if model_state_dict[name].shape != checkpoint_param.shape:
                    mismatched_params.append((name, checkpoint_param.shape, model_state_dict[name].shape))
        
        if mismatched_params and not allow_size_mismatch:
            # If size mismatches are not allowed, log the error and return the original model
            logger.error(f"Size mismatches detected in {len(mismatched_params)} parameters.")
            for name, checkpoint_shape, model_shape in mismatched_params:
                logger.error(f"  {name}: checkpoint={checkpoint_shape}, model={model_shape}")
            
            logger.error("Rollback aborted due to parameter size mismatches. Use allow_size_mismatch=True to attempt resize.")
            return model
        
        # Update LoRA parameters with resize handling
        for name, checkpoint_param in lora_state_dict.items():
            if name in model_state_dict:
                if model_state_dict[name].shape == checkpoint_param.shape:
                    # Shapes match, directly update
                    model_state_dict[name] = checkpoint_param
                elif allow_size_mismatch and ('lora_A' in name or 'lora_B' in name):
                    # Handle LoRA layer resize
                    logger.warning(f"Resizing parameter {name}: checkpoint={checkpoint_param.shape}, model={model_state_dict[name].shape}")
                    
                    # Get module and parameter name
                    module_path = '.'.join(name.split('.')[:-1])
                    param_name = name.split('.')[-1]
                    
                    # Get the actual module
                    try:
                        module = model
                        for part in module_path.split('.'):
                            module = getattr(module, part)
                            
                        # Determine current r value
                        current_r = getattr(module, 'r', 0)
                        
                        # Handle lora_A resize (r, in_features)
                        if param_name == 'lora_A':
                            checkpoint_r = checkpoint_param.shape[0]
                            in_features = checkpoint_param.shape[1]
                            
                            if current_r < checkpoint_r:
                                # Select top-r rows
                                model_state_dict[name] = checkpoint_param[:current_r, :]
                                resized_params.append(name)
                            elif current_r > checkpoint_r:
                                # Copy available rows, initialize new ones
                                new_param = torch.zeros((current_r, in_features), dtype=checkpoint_param.dtype, device=checkpoint_param.device)
                                new_param[:checkpoint_r, :] = checkpoint_param
                                
                                # Initialize remaining rows
                                if current_r > checkpoint_r:
                                    nn.init.kaiming_uniform_(new_param[checkpoint_r:, :], a=math.sqrt(5))
                                
                                model_state_dict[name] = new_param
                                resized_params.append(name)
                                
                        # Handle lora_B resize (out_features, r)
                        elif param_name == 'lora_B':
                            out_features = checkpoint_param.shape[0]
                            checkpoint_r = checkpoint_param.shape[1]
                            
                            if current_r < checkpoint_r:
                                # Select first r columns
                                model_state_dict[name] = checkpoint_param[:, :current_r]
                                resized_params.append(name)
                            elif current_r > checkpoint_r:
                                # Copy available columns, initialize new ones
                                new_param = torch.zeros((out_features, current_r), dtype=checkpoint_param.dtype, device=checkpoint_param.device)
                                new_param[:, :checkpoint_r] = checkpoint_param
                                
                                # Initialize remaining columns to zero (LoRA B initialization)
                                model_state_dict[name] = new_param
                                resized_params.append(name)
                        
                        # Update r and scaling values if needed
                        if hasattr(module, 'r') and hasattr(module, 'scaling') and hasattr(module, 'lora_alpha'):
                            module.r = current_r
                            module.scaling = module.lora_alpha / current_r if current_r > 0 else 0
                            
                    except (AttributeError, ValueError) as e:
                        logger.error(f"Error resizing {name}: {e}")
                else:
                    # Parameter shapes don't match and not a LoRA parameter or not allowed to resize, skip
                    logger.warning(f"Skipping parameter {name} due to shape mismatch: checkpoint={checkpoint_param.shape}, model={model_state_dict[name].shape}")
        
        # Load updated state dict
        model.load_state_dict(model_state_dict)
        
        # Log summary
        if resized_params:
            logger.info(f"Resized {len(resized_params)} LoRA parameters during checkpoint loading")
            
        logger.info(f"Loaded LoRA checkpoint from {path}")
        
    except Exception as e:
        logger.error(f"Error loading checkpoint: {e}")
        import traceback
        logger.error(traceback.format_exc())
        
    return model

def modify_lora_layers(
    model: nn.Module, 
    new_r_config: Dict[str, int]
) -> Tuple[nn.Module, Dict[str, Any]]:
    """
    Modify LoRA layers according to a new r configuration.
    
    Args:
        model: The model to modify
        new_r_config: New r configuration
        
    Returns:
        Tuple of (modified model, changes dictionary)
    """
    # Store original logger settings
    original_level = logger.level
    original_propagate = getattr(logger, 'propagate', True)
    
    # Ensure propagation and visibility
    logger.propagate = True
    logger.setLevel(logging.INFO)
    
    # Efficiently handle logger configuration
    temp_handler = None
    original_handler_levels = {}
    
    # Check if we need to add a temporary handler
    if not logger.handlers:
        temp_handler = logging.StreamHandler(sys.stdout)
        temp_handler.setLevel(logging.INFO)
        formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
        temp_handler.setFormatter(formatter)
        logger.addHandler(temp_handler)
    else:
        # Store and adjust existing handler levels
        for handler in logger.handlers:
            original_handler_levels[handler] = handler.level
            handler.setLevel(logging.INFO)
    
    try:
        changes = {}
        lora_modules = get_lora_modules(model)
        mismatch_detected = False
        
        logger.info("=" * 60)
        logger.info("MODIFYING LoRA LAYERS")
        logger.info("=" * 60)
        
        # Group layers by type for more organized logging
        layer_names = list(new_r_config.keys())
        layer_types = classify_layers_by_type(layer_names)
        
        # Track statistics
        total_decreased = 0
        total_increased = 0
        total_unchanged = 0
        total_params_before = 0
        total_params_after = 0
        
        for layer_type, layers in layer_types.items():
            logger.info(f"\n[{layer_type.upper()} LAYERS]")
            
            for layer_name in sorted(layers):
                new_r = new_r_config[layer_name]
                
                if layer_name not in lora_modules:
                    if new_r == 0:
                        # For r=0 layers, this is expected behavior
                        logger.info(f"Layer {layer_name}: r=0 (SKIPPED)")
                        total_unchanged += 1
                    else:
                        # For r>0 layers, log a more informative warning
                        logger.warning(f"Layer {layer_name} not found as a LoRA module (r={new_r}), SKIPPING...")
                    continue
                    
                module = lora_modules[layer_name]
                current_r = module.r if hasattr(module, 'r') else getattr(module, 'lora_A').shape[0]
                new_r = new_r_config[layer_name]
                
                # Calculate parameters
                in_features = module.lora_A.size(1)
                out_features = module.lora_B.size(0)
                current_params = current_r * (in_features + out_features)
                new_params = new_r * (in_features + out_features)
                
                total_params_before += current_params
                
                if current_r == new_r:
                    logger.info(f"Layer {layer_name}: r={current_r} (UNCHANGED)")
                    total_unchanged += 1
                    total_params_after += current_params
                    continue
                    
                # Record change
                changes[layer_name] = {
                    'old_r': current_r,
                    'new_r': new_r,
                    'in_features': in_features,
                    'out_features': out_features
                }
                
                # Update statistics based on change type
                if new_r < current_r:
                    total_decreased += 1
                else:
                    total_increased += 1
                
                if new_r == 0:
                    # Set weights to zero instead of resizing to maintain structure
                    with torch.no_grad():
                        module.lora_A.fill_(0)
                        module.lora_B.fill_(0)
                        # Make sure r value is actually updated
                        module.r = new_r
                        module.scaling = module.lora_alpha / new_r if new_r > 0 else 0
                    
                    logger.info(f"Layer {layer_name}: r={current_r} → 0 (zeroed)")
                    logger.info(f"  Parameters: {current_params:,} → 0 ({-current_params:,})")
                    continue
                    
                if current_r == 0:
                    # Reinitialize weights
                    with torch.no_grad():
                        # Create new tensors with the right shape and dtype from the original weight
                        base_dtype = module.weight.dtype if hasattr(module, 'weight') else module.lora_A.dtype
                        new_lora_A = torch.zeros(
                            (new_r, module.lora_A.size(1)), 
                            device=module.lora_A.device, 
                            dtype=base_dtype
                        )
                        new_lora_B = torch.zeros(
                            (module.lora_B.size(0), new_r), 
                            device=module.lora_B.device, 
                            dtype=base_dtype
                        )
                        
                        # Initialize with kaiming uniform for A and zeros for B
                        nn.init.kaiming_uniform_(new_lora_A, a=math.sqrt(5))
                        
                        # Update module parameters
                        module.lora_A = nn.Parameter(new_lora_A)
                        module.lora_B = nn.Parameter(new_lora_B)
                        module.r = new_r
                        module.scaling = module.lora_alpha / new_r
                    
                    logger.info(f"Layer {layer_name}: r=0 → {new_r} (reinitialized)")
                    logger.info(f"  Parameters: 0 → {new_params:,} ({new_params:+,})")
                    
                    total_params_after += new_params
                    continue
                
                # Handle changing r value (both increase and decrease)
                with torch.no_grad():
                    if new_r < current_r:
                        # Memory-efficient importance calculation
                        a_importance = torch.norm(module.lora_A, dim=1)
                        b_importance = torch.norm(module.lora_B, dim=0)
                        combined_importance = a_importance * b_importance
                        
                        # Get indices of the most important dimensions
                        _, top_indices = torch.topk(combined_importance, new_r)
                        
                        # Create new tensors with selected dimensions
                        new_lora_A = module.lora_A[top_indices]
                        new_lora_B = module.lora_B[:, top_indices]
                        
                        logger.info(f"Layer {layer_name}: r={current_r} → {new_r} (PRUNED BY {current_r-new_r})")
                    else:
                        # Increase r by adding new dimensions with correct dtype
                        base_dtype = module.weight.dtype if hasattr(module, 'weight') else module.lora_A.dtype
                        new_lora_A = torch.zeros(
                            (new_r, module.lora_A.size(1)), 
                            device=module.lora_A.device, 
                            dtype=base_dtype
                        )
                        new_lora_B = torch.zeros(
                            (module.lora_B.size(0), new_r), 
                            device=module.lora_B.device, 
                            dtype=base_dtype
                        )
                        
                        # Copy existing weights efficiently
                        new_lora_A[:current_r].copy_(module.lora_A)
                        new_lora_B[:, :current_r].copy_(module.lora_B)
                        
                        # Initialize new dimensions
                        if new_r > current_r:
                            nn.init.kaiming_uniform_(new_lora_A[current_r:], a=math.sqrt(5))
                        
                        logger.info(f"Layer {layer_name}: r={current_r} → {new_r} (expanded by {new_r-current_r})")
                    
                    # Update module parameters
                    module.lora_A = nn.Parameter(new_lora_A)
                    module.lora_B = nn.Parameter(new_lora_B)
                    old_r = module.r
                    module.r = new_r
                    module.scaling = module.lora_alpha / new_r
                    
                    # Verify that r was actually updated
                    if module.r != new_r:
                        logger.error(f"FAILED to update r value for {layer_name}: expected {new_r}, got {module.r}")
                        mismatch_detected = True
                    
                    logger.info(f"  Parameters: {current_params:,} → {new_params:,} ({new_params-current_params:+,})")
                    
                    total_params_after += new_params
        
        # Summarize changes
        total_reduction = total_params_before - total_params_after
        reduction_pct = (total_reduction / max(total_params_before, 1)) * 100  # Avoid division by zero
        
        logger.info("\nSUMMARY OF CHANGES:")
        logger.info(f"  Decreased r-values: {total_decreased} layers")
        logger.info(f"  Increased r-values: {total_increased} layers")
        logger.info(f"  Unchanged r-values: {total_unchanged} layers")
        logger.info(f"  Total parameters: {total_params_before:,} → {total_params_after:,} ({total_params_after-total_params_before:+,})")
        logger.info(f"  Parameter reduction: {reduction_pct:.2f}%")
        
        if mismatch_detected:
            logger.error("CRITICAL: Some r-values were not correctly updated in the model.")
        
        logger.info("=" * 60)
        
        return model, changes
        
    finally:
        # Clean up and restore logger settings
        if temp_handler:
            logger.removeHandler(temp_handler)
            temp_handler = None
            
        # Restore original handler levels
        for handler, level in original_handler_levels.items():
            handler.setLevel(level)
            
        # Restore original logger settings
        logger.setLevel(original_level)
        logger.propagate = original_propagate

def evaluate_model(
    model: nn.Module, 
    eval_dataloader: DataLoader, 
    device: torch.device
) -> float:
    """
    Evaluate model performance.
    
    Args:
        model: Model to evaluate
        eval_dataloader: Evaluation dataloader
        device: Device to run evaluation on
        
    Returns:
        Performance metric (accuracy or -loss)
    """
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0
    
    # Check if model has pad_token_id set
    if hasattr(model, 'config') and hasattr(model.config, 'pad_token_id'):
        if model.config.pad_token_id is None:
            logger.warning("Model pad_token_id is None, setting to eos_token_id for evaluation")
            if hasattr(model.config, 'eos_token_id'):
                model.config.pad_token_id = model.config.eos_token_id
    
    with torch.no_grad():
        for batch in eval_dataloader:
            # Move batch to device
            batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
            
            # For LLaMA models, ensure we have attention_mask to handle padding
            if "attention_mask" not in batch and "input_ids" in batch:
                # Create attention mask (all 1s since we assume no padding in individual samples)
                batch["attention_mask"] = torch.ones_like(batch["input_ids"])
            
            # Forward pass
            outputs = model(**batch)
            
            # Calculate loss
            if hasattr(outputs, 'loss') and outputs.loss is not None:
                total_loss += outputs.loss.item()
            
            # Calculate accuracy if applicable
            if hasattr(outputs, 'logits') and 'labels' in batch:
                predictions = torch.argmax(outputs.logits, dim=-1)
                correct += (predictions == batch['labels']).sum().item()
                total += batch['labels'].size(0)
    
    # Return accuracy if available, otherwise negative loss
    if total > 0:
        return correct / total
    else:
        return -total_loss / len(eval_dataloader)

def validate_pruning(
    model: nn.Module, 
    eval_dataloader: DataLoader, 
    baseline_performance: float,
    threshold: float = pruning_config.PERFORMANCE_DROP_THRESHOLD,
    device: torch.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
) -> Tuple[float, bool]:
    """
    Validate the model after pruning to ensure performance is maintained.
    
    Args:
        model: The pruned model
        eval_dataloader: DataLoader for evaluation
        baseline_performance: Performance metric before pruning
        threshold: Maximum acceptable performance drop
        device: Device to run evaluation on
        
    Returns:
        Tuple of (current performance, is_acceptable)
    """
    logger.info(f"Validating pruned model (baseline: {baseline_performance:.4f}, threshold: {threshold:.4f})")
    
    # Evaluate pruned model
    performance = evaluate_model(model, eval_dataloader, device)
    
    # Calculate performance drop
    performance_drop = max(0, baseline_performance - performance)
    performance_drop_percentage = (performance_drop / baseline_performance) * 100 if baseline_performance > 0 else 0
    
    # Check if performance is acceptable
    is_acceptable = performance_drop <= threshold * baseline_performance
    
    logger.info(f"Validation results: performance={performance:.4f}, drop={performance_drop:.4f} ({performance_drop_percentage:.2f}%)")
    logger.info(f"Verdict: {'ACCEPTABLE' if is_acceptable else 'REJECTED'}")
    
    return performance, is_acceptable

def validate_pruning_configuration(
    model: nn.Module,
    prev_r_config: Dict[str, int],
    new_r_config: Dict[str, int],
    layer_sizes: Optional[Dict[str, Tuple[int, int]]] = None,
    logger: Optional[logging.Logger] = None
) -> Tuple[Dict[str, Dict[str, Any]], Dict[str, Any]]:
    """
    Validate that the pruned model's configuration matches the expected configuration.
    Similar to validate_lora_configuration but specifically for pruning.
    
    Args:
        model: The pruned model
        prev_r_config: Previous r configuration before pruning
        new_r_config: New r configuration after pruning
        layer_sizes: Dictionary of layer sizes (in_features, out_features) for parameter calculation
        logger: Optional logger for output (uses module logger if None)
        
    Returns:
        Tuple of (layer_details, summary) with validation results
    """
    # Use module logger if none provided
    if logger is None:
        logger = logging.getLogger(__name__)
    
    # Store original logger settings
    original_level = logger.level
    original_propagate = getattr(logger, 'propagate', True)
    
    # Enable propagation to ensure log visibility
    logger.propagate = True
    
    # Store handler settings and add temporary handler if needed
    temp_handler = None
    original_handler_levels = {}
    
    # Set logger to INFO level
    logger.setLevel(logging.INFO)
    
    # Check if handler exists and add temporary one if needed
    if not logger.handlers:
        temp_handler = logging.StreamHandler(sys.stdout)
        temp_handler.setLevel(logging.INFO)
        formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
        temp_handler.setFormatter(formatter)
        logger.addHandler(temp_handler)
    else:
        # Store and adjust existing handler levels
        for handler in logger.handlers:
            original_handler_levels[handler] = handler.level
            handler.setLevel(logging.INFO)
    
    try:
        logger.info("=" * 60)
        logger.info("VALIDATING PRUNED LoRA CONFIGURATION")
        logger.info("=" * 60)
        
        layer_details = {}
        applied_count = 0
        mismatch_count = 0
        pruned_count = 0
        total_params_before = 0
        total_params_after = 0
        total_reduction = 0
        
        # Extract layer sizes if not provided
        if layer_sizes is None:
            layer_sizes = {}
            for layer_name in set(prev_r_config.keys()).union(set(new_r_config.keys())):
                layer_sizes[layer_name] = get_layer_size(model, layer_name)
        
        # Check all layers from the configurations
        all_layers = set(prev_r_config.keys()).union(set(new_r_config.keys()))
        
        # Group layers by type for more organized output
        layer_types = classify_layers_by_type(list(all_layers))
        
        for layer_type, type_layers in layer_types.items():
            logger.info(f"\n[{layer_type.upper()} LAYERS]")
            
            for layer_name in sorted(type_layers):
                prev_r = prev_r_config.get(layer_name, 0)
                expected_r = new_r_config.get(layer_name, 0)
                
                # Get actual r value from model
                actual_r = 0
                try:
                    module = None
                    names = layer_name.split('.')
                    module = model
                    for name in names:
                        module = getattr(module, name)
                        
                    if hasattr(module, 'r'):
                        actual_r = module.r
                    elif hasattr(module, 'lora_A'):
                        actual_r = module.lora_A.shape[0]
                except (AttributeError, ValueError):
                    logger.warning(f"Module {layer_name} not found in model")
                    continue
                
                # Calculate parameter counts
                in_features, out_features = layer_sizes.get(layer_name, (0, 0))
                prev_params = prev_r * (in_features + out_features) if prev_r > 0 else 0
                current_params = actual_r * (in_features + out_features) if actual_r > 0 else 0
                param_reduction = prev_params - current_params
                
                # Check if pruned
                was_pruned = prev_r > actual_r
                if was_pruned:
                    pruned_count += 1
                    
                # Check for mismatch
                is_match = actual_r == expected_r
                if not is_match:
                    mismatch_count += 1
                    
                # Status and formatting
                status = "✓ MATCH" if is_match else "✗ MISMATCH"
                
                if was_pruned:
                    pruned_status = "PRUNED"
                    change_symbol = "↓"
                elif prev_r < actual_r:
                    pruned_status = "INCREASED"
                    change_symbol = "↑"
                else:
                    pruned_status = "UNCHANGED"
                    change_symbol = "="
                
                # Log details
                logger.info(f"Layer: {layer_name}")
                logger.info(f"  r-value: {prev_r} {change_symbol} {actual_r} (Expected: {expected_r}) - Status: {status}, {pruned_status}")
                if prev_params > 0 or current_params > 0:
                    param_change = current_params - prev_params
                    logger.info(f"  Parameters: {prev_params:,} → {current_params:,} ({param_change:+,})")
                
                # Store details in an efficient manner
                layer_details[layer_name] = {
                    'prev_r': prev_r,
                    'actual_r': actual_r,
                    'expected_r': expected_r,
                    'is_match': is_match,
                    'was_pruned': was_pruned,
                    'prev_params': prev_params,
                    'current_params': current_params,
                    'param_reduction': param_reduction
                }
                
                # Update totals
                total_params_before += prev_params
                total_params_after += current_params
                total_reduction += param_reduction
                
                if actual_r > 0:
                    applied_count += 1
        
        # Summary with protection against division by zero
        reduction_percentage = (total_reduction / max(total_params_before, 1)) * 100
        
        summary = {
            'total_layers': len(all_layers),
            'pruned_layers': pruned_count,
            'mismatch_layers': mismatch_count,
            'applied_layers': applied_count,
            'total_params_before': total_params_before,
            'total_params_after': total_params_after,
            'total_reduction': total_reduction,
            'reduction_percentage': reduction_percentage
        }
        
        # Log summary
        logger.info("-" * 60)
        logger.info(f"PRUNING SUMMARY")
        logger.info(f"  Total layers: {len(all_layers)}")
        logger.info(f"  Pruned layers: {pruned_count}")
        logger.info(f"  Applied layers: {applied_count}")
        logger.info(f"  LoRA parameters: {total_params_before:,} → {total_params_after:,} ({total_params_after - total_params_before:+,})")
        logger.info(f"  Reduction: {reduction_percentage:.2f}%")
        
        if mismatch_count > 0:
            logger.warning(f"CRITICAL: Found {mismatch_count} layers with r-value mismatches!")
        else:
            logger.info("SUCCESS: All r-values correctly applied")
        logger.info("=" * 60)
        
        return layer_details, summary
    
    finally:
        # Remove temporary handler if added
        if temp_handler:
            logger.removeHandler(temp_handler)
        
        # Restore original handler levels
        for handler, level in original_handler_levels.items():
            handler.setLevel(level)
        
        # Restore original logger settings
        logger.setLevel(original_level)
        logger.propagate = original_propagate

def calculate_model_size_mb(model: nn.Module, only_lora: bool = True) -> float:
    """
    Calculate approximate model size in MB.
    
    Args:
        model: The model to analyze
        only_lora: If True, calculate only LoRA parameters size
        
    Returns:
        Size in MB
    """
    # 4 bytes per parameter (assuming float32)
    bytes_per_param = 4
    
    if only_lora:
        # Count only LoRA parameters
        total_params = 0
        for name, param in model.named_parameters():
            if 'lora_' in name:
                total_params += param.numel()
    else:
        # Count all parameters
        total_params = sum(p.numel() for p in model.parameters())
    
    size_mb = (total_params * bytes_per_param) / (1024 * 1024)
    return size_mb

def plot_pruning_results(
    pruning_history: List[Dict[str, Any]],
    output_dir: str,
    file_prefix: str = "pruning"
) -> None:
    """
    Plot pruning results including parameter reduction and performance.
    
    Args:
        pruning_history: List of dictionaries containing pruning step information
        output_dir: Directory to save plots
        file_prefix: Prefix for plot filenames
    """
    if not pruning_history:
        logger.warning("No pruning history to plot")
        return
    
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    
    # Extract data from history
    steps = list(range(len(pruning_history)))
    param_counts = [step['param_count'] for step in pruning_history]
    performances = [step['performance'] for step in pruning_history]
    
    # Calculate reduction percentages
    initial_params = param_counts[0]
    param_reductions = [(initial_params - p) / initial_params * 100 for p in param_counts]
    
    # Create parameter reduction plot
    plt.figure(figsize=(10, 6))
    plt.plot(steps, param_reductions, 'bo-', linewidth=2)
    plt.xlabel('Pruning Step')
    plt.ylabel('Parameter Reduction (%)')
    plt.title('Progressive Pruning: Parameter Reduction')
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, f"{file_prefix}_param_reduction.png"))
    plt.close()
    
    # Create performance plot
    plt.figure(figsize=(10, 6))
    plt.plot(steps, performances, 'ro-', linewidth=2)
    plt.xlabel('Pruning Step')
    plt.ylabel('Performance Metric')
    plt.title('Progressive Pruning: Model Performance')
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, f"{file_prefix}_performance.png"))
    plt.close()
    
    # Create combined plot
    fig, ax1 = plt.subplots(figsize=(12, 6))
    
    # Parameter reduction (left y-axis)
    ax1.set_xlabel('Pruning Step')
    ax1.set_ylabel('Parameter Reduction (%)', color='blue')
    ax1.plot(steps, param_reductions, 'bo-', linewidth=2)
    ax1.tick_params(axis='y', labelcolor='blue')
    
    # Performance (right y-axis)
    ax2 = ax1.twinx()
    ax2.set_ylabel('Performance Metric', color='red')
    ax2.plot(steps, performances, 'ro-', linewidth=2)
    ax2.tick_params(axis='y', labelcolor='red')
    
    plt.title('Progressive Pruning: Parameter Reduction vs Performance')
    plt.grid(True, linestyle='--', alpha=0.3)
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, f"{file_prefix}_combined.png"))
    plt.close()
    
    # Efficiency metrics plot (if available)
    has_efficiency = all('efficiency' in step for step in pruning_history)
    if has_efficiency:
        efficiencies = [step.get('efficiency', 0) for step in pruning_history]
        
        plt.figure(figsize=(10, 6))
        plt.plot(steps, efficiencies, 'go-', linewidth=2)
        plt.xlabel('Pruning Step')
        plt.ylabel('Performance/Parameters Ratio')
        plt.title('Progressive Pruning: Efficiency Metric')
        plt.grid(True, linestyle='--', alpha=0.7)
        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, f"{file_prefix}_efficiency.png"))
        plt.close()
    
    logger.info(f"Pruning result plots saved to {output_dir}")

def get_efficiency_metrics(
    model: nn.Module,
    r_config: Dict[str, int],
    layer_sizes: Dict[str, Tuple[int, int]],
    performance: float, 
    flops: float,
    macs: float
) -> Dict[str, float]:
    """
    Calculate efficiency metrics for evaluating pruning effectiveness.
    
    Args:
        model: The model
        r_config: Current r configuration
        layer_sizes: Layer dimensions
        performance: Model performance (accuracy or other metric)
        flops: Floating point operations
        macs: Multiply-accumulate operations
        
    Returns:
        Dictionary with efficiency metrics
    """
    # Calculate parameter count
    param_count = calculate_total_parameters(r_config, layer_sizes)
    model_size_mb = calculate_model_size_mb(model, only_lora=True)
    
    # Convert to billions or millions for better scaling
    flops_g = flops / 1e9 if flops > 0 else 1  # GFLOPs
    macs_g = macs / 1e9 if macs > 0 else 1    # GMACs
    
    # Calculate efficiency metrics
    metrics = {
        'param_count': param_count,
        'model_size_mb': model_size_mb,
        'performance': performance,
        'flops_g': flops_g,
        'macs_g': macs_g,
        'performance_per_param': (performance * 1e6) / param_count if param_count > 0 else 0,
        'performance_per_mb': performance / model_size_mb if model_size_mb > 0 else 0,
        'performance_per_gflops': performance / flops_g if flops_g > 0 else 0,
        'performance_per_gmacs': performance / macs_g if macs_g > 0 else 0,
    }
    
    return metrics