from .scheduler import Scheduler

class SimpleSplitScheduler(Scheduler, scheduler_name='simple_split'):
    '''
    Use high and low precision to schedule the tasks.
    '''
    def __init__(self, precisions):
        assert len(precisions) <= 2, "Precisions should contain at most two elements."
        super(SimpleSplitScheduler, self).__init__(precisions)
        self.cot_precision = precisions[0]
        self.sol_precision = precisions[1]

    def schedule(self, **kwargs):
        if kwargs['cur_phase'] == "cot":
            return self.cot_precision
        elif kwargs['cur_phase'] == "solution":
            return self.sol_precision
        elif kwargs['cur_phase'] == "answer":
            return self.sol_precision
        else:
            return self.cot_precision