from .scheduler import Scheduler
import numpy as np

class DescentDeweyScheduler(Scheduler, scheduler_name='descent_dewey'):
    '''
    Use score degree to schedule the tasks with separate thresholds for calon and calve phases.
    '''
    def __init__(self, precisions, problem_mean_score, problem_max_score, computation_mean_score, computation_max_score, verification_mean_score, verification_max_score, sol_precision, windows):
        assert len(precisions) <= 2, "Precisions should contain at most two elements."
        super(DescentDeweyScheduler, self).__init__(precisions)
        self.high_precision = max(precisions)
        self.low_precision = min(precisions)
        self.problem_mean_score = problem_mean_score
        self.problem_max_score = problem_max_score
        self.computation_mean_score = computation_mean_score
        self.computation_max_score = computation_max_score
        self.verification_mean_score = verification_mean_score
        self.verification_max_score = verification_max_score
        self.sol_precision = sol_precision
        self.windows = windows
        self.count = 0
        
        self.problem_scores = []
        self.computation_scores = []
        self.verification_scores = []
        # self.conclusion_scores = []

    def schedule(self, **kwargs):
        if kwargs["is_split"]:
            self.count += 1
            if self.count > self.windows:
                if kwargs["cur_phase"] == "cot":
                    dewey_text_type = kwargs.get("dewey_text_type", "")
                    
                    def safe_stats(scores):
                        if not scores: 
                            return 0.0, 0.0
                        return np.mean(scores), np.max(scores)
                    
                    mean_problem_score, max_problem_score = safe_stats(kwargs["problem_scores"])
                    mean_computation_score, max_computation_score = safe_stats(kwargs["computation_scores"])
                    mean_verification_score, max_verification_score = safe_stats(kwargs["verification_scores"])
                    # mean_conclusion_score, max_conclusion_score = safe_stats(kwargs["conclusion_scores"])

                    if kwargs["precision"] == self.high_precision:
                        if ((dewey_text_type == "problem_formulation" and mean_problem_score > self.problem_mean_score[0] and max_problem_score > self.problem_max_score[0]) or
                            (dewey_text_type == "computation" and mean_computation_score > self.computation_mean_score[0] and max_computation_score > self.computation_max_score[0]) or
                            (dewey_text_type == "verification" and mean_verification_score > self.verification_mean_score[0] and max_verification_score > self.verification_max_score[0])):
                            self.count = 0
                            return self.low_precision
                        else:
                            return self.high_precision
                    if kwargs["precision"] == self.low_precision:
                        if ((dewey_text_type == "problem_formulation" and mean_problem_score > self.problem_mean_score[1] and max_problem_score > self.problem_max_score[1]) or
                            (dewey_text_type == "computation" and mean_computation_score > self.computation_mean_score[1] and max_computation_score > self.computation_max_score[1]) or
                            (dewey_text_type == "verification" and mean_verification_score > self.verification_mean_score[1] and max_verification_score > self.verification_max_score[1])):
                            self.count = 0
                            return 0
                        else:
                            return self.low_precision

                else:
                    return self.sol_precision
            else:
                return kwargs["precision"]
        else:
            if kwargs["cur_phase"] == "cot":
                return kwargs["precision"]
            else:
                return self.sol_precision

    def reset(self):
        '''
        Reset the scheduler.
        '''
        self.count = 0
        self.problem_scores = []
        self.computation_scores = []
        self.verification_scores = []
        # self.conclusion_scores = []
