import math
from typing import Optional, Tuple


class NoiseScheduler:
    """
    Noise scheduler that adjusts noise probabilities during training.
    
    Supports different scheduling types:
    - linear: Linear increase from 0 to max_value
    - constant: Constant value throughout training
    - exponential: Exponential growth to max_value
    - step: Step-wise increase per epoch
    """
    
    def __init__(
        self,
        scheduler_type: str = "constant",
        max_important_noise: float = 0.1,
        max_normal_noise: float = 0.025,
        num_epochs: Optional[int] = None,
        num_grow_steps: int = 0,
        exp_gamma: float = 2.0,
    ):
        """
        Initialize noise scheduler.
        
        Args:
            scheduler_type: Type of scheduling ("linear", "constant", "exponential", "step")
            max_important_noise: Maximum noise for important tokens
            max_normal_noise: Maximum noise for normal tokens  
            num_epochs: Total number of epochs
            num_grow_steps: Total number of steps for noise to grow from 0 to max_value
            exp_gamma: Exponential growth factor for exponential scheduler
        """
        self.scheduler_type = scheduler_type.lower()
        self.max_important_noise = max_important_noise
        self.max_normal_noise = max_normal_noise
        self.num_epochs = num_epochs
        self.num_grow_steps = num_grow_steps
        self.exp_gamma = exp_gamma
        
        # Training state (to be updated by trainer callback)
        self.current_step = 0
        self.current_epoch = 0
        
        
        # Validate scheduler type
        valid_types = ["linear", "constant", "exponential", "step"]
        if self.scheduler_type not in valid_types:
            raise ValueError(f"scheduler_type must be one of {valid_types}, got {self.scheduler_type}")
    
    def update_training_state(self, current_step: int, current_epoch: int) -> None:
        """
        Update internal training state to unify with iter_decider's scheduling.
        """
        self.current_step = int(current_step) if current_step is not None else 0
        self.current_epoch = int(current_epoch) if current_epoch is not None else 0
    
    def get_current_noise_values(self) -> Tuple[float, float]:
        """
        Get current noise values based on internal training state.
            
        Returns:
            Tuple of (important_token_noise, normal_token_noise)
        """
        if self.scheduler_type == "constant":
            return self.max_important_noise, self.max_normal_noise
        
        elif self.scheduler_type == "linear":
            if self.num_grow_steps is None or self.num_grow_steps <= 0:
                return self.max_important_noise, self.max_normal_noise
            
            # Linear interpolation from 0 to max_value
            progress = min(self.current_step / self.num_grow_steps, 1.0)
            important_noise = self.max_important_noise * progress
            normal_noise = self.max_normal_noise * progress
            return important_noise, normal_noise
        
        elif self.scheduler_type == "exponential":
            if self.num_grow_steps is None or self.num_grow_steps <= 0:
                return self.max_important_noise, self.max_normal_noise
            
            # Exponential growth: noise = max_value * (1 - exp(-gamma * progress))
            progress = min(self.current_step / self.num_grow_steps, 1.0)
            exp_factor = 1.0 - math.exp(-self.exp_gamma * progress)
            important_noise = self.max_important_noise * exp_factor
            normal_noise = self.max_normal_noise * exp_factor
            return important_noise, normal_noise
        
        elif self.scheduler_type == "step":
            if self.num_epochs is None or self.num_epochs <= 0:
                return self.max_important_noise, self.max_normal_noise
            
            # Step-wise increase: noise increases linearly with epoch
            epoch_progress = min(self.current_epoch / max(self.num_epochs - 1, 1), 1.0)
            important_noise = self.max_important_noise * epoch_progress
            normal_noise = self.max_normal_noise * epoch_progress
            return important_noise, normal_noise
        
        else:
            return self.max_important_noise, self.max_normal_noise