"""
ILP solver for optimal LoRA pruning.

This module contains functions to formulate and solve the Integer Linear Programming (ILP)
problem for finding the optimal r-values configuration with reduced parameters.
"""

import time
import math
import logging
import sys
import numpy as np
from typing import Dict, List, Tuple, Optional, Any, Union
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import os
from . import pruning_config
from . import pruning_utils

logger = logging.getLogger(__name__)

def optimize_pruning_step(
    r_config: Dict[str, int],
    layer_importances: Dict[str, float],
    layer_sizes: Dict[str, Tuple[int, int]],
    available_r_values: List[int],
    step_budget: int,
    prev_r_config: Optional[Dict[str, int]] = None,
    momentum_penalty: float = pruning_config.MOMENTUM_PENALTY_WEIGHT,
    time_limit: int = pruning_config.OPTIMIZATION_TIMEOUT,
    seed: int = 42
) -> Dict[str, int]:
    """
    Optimize r-values for a single pruning step using ILP.
    
    Args:
        r_config: Current r configuration (layer_name -> r_value)
        layer_importances: Layer importance scores (layer_name -> importance)
        layer_sizes: Layer sizes (layer_name -> (in_features, out_features))
        available_r_values: List of available r values
        step_budget: Budget constraint for this step
        prev_r_config: Previous r configuration to apply momentum-based penalty
        momentum_penalty: Weight of momentum penalty to penalize large changes
        time_limit: Time limit for optimization in seconds
        seed: Random seed for reproducibility
    
    Returns:
        Optimized r configuration
    """
    logger.info(f"Starting pruning optimization with budget {step_budget:,} parameters and seed {seed}")
    
    try:
        import pulp
    except ImportError:
        logger.error("PuLP is required for this optimization. Please install pulp.")
        raise ImportError("PuLP optimizer is required for pruning optimization.")
    
    start_time = time.time()
    layer_names = list(r_config.keys())
    
    # Filter r_values to only include valid options (less than or equal to current r)
    valid_r_values = {layer: [r for r in available_r_values if r <= r_config[layer]] 
                     for layer in layer_names}
    
    # Create a new model - minimize objective (performance loss)
    logger.info("Formulating ILP problem...")
    ilp_model = pulp.LpProblem("lora_pruning", pulp.LpMinimize)
    
    # Ensure reproducibility at environment level

    os.environ["CBC_RANDOM_SEED"] = str(seed)
    np.random.seed(seed)
    
    # Set extensive solver options for reproducibility
    # These options aim to replicate the Gurobi parameters for deterministic behavior
    solver_options = [
        f"randomSeed {seed}",             # Set random seed for reproducibility
        f"timeLimit {time_limit}",         # Time limit in seconds
        "threads 1",                       # Use single thread for deterministic behavior
        "ratioGap 0.0",                    # Set MIP gap tolerance to 0 (exact solution)
        "allowableGap 0.0",                # Set allowable absolute gap to 0
        "presolve on",                     # Enable presolve (like Gurobi Presolve=2)
        "passPresolve 5",                  # More aggressive presolve (similar to Gurobi)
        "cutoff 1e50",                     # High cutoff value
        "sec 3600",                        # Maximum time per node in seconds
        "strong 10",                       # Strong branching on 10 variables
        "perturbation on",                 # Enable perturbation for stability
        "passC 1000",                      # Pass limit for cut generator
        "cuts on",                         # Enable cut generation (similar to Gurobi Cuts=2)
        "passCuts 10",                     # Number of cut passes
        "cost off",                        # Disable automatic computation for priorities
        "primalP on",                      # Enable primal heuristics
        "logLevel 1",                      # Standard log level (like Gurobi OutputFlag=1)
        "nodeStrategy depth",              # Node selection strategy
        "scaling aggressive",              # Aggressive scaling
        "integerT 1e-9",                   # Integer tolerance (like Gurobi IntFeasTol)
        "primalT 1e-9",                    # Primal tolerance (like Gurobi FeasibilityTol)
        "dualT 1e-9",                      # Dual tolerance
        "OrbitalBranching on",             # Enable orbital branching for symmetry detection
        "prioritize on",                   # Prioritize important variables
        "autoScale on",                    # Auto scaling
    ]
    
    # Create binary decision variables for each layer and valid r value
    variables = {}
    for layer in layer_names:
        for r in valid_r_values[layer]:
            variables[(layer, r)] = pulp.LpVariable(
                f"{layer}_r{r}",
                cat=pulp.LpBinary
            )
    
    # Constraint: Each layer must choose exactly one r value
    for layer in layer_names:
        ilp_model += (
            pulp.lpSum(variables[(layer, r)] for r in valid_r_values[layer]) == 1,
            f"one_r_{layer}"
        )
    
    # Constraint: Total parameters must not exceed budget
    total_params = pulp.lpSum(
        r * (layer_sizes[layer][0] + layer_sizes[layer][1]) * variables[(layer, r)]
        for layer in layer_names
        for r in valid_r_values[layer]
    )
    
    ilp_model += (total_params <= step_budget, "budget_constraint")
    
    # Layer type constraints: Average rank > 0 for each layer type
    layer_types = pruning_utils.classify_layers_by_type(layer_names)
    
    for layer_type, type_layers in layer_types.items():
        if not type_layers:
            continue
            
        # Sum of r values for all layers of this type
        sum_r_expr = pulp.lpSum(
            r * variables[(layer, r)]
            for layer in type_layers
            for r in valid_r_values[layer]
        )
        
        # Average r value for all layers of this type must be > 0
        # This is equivalent to sum of r values > 0
        ilp_model += (
            sum_r_expr >= 0.001,  # Use small epsilon to ensure > 0
            f"avg_r_positive_{layer_type}"
        )
    
    # Objective function: Minimize estimated performance loss with momentum-based penalty
    objective = pulp.LpAffineExpression()
    
    for layer in layer_names:
        current_r = r_config[layer]
        importance = layer_importances.get(layer, 1.0)
        
        for r in valid_r_values[layer]:
            # Estimate performance loss for this choice
            loss = pruning_utils.estimate_performance_loss(
                layer_name=layer,
                current_r=current_r,
                new_r=r,
                importance=importance,
                layer_size=layer_sizes[layer]
            )
            
            # Add momentum-based penalty if previous configuration exists
            if prev_r_config is not None and layer in prev_r_config:
                prev_r = prev_r_config[layer]
                
                # Apply penalty proportional to the absolute change in r value
                # Higher penalty for more important layers
                r_diff = abs(r - prev_r)
                momentum_term = r_diff * momentum_penalty * importance
                
                # Add stability bonus for keeping r the same or similar
                if r == prev_r:
                    # Negative term (bonus) for keeping the same r value
                    momentum_term -= pruning_config.STABLE_LAYER_BONUS * importance
                
                # Add to loss
                loss += momentum_term
                
                # Log this penalty if it's significant
                if r_diff > 0 and logger.level <= logging.DEBUG:
                    logger.debug(f"Layer {layer}: r={prev_r}→{r}, momentum penalty: {momentum_term:.6f}")
            
            objective += loss * variables[(layer, r)]
    
    ilp_model += objective
    
    # Solve the model with enhanced reproducibility settings
    logger.info(f"Starting deterministic optimization with CBC solver (seed: {seed}, threads: 1)")
    
    solver = pulp.PULP_CBC_CMD(
        msg=False,  
        timeLimit=time_limit, 
        options=solver_options,
        keepFiles=False,  # Keep files for debugging
        mip=True,        # Force use of MIP solver
        threads=1,       # Redundant but explicit single thread setting
        gapRel=0.0,      # Relative gap tolerance
        gapAbs=0.0       # Absolute gap tolerance
    )
    
    # Solve with additional time tracking
    solve_start = time.time()
    ilp_model.solve(solver)
    solve_time = time.time() - solve_start
    
    # Check if the model was solved successfully
    status = pulp.LpStatus[ilp_model.status]
    if status == 'Optimal':
        logger.info(f"Optimal solution found in {solve_time:.2f} seconds!")
    elif status == 'Not Solved':
        logger.warning(f"Time limit reached after {solve_time:.2f} seconds, using best solution found so far")
    else:
        logger.error(f"Optimization failed with status {status} after {solve_time:.2f} seconds")
        return r_config  # Return original configuration
        
    # Extract solution
    new_r_config = {}
    for layer in layer_names:
        for r in valid_r_values[layer]:
            if pulp.value(variables[(layer, r)]) > 0.5:  # Variable is selected in solution
                new_r_config[layer] = r
                break
        
        # Fallback if no solution found for a layer
        if layer not in new_r_config:
            logger.warning(f"No r value selected for {layer}, keeping current r={r_config[layer]}")
            new_r_config[layer] = r_config[layer]
    
    # Calculate achieved reduction
    initial_params = pruning_utils.calculate_total_parameters(r_config, layer_sizes)
    final_params = pruning_utils.calculate_total_parameters(new_r_config, layer_sizes)
    reduction = initial_params - final_params
    reduction_percentage = (reduction / initial_params * 100) if initial_params > 0 else 0
    
    logger.info(f"Optimization completed in {time.time() - start_time:.2f}s")
    logger.info(f"Parameter reduction: {reduction:,} parameters ({reduction_percentage:.2f}%)")
    logger.info(f"Final parameter count: {final_params:,} parameters")
    
    # Analyze the influence of momentum penalty if applicable
    if prev_r_config is not None:
        stable_count = sum(1 for layer in layer_names if new_r_config[layer] == prev_r_config.get(layer, 0))
        changed_count = len(layer_names) - stable_count
        
        if changed_count > 0:
            avg_change = sum(abs(new_r_config[layer] - prev_r_config.get(layer, 0)) for layer in layer_names) / len(layer_names)
            logger.info(f"Momentum-based stability: {stable_count}/{len(layer_names)} layers unchanged, avg change: {avg_change:.2f}")
        else:
            logger.info(f"Momentum-based stability: All layers remained unchanged")
    
    # Log changes in r values
    pruned_layers = sum(1 for layer in layer_names if new_r_config[layer] < r_config[layer])
    unchanged_layers = sum(1 for layer in layer_names if new_r_config[layer] == r_config[layer])
    increased_layers = sum(1 for layer in layer_names if new_r_config[layer] > r_config[layer])
    
    logger.info(f"R-value changes: {pruned_layers} decreased, {unchanged_layers} unchanged, {increased_layers} increased")
    
    # Log top pruned layers
    pruned_info = []
    for layer in layer_names:
        if new_r_config[layer] < r_config[layer]:
            param_diff = (r_config[layer] - new_r_config[layer]) * sum(layer_sizes[layer])
            pruned_info.append((layer, r_config[layer], new_r_config[layer], param_diff))
    
    if pruned_info:
        logger.info(f"Top pruned layers (by parameter reduction):")
        for layer, old_r, new_r, param_diff in sorted(pruned_info, key=lambda x: x[3], reverse=True)[:5]:
            logger.info(f"  {layer}: r={old_r} → {new_r} (-{param_diff:,} parameters)")
    
    return new_r_config

