"""
Pruning scheduler implementation for progressive LoRA pruning.

This module provides functions to create pruning schedules based on Bezier curves,
allowing for smooth and controlled reduction of parameters over multiple steps.
"""

import numpy as np
import math
import logging
from typing import List, Dict, Tuple, Optional
from . import pruning_config

logger = logging.getLogger(__name__)

def calculate_bezier_coefficients(control_points: List[float]) -> Tuple[np.ndarray, np.ndarray, int]:
    """
    Calculate Bezier curve coefficients from control points.
    
    Args:
        control_points: List of control points (must be between 0 and 1)
        
    Returns:
        Tuple of (coefficients, points, degree) for the Bezier curve
    """
    # Validate control points
    if not all(0 <= p <= 1 for p in control_points):
        raise ValueError("All control points must be between 0 and 1")
    
    # Convert to numpy array for calculations
    points = np.array(control_points)
    n = len(points) - 1  # Degree of the Bezier curve
    
    # Calculate binomial coefficients
    coeffs = np.zeros(n + 1)
    for i in range(n + 1):
        coeffs[i] = math.comb(n, i)
    
    return coeffs, points, n

def evaluate_bezier_curve(t: float, coeffs: np.ndarray, points: np.ndarray, n: int) -> float:
    """
    Evaluate a Bezier curve at parameter t.
    
    Args:
        t: Parameter value (0 to 1)
        coeffs: Binomial coefficients
        points: Control points
        n: Degree of the curve
        
    Returns:
        Value of the Bezier curve at t
    """
    if not 0 <= t <= 1:
        raise ValueError("Parameter t must be between 0 and 1")
    
    result = 0.0
    for i in range(n + 1):
        result += coeffs[i] * points[i] * (t ** i) * ((1 - t) ** (n - i))
    
    return result

