from .scheduler import Scheduler
import numpy as np
 
class DescentDeweyScheduler(Scheduler, scheduler_name='descent_dewey_score'):
    '''
    Use score degree to schedule the tasks with separate thresholds for calon and calve phases.
    '''
    def __init__(self, precisions, problem_split_mean_score, problem_split_max_score, computation_split_mean_score, computation_split_max_score, verification_split_mean_score, verification_split_max_score, sol_precision, windows, alpha_split, minus_score):
        assert len(precisions) <= 2, "Precisions should contain at most two elements."
        super(DescentDeweyScheduler, self).__init__(precisions)
        self.high_precision = max(precisions)
        self.low_precision = min(precisions)
        self.problem_split_mean_score = problem_split_mean_score
        self.problem_split_max_score = problem_split_max_score
        self.computation_split_mean_score = computation_split_mean_score
        self.computation_split_max_score = computation_split_max_score
        self.verification_split_mean_score = verification_split_mean_score
        self.verification_split_max_score = verification_split_max_score
        self.sol_precision = sol_precision
        self.windows = windows
        self.count = 0
        self.alpha_split = alpha_split
        self.minus_score = minus_score
        self.problem_scores = []
        self.computation_scores = []
        self.verification_scores = []
        # self.conclusion_scores = []

    def schedule(self, **kwargs):
        if kwargs["is_split"]:
            self.count += 1
            if self.count > self.windows:
                if kwargs["cur_phase"] == "cot":
                    dewey_text_type = kwargs.get("dewey_text_type", "")
                    
                    def safe_stats(scores):
                        if not scores: 
                            return 0.0, 0.0
                        return np.mean(scores), np.max(scores)
                    split_prob = kwargs["one_question_split_prob"]
                    mean_problem_score, max_problem_score = safe_stats(kwargs["problem_scores"])
                    mean_computation_score, max_computation_score = safe_stats(kwargs["computation_scores"])
                    mean_verification_score, max_verification_score = safe_stats(kwargs["verification_scores"])
                    preoblem_process_mean_score = (1-self.alpha_split[0]) * (mean_problem_score - self.minus_score[0]) + self.alpha_split[0] * ((-np.log2(split_prob[-1]))-(-np.log2(split_prob[-2])))
                    preoblem_process_max_score = (1-self.alpha_split[0]) * (max_problem_score - self.minus_score[0]) + self.alpha_split[0] * ((-np.log2(split_prob[-1]))-(-np.log2(split_prob[-2])))
                    computation_process_mean_score = (1-self.alpha_split[1]) * (mean_computation_score - self.minus_score[1]) + self.alpha_split[1] * ((-np.log2(split_prob[-1]))-(-np.log2(split_prob[-2])))
                    computation_process_max_score = (1-self.alpha_split[1]) * (max_computation_score - self.minus_score[1]) + self.alpha_split[1] * ((-np.log2(split_prob[-1]))-(-np.log2(split_prob[-2])))
                    verification_process_mean_score = (1-self.alpha_split[2]) * (mean_verification_score - self.minus_score[2]) + self.alpha_split[2] * ((-np.log2(split_prob[-1]))-(-np.log2(split_prob[-2])))
                    verification_process_max_score = (1-self.alpha_split[2]) * (max_verification_score - self.minus_score[2]) + self.alpha_split[2] * ((-np.log2(split_prob[-1]))-(-np.log2(split_prob[-2])))
                    # mean_conclusion_score, max_conclusion_score = safe_stats(kwargs["conclusion_scores"])

                    if kwargs["precision"] == self.high_precision:
                        if ((dewey_text_type == "problem_formulation" and preoblem_process_mean_score > self.problem_split_mean_score[0] and preoblem_process_max_score > self.problem_split_max_score[0]) or
                            (dewey_text_type == "computation" and computation_process_mean_score > self.computation_split_mean_score[0] and computation_process_max_score > self.computation_split_max_score[0]) or
                            (dewey_text_type == "verification" and verification_process_mean_score > self.verification_split_mean_score[0] and verification_process_max_score > self.verification_split_max_score[0])):
                            self.count = 0
                            return self.low_precision
                        else:
                            return self.high_precision
                    if kwargs["precision"] == self.low_precision:
                        if ((dewey_text_type == "problem_formulation" and preoblem_process_mean_score > self.problem_split_mean_score[1] and preoblem_process_max_score > self.problem_split_max_score[1]) or
                            (dewey_text_type == "computation" and computation_process_mean_score > self.computation_split_mean_score[1] and computation_process_max_score > self.computation_split_max_score[1]) or
                            (dewey_text_type == "verification" and verification_process_mean_score > self.verification_split_mean_score[1] and verification_process_max_score > self.verification_split_max_score[1])):
                            self.count = 0
                            return 0
                        else:
                            return self.low_precision

                else:
                    return self.sol_precision
            else:
                return kwargs["precision"]
        else:
            if kwargs["cur_phase"] == "cot":
                return kwargs["precision"]
            else:
                return self.sol_precision

    def reset(self):
        '''
        Reset the scheduler.
        '''
        self.count = 0
        self.problem_scores = []
        self.computation_scores = []
        self.verification_scores = []
        # self.conclusion_scores = []