class ProgressivePruningManager:
    """
    Manages the progressive pruning process during training.
    
    This class integrates pruning, recovery, validation, and rollback
    mechanisms into the training process.
    """
    
    def __init__(
        self,
        model: nn.Module,
        initial_r_config: Dict[str, int],
        train_dataloader: DataLoader,
        eval_dataloader: DataLoader,
        target_reduction: float,
        num_pruning_steps: int,
        total_training_steps: int,
        device: torch.device,
        output_dir: str = "pruning_results",
        seed: int = 42,
        enable_rollback: bool = False,
        importance_ema_decay: Optional[float] = None,
        momentum_penalty_weight: Optional[float] = None
    ):
        """
        Initialize the progressive pruning manager.
        
        Args:
            model: Model to prune
            initial_r_config: Initial r configuration
            train_dataloader: DataLoader for training data
            eval_dataloader: DataLoader for evaluation data
            target_reduction: Target parameter reduction ratio (0-1)
            num_pruning_steps: Number of pruning steps
            total_training_steps: Total number of training steps
            device: Device to run on
            output_dir: Directory to save results
            seed: Random seed for reproducibility
            enable_rollback: Whether to enable pruning rollback (default: True)
        """
        self.model = model
        self.initial_r_config = initial_r_config
        self.train_dataloader = train_dataloader
        self.eval_dataloader = eval_dataloader
        self.device = device
        self.output_dir = output_dir
        self.seed = seed
        self.enable_rollback = enable_rollback  # Store rollback flag
        
        # Store parameter overrides with fallback to config defaults
        self.importance_ema_decay = importance_ema_decay if importance_ema_decay is not None else pruning_config.IMPORTANCE_EMA_DECAY
        self.momentum_penalty_weight = momentum_penalty_weight if momentum_penalty_weight is not None else pruning_config.MOMENTUM_PENALTY_WEIGHT
        
        # Log parameter values for transparency
        logger.info(f"EMA decay factor: {self.importance_ema_decay} {'(overridden)' if importance_ema_decay is not None else '(config default)'}")
        logger.info(f"Momentum penalty weight: {self.momentum_penalty_weight} {'(overridden)' if momentum_penalty_weight is not None else '(config default)'}")
        
        # Create output directory
        import os
        os.makedirs(output_dir, exist_ok=True)
        
        # Log seed for reproducibility
        logger.info(f"Initializing progressive pruning manager with seed {seed}")
        
        # Get layer sizes
        self.layer_sizes = {}
        for layer_name in initial_r_config:
            self.layer_sizes[layer_name] = pruning_utils.get_layer_size(model, layer_name)
        
        # Calculate initial parameter count
        self.initial_params = pruning_utils.calculate_total_parameters(initial_r_config, self.layer_sizes)
        logger.info(f"Initial LoRA parameters: {self.initial_params:,}")
        
        # Initialize pruning scheduler
        from .pruning_scheduler import PruningScheduler
        self.scheduler = PruningScheduler(
            initial_budget=self.initial_params,
            target_reduction=target_reduction,
            num_pruning_steps=num_pruning_steps,
            total_training_steps=total_training_steps
        )
        
        # Initialize pruning state
        self.current_r_config = initial_r_config.copy()
        self.current_performance = None
        self.baseline_performance = None
        self.pruning_history = []
        self.checkpoint_path = os.path.join(output_dir, "pre_pruning_checkpoint.pt")
        
        # Metrics tracking
        self.pruning_metrics = {
            'eval_flops': 0,
            'eval_macs': 0,
            'pruned_params': 0,
            'model_size_reduction': 0,
            'accuracy_per_gflops': 0,
            'accuracy_per_gmacs': 0,
            'eval_steps_per_second': 0
        }
        
        # State tracking for EMA and momentum
        self.prev_importances = None  # For EMA of layer importances
        self.prev_r_config = None     # For momentum-based penalty
        
    def initialize_baseline_performance(self) -> None:
        """Initialize baseline performance by evaluating the model."""
        self.baseline_performance = pruning_utils.evaluate_model(
            model=self.model,
            eval_dataloader=self.eval_dataloader,
            device=self.device
        )
        self.current_performance = self.baseline_performance
        
        logger.info(f"Baseline performance: {self.baseline_performance:.4f}")
        
        # Add initial state to history
        self.pruning_history.append({
            'step': 0,
            'r_config': self.current_r_config.copy(),
            'param_count': self.initial_params,
            'performance': self.baseline_performance,
            'status': 'initial'
        })
        
        # Log initial model parameter details
        logger.info("=" * 80)
        logger.info("INITIAL MODEL PARAMETERS BREAKDOWN")
        logger.info("=" * 80)
        
        # Group layers by type
        layer_types = pruning_utils.classify_layers_by_type(list(self.initial_r_config.keys()))
        
        # Log parameters per layer type
        for layer_type, layers in layer_types.items():
            type_params = 0
            for layer in layers:
                r = self.initial_r_config.get(layer, 0)
                if r > 0 and layer in self.layer_sizes:
                    in_features, out_features = self.layer_sizes[layer]
                    layer_params = r * (in_features + out_features)
                    type_params += layer_params
            
            type_percentage = (type_params / self.initial_params * 100) if self.initial_params > 0 else 0
            logger.info(f"{layer_type} layers: {len(layers)} layers, {type_params:,} parameters ({type_percentage:.2f}%)")
        
        logger.info(f"Total LoRA parameters: {self.initial_params:,}")
        logger.info("=" * 80)
    
    def should_prune(self, training_step: int) -> bool:
        """
        Check if pruning should be applied at this training step.
        
        Args:
            training_step: Current training step
            
        Returns:
            True if pruning should be applied
        """
        return self.scheduler.should_prune(training_step)
    
    def execute_pruning_step(self, training_step: int) -> bool:
        """
        Execute a pruning step.
        
        This includes layer importance measurement, ILP optimization,
        model modification, and recovery training.
        
        Args:
            training_step: Current training step
            
        Returns:
            True if pruning was successful, False otherwise
        """

        # Always use INFO level for critical stages
        original_level = logger.level
        logger.setLevel(logging.INFO)
        
        try:
            logger.info(f"\n{'='*80}\nEXECUTING PRUNING STEP AT TRAINING STEP {training_step}\n{'='*80}")
            
            # Save checkpoint for potential rollback
            pruning_utils.save_model_checkpoint(self.model, self.checkpoint_path)
            
            # Get budget for this pruning step
            step_budget = self.scheduler.get_next_pruning_budget()
            step_info = self.scheduler.get_current_step_info()
            next_step_idx = step_info["step_idx"] + 1
            
            logger.info(f"Pruning step {next_step_idx+1}/{step_info['total_steps']} - "
                       f"Target budget: {step_budget:,} parameters")
            
            # Save current r_config for validation and comparison
            prev_r_config = self.current_r_config.copy()
            
            # Measure layer importances with EMA
            layer_importances = pruning_utils.measure_layer_importance(
                model=self.model,
                dataloader=self.train_dataloader,
                r_config=self.current_r_config,
                device=self.device,
                prev_importances=self.prev_importances,
                ema_decay=self.importance_ema_decay  # Use instance parameter instead of config default
            )
            
            # Store importances for next time (for EMA)
            self.prev_importances = layer_importances.copy()
            
            # Run ILP optimization with momentum-based penalty and seed
            new_r_config = optimize_pruning_step(
                r_config=self.current_r_config,
                layer_importances=layer_importances,
                layer_sizes=self.layer_sizes,
                available_r_values=pruning_config.AVAILABLE_R_VALUES,
                step_budget=step_budget,
                prev_r_config=self.prev_r_config,  # Pass previous config for momentum penalty
                momentum_penalty=self.momentum_penalty_weight,
                time_limit=pruning_config.OPTIMIZATION_TIMEOUT,
                seed=self.seed  # Pass seed for reproducibility
            )
            
            # Save current config for next time (for momentum penalty)
            self.prev_r_config = self.current_r_config.copy()
            
            # Log r-value changes before model modification
            logger.info("=" * 60)
            logger.info("LAYER-WISE R-VALUE CHANGES PLANNED")
            logger.info("=" * 60)
            
            # Count different types of changes
            increased = 0
            decreased = 0
            unchanged = 0
            
            # Group by layer type for more organized logging
            layer_types = pruning_utils.classify_layers_by_type(list(self.current_r_config.keys()))
            
            for layer_type, layers in layer_types.items():
                logger.info(f"\n[{layer_type.upper()} LAYERS]")
                for layer_name in sorted(layers):
                    old_r = self.current_r_config.get(layer_name, 0)
                    new_r = new_r_config.get(layer_name, 0)
                    
                    if old_r > new_r:
                        change_symbol = "↓"
                        decreased += 1
                        change_info = f"PRUNED ({old_r-new_r} reduction)"
                    elif old_r < new_r:
                        change_symbol = "↑"
                        increased += 1
                        change_info = f"INCREASED ({new_r-old_r} addition)"
                    else:
                        change_symbol = "="
                        unchanged += 1
                        change_info = "UNCHANGED"
                    
                    # Calculate parameters
                    in_features, out_features = self.layer_sizes.get(layer_name, (0, 0))
                    old_params = old_r * (in_features + out_features) if old_r > 0 else 0
                    new_params = new_r * (in_features + out_features) if new_r > 0 else 0
                    
                    logger.info(f"  {layer_name}: r={old_r} {change_symbol} {new_r} - {change_info}")
                    if old_params > 0 or new_params > 0:
                        logger.info(f"    Parameters: {old_params:,} → {new_params:,} ({new_params-old_params:+,})")
            
            logger.info("\nSUMMARY:")
            logger.info(f"  Decreased r-values: {decreased} layers")
            logger.info(f"  Unchanged r-values: {unchanged} layers")
            logger.info(f"  Increased r-values: {increased} layers")
            logger.info("=" * 60)
            
            # Modify model with new r configuration
            self.model, changes = pruning_utils.modify_lora_layers(
                model=self.model,
                new_r_config=new_r_config
            )
            
            if not changes:
                logger.info("No changes made to model, skipping validation")
                # Advance scheduler anyway
                self.scheduler.advance_pruning_step()
                return True
            
            # Validate pruning configuration
            pruning_validation, validation_summary = pruning_utils.validate_pruning_configuration(
                model=self.model,
                prev_r_config=prev_r_config,
                new_r_config=new_r_config,
                layer_sizes=self.layer_sizes,
                logger=logger  # Pass logger explicitly
            )
            
            # Check for mismatches and log
            if validation_summary['mismatch_layers'] > 0:
                logger.error(f"CRITICAL: Found {validation_summary['mismatch_layers']} layer(s) with r-value mismatches!")
                # Log details of mismatches
                for layer_name, details in pruning_validation.items():
                    if not details['is_match']:
                        logger.error(f"Mismatch in layer {layer_name}: expected r={details['expected_r']}, "
                                    f"got r={details['actual_r']}")
            
            # Track pruning metrics
            self.pruning_metrics['pruned_params'] = validation_summary['total_reduction']
            self.pruning_metrics['model_size_reduction'] = validation_summary['reduction_percentage']
            
            # Start recovery
            self.scheduler.start_recovery()
            
            # Calculate current parameters
            current_params = pruning_utils.calculate_total_parameters(new_r_config, self.layer_sizes)
            
            # Update configuration
            self.current_r_config = new_r_config
            
            # Add to history
            self.pruning_history.append({
                'step': next_step_idx + 1,
                'r_config': self.current_r_config.copy(),
                'param_count': current_params,
                'performance': None,  # Will be updated after recovery
                'status': 'pruned',
                'pruning_details': validation_summary
            })
            
            # Advance pruning step in scheduler
            self.scheduler.advance_pruning_step()
            
            return True
        
        finally:
            # Restore original log level
            logger.setLevel(original_level)
    
    def update_recovery(self) -> None:
        """Update recovery state and check if recovery is complete."""
        self.scheduler.update_recovery()
        
        # If recovery just completed, validate model
        if not self.scheduler.in_recovery_mode and self.pruning_history[-1]['status'] == 'pruned':
            self._validate_after_recovery()
    
    def _validate_after_recovery(self) -> None:
        """Validate model after recovery and handle results."""
        # Setup logging but check if console handler already exists
        
        # Force INFO level for visibility of critical logs
        original_level = logger.level
        logger.setLevel(logging.INFO)
        
        try:
            # Evaluate performance
            start_time = time.time()
            performance, is_acceptable = pruning_utils.validate_pruning(
                model=self.model,
                eval_dataloader=self.eval_dataloader,
                baseline_performance=self.baseline_performance,
                threshold=pruning_config.PERFORMANCE_DROP_THRESHOLD,
                device=self.device
            )
            eval_time = time.time() - start_time
            
            # Calculate steps per second - critical metric
            num_eval_steps = len(self.eval_dataloader)
            steps_per_second = num_eval_steps / eval_time if eval_time > 0 else 0
            
            # Update current performance
            self.current_performance = performance
            
            # Update history
            self.pruning_history[-1]['performance'] = performance
            self.pruning_history[-1]['eval_time'] = eval_time
            self.pruning_history[-1]['eval_steps_per_second'] = steps_per_second
            
            # Update metrics - ensure eval_steps_per_second is properly stored
            self.pruning_metrics['eval_steps_per_second'] = steps_per_second
            
            # After recovery validation with comprehensive logging
            logger.info("=" * 60)
            logger.info(f"PRUNING STEP {len(self.pruning_history)-1} VALIDATION RESULTS")
            logger.info("=" * 60)
            
            current_params = self.pruning_history[-1]['param_count']
            param_reduction = self.initial_params - current_params
            param_reduction_pct = (param_reduction / self.initial_params * 100) if self.initial_params > 0 else 0
            
            logger.info(f"Parameters: {self.initial_params:,} → {current_params:,} ({current_params-self.initial_params:+,}, {param_reduction_pct:.2f}% reduction)")
            logger.info(f"Performance: {self.baseline_performance:.4f} → {performance:.4f} ({(performance-self.baseline_performance):+.4f})")
            logger.info(f"Evaluation speed: {steps_per_second:.2f} steps/second")
            
            if is_acceptable or not self.enable_rollback:
                # Pruning succeeded or rollback disabled
                success_msg = "SUCCESS - Performance is within acceptable threshold"
                if not is_acceptable and not self.enable_rollback:
                    success_msg = "ACCEPTED (rollback disabled) - Performance drop exceeded threshold but rollback is disabled"
                    
                self.pruning_history[-1]['status'] = 'success'
                logger.info(f"VERDICT: {success_msg}")
                
                # Update checkpoint with successful model
                pruning_utils.save_model_checkpoint(self.model, self.checkpoint_path)
            else:
                # Pruning failed - roll back
                logger.warning(f"VERDICT: FAILED - Performance drop exceeds threshold, rolling back")
                
                # Load previous checkpoint
                self.model = pruning_utils.load_model_checkpoint(
                    self.model, 
                    self.checkpoint_path,
                    allow_size_mismatch=True  # Allow size mismatches for safe rollback
                )
                
                # Restore previous configuration
                self.current_r_config = self.pruning_history[-2]['r_config'].copy()
                self.current_performance = self.pruning_history[-2]['performance']
                
                # Reset prev_r_config to prevent momentum penalty from using the failed config
                self.prev_r_config = self.current_r_config.copy()
                
                # Update history
                self.pruning_history[-1]['status'] = 'rollback'
                
                # Start extended recovery
                self.scheduler.start_recovery(extended=True)
            
            logger.info("=" * 60)
            
        finally:
            # Restore original log level
            logger.setLevel(original_level)
    
    def get_progress_info(self) -> Dict[str, Any]:
        """
        Get information about pruning progress.
        
        Returns:
            Dictionary with pruning progress information
        """
        # Calculate current reduction
        current_params = pruning_utils.calculate_total_parameters(
            self.current_r_config, self.layer_sizes)
        reduction = (self.initial_params - current_params) / self.initial_params
        
        return {
            "initial_params": self.initial_params,
            "current_params": current_params,
            "reduction": reduction,
            "target_reduction": self.scheduler.schedule["step_reduction_rates"][-1],
            "current_performance": self.current_performance,
            "baseline_performance": self.baseline_performance,
            "steps_completed": self.scheduler.current_step_idx + 1,
            "total_steps": len(self.scheduler.schedule["step_budgets"]),
            "in_recovery": self.scheduler.in_recovery_mode,
            "pruning_metrics": self.pruning_metrics
        }
    
    def update_efficiency_metrics(self, eval_flops: float, eval_macs: float) -> None:
        """
        Update efficiency metrics with latest evaluation data.
        
        Args:
            eval_flops: FLOPs from evaluation
            eval_macs: MACs from evaluation
        """
        # Store raw values
        self.pruning_metrics['eval_flops'] = eval_flops
        self.pruning_metrics['eval_macs'] = eval_macs
        
        # Calculate efficiency metrics (accuracy per compute)
        if self.current_performance and eval_flops > 0:
            self.pruning_metrics['accuracy_per_gflops'] = self.current_performance / (eval_flops / 1e9)
        
        if self.current_performance and eval_macs > 0:
            self.pruning_metrics['accuracy_per_gmacs'] = self.current_performance / (eval_macs / 1e9)
    
    def finalize(self) -> Tuple[nn.Module, Dict[str, int], float]:
        """
        Finalize pruning and generate reports.
        
        Returns:
            Tuple of (pruned_model, final_r_config, achieved_reduction)
        """
        # Generate summary
        current_params = pruning_utils.calculate_total_parameters(
            self.current_r_config, self.layer_sizes)
        achieved_reduction = (self.initial_params - current_params) / self.initial_params
        
        # Update pruning metrics with final accurate values for consistent reporting
        # This ensures metrics are accurate in both train_metrics and final logs
        self.pruning_metrics['pruned_params'] = self.initial_params - current_params
        self.pruning_metrics['model_size_reduction'] = achieved_reduction * 100  # Convert to percentage
        
        logger.info(f"\n{'='*80}\nPROGRESSIVE PRUNING COMPLETED\n{'='*80}")
        logger.info(f"Initial parameters = {self.initial_params:,}")
        logger.info(f"Final parameters = {current_params:,}")
        logger.info(f"Achieved reduction = {achieved_reduction:.2%}")
        logger.info(f"Final performance = {self.current_performance:.4f} (baseline: {self.baseline_performance:.4f})")
        
        # Detailed efficiency metrics
        if self.pruning_metrics['eval_flops'] > 0:
            logger.info(f"Performance per GFLOPs = {self.pruning_metrics['accuracy_per_gflops']:.6f}")
        
        if self.pruning_metrics['eval_macs'] > 0:
            logger.info(f"Performance per GMACs = {self.pruning_metrics['accuracy_per_gmacs']:.6f}")
        
        logger.info(f"Evaluation speed = {self.pruning_metrics['eval_steps_per_second']:.2f} steps/second")
        
        # Create visualization
        if pruning_config.PLOT_PRUNING_TRAJECTORY and len(self.pruning_history) > 1:
            pruning_utils.plot_pruning_results(
                pruning_history=self.pruning_history,
                output_dir=self.output_dir
            )
                
        # Save final r_config
        import json
        with open(os.path.join(self.output_dir, "final_r_config.json"), 'w') as f:
            json.dump(self.current_r_config, f, indent=2)
        
        # Save final pruning metrics
        with open(os.path.join(self.output_dir, "pruning_metrics.json"), 'w') as f:
            metrics = {
                "initial_params": self.initial_params,
                "final_params": current_params,
                "reduction": float(achieved_reduction),
                "baseline_performance": float(self.baseline_performance),
                "final_performance": float(self.current_performance),
                "pruning_metrics": self.pruning_metrics
            }
            json.dump(metrics, f, indent=2)
                
        return self.model, self.current_r_config, achieved_reduction