def get_pruning_schedule(
    initial_budget: int,
    target_reduction: float,
    num_steps: int,
    total_training_steps: int
) -> Dict[str, List[int]]:
    """
    Generate a pruning schedule using a Bezier curve for smooth reduction.
    
    Args:
        initial_budget: Initial parameter budget (number of parameters)
        target_reduction: Target reduction ratio (0.0 to 1.0)
        num_steps: Number of pruning steps
        total_training_steps: Total number of training steps
        
    Returns:
        Dictionary containing budget information for each step
    """
    if target_reduction <= 0 or target_reduction >= 1:
        raise ValueError("Target reduction must be between 0 and 1")
    
    if num_steps < 1:
        raise ValueError("Number of steps must be at least 1")
    
    # Calculate target budget
    target_budget = initial_budget * (1 - target_reduction)
    
    # Get Bezier coefficients
    control_points = pruning_config.BEZIER_CONTROL_POINTS
    if len(control_points) < 2:
        # Fallback to simple linear reduction
        control_points = [0.0, 1.0]
    
    coeffs, points, degree = calculate_bezier_coefficients(control_points)
    
    # Generate step budgets based on Bezier curve
    step_budgets = []
    step_reduction_rates = []
    step_triggers = []
    
    # Calculate steps per pruning to distribute across training
    steps_per_pruning = max(1, total_training_steps // (num_steps + 1))
    # Start after some initial training
    initial_delay = max(pruning_config.PRUNING_START_EPOCH * steps_per_pruning, 
                        steps_per_pruning)
    
    for i in range(num_steps):
        # Calculate step trigger (when to apply this pruning step)
        step_trigger = initial_delay + i * steps_per_pruning
        step_triggers.append(step_trigger)
        
        t = (i + 1) / num_steps  # Parameter for Bezier curve (0 to 1)
        reduction_rate = evaluate_bezier_curve(t, coeffs, points, degree)
        current_reduction = target_reduction * reduction_rate
        current_budget = initial_budget * (1 - current_reduction)
        
        # Ensure budget is an integer
        current_budget = int(current_budget)
        
        step_budgets.append(current_budget)
        step_reduction_rates.append(current_reduction)
    
    # Create and return the pruning schedule
    schedule = {
        "initial_budget": initial_budget,
        "target_budget": int(target_budget),
        "step_budgets": step_budgets,
        "step_reduction_rates": step_reduction_rates,
        "step_triggers": step_triggers
    }
    
    # Log the schedule
    logger.info(f"Generated pruning schedule: initial={initial_budget:,} → target={int(target_budget):,}")
    for i, (budget, rate, trigger) in enumerate(zip(step_budgets, step_reduction_rates, step_triggers)):
        logger.info(f"  Step {i+1}: budget={budget:,} (reduction={rate:.2%}) at training step {trigger}")
    
    return schedule

class PruningScheduler:
    """
    Manages the pruning schedule during training.
    
    This class tracks training progress and determines when to apply pruning
    and recovery steps according to the schedule.
    """
    
    def __init__(
        self,
        initial_budget: int,
        target_reduction: float,
        num_pruning_steps: int,
        total_training_steps: int
    ):
        """
        Initialize the pruning scheduler.
        
        Args:
            initial_budget: Initial parameter budget
            target_reduction: Target reduction ratio (0.0 to 1.0)
            num_pruning_steps: Number of pruning steps
            total_training_steps: Total number of training steps
        """
        self.schedule = get_pruning_schedule(
            initial_budget=initial_budget,
            target_reduction=target_reduction,
            num_steps=num_pruning_steps,
            total_training_steps=total_training_steps
        )
        
        self.current_step_idx = -1  # No pruning done yet
        self.in_recovery_mode = False
        self.recovery_step_counter = 0
        self.recovery_steps = pruning_config.RECOVERY_STEPS
        
    def should_prune(self, training_step: int) -> bool:
        """
        Check if pruning should be applied at the current training step.
        
        Args:
            training_step: Current training step
            
        Returns:
            True if pruning should be applied, False otherwise
        """
        # If in recovery mode, don't prune
        if self.in_recovery_mode:
            return False
        
        # Check if we're at a pruning step
        next_step_idx = self.current_step_idx + 1
        if next_step_idx < len(self.schedule["step_triggers"]):
            if training_step >= self.schedule["step_triggers"][next_step_idx]:
                return True
        
        return False
    
    def start_recovery(self, extended: bool = False) -> None:
        """
        Start recovery mode after pruning.
        
        Args:
            extended: Whether to use extended recovery
        """
        self.in_recovery_mode = True
        self.recovery_step_counter = 0
        self.recovery_steps = pruning_config.EXTENDED_RECOVERY_STEPS if extended else pruning_config.RECOVERY_STEPS
    
    def update_recovery(self) -> None:
        """Update recovery step counter and check if recovery is complete."""
        if self.in_recovery_mode:
            self.recovery_step_counter += 1
            if self.recovery_step_counter >= self.recovery_steps:
                self.in_recovery_mode = False
    
    def get_next_pruning_budget(self) -> int:
        """
        Get the budget for the next pruning step.
        
        Returns:
            Budget for the next pruning step
        """
        next_step_idx = self.current_step_idx + 1
        if next_step_idx < len(self.schedule["step_budgets"]):
            return self.schedule["step_budgets"][next_step_idx]
        return self.schedule["target_budget"]
    
    def advance_pruning_step(self) -> None:
        """Advance to the next pruning step."""
        self.current_step_idx += 1
    
    def get_current_step_info(self) -> Dict[str, float]:
        """
        Get information about the current pruning step.
        
        Returns:
            Dictionary with current step information
        """
        if self.current_step_idx < 0:
            return {
                "step_idx": -1,
                "budget": self.schedule["initial_budget"],
                "reduction_rate": 0.0,
                "total_steps": len(self.schedule["step_budgets"])
            }
        
        return {
            "step_idx": self.current_step_idx,
            "budget": self.schedule["step_budgets"][self.current_step_idx],
            "reduction_rate": self.schedule["step_reduction_rates"][self.current_step_idx],
            "total_steps": len(self.schedule["step_budgets"])
        }