from .scheduler import Scheduler

class PartSplitScheduler(Scheduler, scheduler_name='part_split'):
    '''
    Use high and low precision to schedule the tasks.
    '''
    def __init__(self, precisions, cot, solution, answer):
        assert len(precisions) <= 2, "Precisions should contain at most two elements."
        super(PartSplitScheduler, self).__init__(precisions)
        self.high_precision = max(precisions)
        self.low_precision = min(precisions)
        self.cot = cot
        self.solution = solution
        self.answer = answer

    def schedule(self, **kwargs):
        # print(kwargs['cur_id'])
        # print(kwargs['cur_phase'])
        if kwargs['cur_phase'] == "cot":
            if self.cot:
                return self.high_precision
            else:
                return self.low_precision
        elif kwargs['cur_phase'] == "solution":
            if self.solution:
                return self.high_precision
            else:
                return self.low_precision
        elif kwargs['cur_phase'] == "answer":
            if self.answer:
                return self.high_precision
            else:
                return self.low_precision
        else:
            return self.low_precision