from .scheduler import Scheduler

class NaiveScheduler(Scheduler, scheduler_name='naive'):
    '''
    Use high and low precision to schedule the tasks.
    The scheduler uses high precision for the first 'high_precision_steps' tokens,
    then switches to low precision for the remaining tokens.
    '''
    def __init__(self, precisions, high_precision_steps):
        assert len(precisions) <= 2, "Precisions should contain at most two elements."
        assert high_precision_steps >= 0, "high_precision_steps should be non-negative."
        super(NaiveScheduler, self).__init__(precisions)
        self.high_precision = max(precisions)
        self.low_precision = min(precisions)
        self.high_precision_steps = high_precision_steps

    def schedule(self, **kwargs):
        """
        Schedule precision based on the current token index.
        
        Args:
            **kwargs: Dictionary containing scheduling information
                - index: Current token index (0-based, from schedule_dict['index'])
                - other parameters are ignored for naive scheduler
        
        Returns:
            int: Precision to use (high_precision or low_precision)
        """
        index = kwargs.get('index', 0)
        
        # Use high precision for the first 'high_precision_steps' tokens
        if index < self.high_precision_steps:
            return self.high_precision
        else:
            return self.low_precision
    
    def reset(self):
        """
        Reset the scheduler state (no state to reset for naive scheduler).
        """
        pass
        