from .scheduler import Scheduler
import numpy as np

class ScoreDescentCutScheduler(Scheduler, scheduler_name='score_descent_cut'):
    '''
    Use score degree to schedule the tasks.
    '''
    def __init__(self, precisions, mean_score, max_score, sol_precision, windows):
        assert len(precisions) <= 2, "Precisions should contain at most two elements."
        super(ScoreDescentCutScheduler, self).__init__(precisions)
        self.high_precision = max(precisions)
        self.low_precision = min(precisions)
        self.mean_score = mean_score
        self.max_score = max_score
        self.sol_precision = sol_precision
        self.windows = windows
        self.count = 0

    def schedule(self, **kwargs):
        # print("kwargs",kwargs["is_split"])
        if kwargs["is_split"]:
            self.count += 1
            if self.count > self.windows:
                if kwargs["cur_phase"] == "cot":
                    mean_score = np.mean(kwargs["scores"])
                    max_score = np.max(kwargs["scores"])
                    if kwargs["precision"] == self.high_precision:
                        if mean_score > self.mean_score[0] and max_score > self.max_score[0]:
                            self.count = 0
                            print("Alternate precision") 
                            return self.low_precision
                        else:
                            return self.high_precision
                    if kwargs["precision"] == self.low_precision:
                        if mean_score > self.mean_score[1] and max_score > self.max_score[1]:
                            self.count = 0
                            print("cut down") 
                            return 0
                        else:
                            return self.low_precision
                else:
                    return self.sol_precision
            else:
                return kwargs["precision"]
        else:
            if kwargs["cur_phase"] == "cot":
                return kwargs["precision"]
            else:
                return self.sol_precision

    def reset(self):
        '''
        Reset the scheduler.
        '''
        self.count = 0
