from .scheduler import Scheduler
import numpy as np

class DescentCalveseekScheduler(Scheduler, scheduler_name='descent_calveseek'):
    '''
    Use score degree to schedule the tasks with separate thresholds for calon and calve phases.
    '''
    def __init__(self, precisions, calon_mean_score, calon_max_score, calve_mean_score, calve_max_score, seek_mean_score, seek_max_score, sol_precision, windows):
        assert len(precisions) <= 2, "Precisions should contain at most two elements."
        super(DescentCalveseekScheduler, self).__init__(precisions)
        self.high_precision = max(precisions)
        self.low_precision = min(precisions)
        self.calon_mean_score = calon_mean_score
        self.calon_max_score = calon_max_score
        self.calve_mean_score = calve_mean_score
        self.calve_max_score = calve_max_score
        self.seek_mean_score = seek_mean_score
        self.seek_max_score = seek_max_score
        self.sol_precision = sol_precision
        self.windows = windows
        self.count = 0
        
        # 初始化四分类的奖励统计
        self.calon_scores = []
        self.verion_scores = []
        self.calve_scores = []
        self.seek_scores = []

    def schedule(self, **kwargs):
        if kwargs["is_split"]:
            self.count += 1
            if self.count > self.windows:
                if kwargs["cur_phase"] == "cot":
                    # 根据文本类型存储奖励
                    text_type = kwargs.get("text_type", "")
                    
                    # 安全地计算平均值和最大值，处理空数组的情况
                    def safe_stats(scores):
                        if not scores:  # 如果数组为空
                            return 0.0, 0.0
                        return np.mean(scores), np.max(scores)
                    
                    mean_calon_score, max_calon_score = safe_stats(kwargs["calon_scores"])
                    mean_calve_score, max_calve_score = safe_stats(kwargs["calve_scores"])
                    mean_verion_score, max_verion_score = safe_stats(kwargs["verion_scores"])
                    mean_seek_score, max_seek_score = safe_stats(kwargs["seek_scores"])

                    if kwargs["precision"] == self.high_precision:
                        if (("calon" in text_type and mean_calon_score > self.calon_mean_score[0] and max_calon_score > self.calon_max_score[0]) or
                            ("calve" in text_type and mean_calve_score > self.calve_mean_score[0] and max_calve_score > self.calve_max_score[0]) or
                            ("seek" in text_type and mean_seek_score > self.seek_mean_score[0] and max_seek_score > self.seek_max_score[0])):
                            self.count = 0
                            print("dsc48================================================")
                            print(f"Alternatebefore: {self.high_precision}") 
                            print(f"text_type: {text_type}, text: {kwargs.get('text', '')}")
                            print(f"Alternate precision: {self.low_precision}")
                            print("===!!!!!!!!!=")
                            return self.low_precision
                        else:
                            print("====dsc55=========")
                            print(f"precision: {self.high_precision}") 
                            print(f"text_type: {text_type}")
                            print("====no change==============")
                            return self.high_precision
                    if kwargs["precision"] == self.low_precision:
                        if (("calon" in text_type and mean_calon_score > self.calon_mean_score[1] and max_calon_score > self.calon_max_score[1]) or
                            ("calve" in text_type and mean_calve_score > self.calve_mean_score[1] and max_calve_score > self.calve_max_score[1]) or
                            ("seek" in text_type and mean_seek_score > self.seek_mean_score[1] and max_seek_score > self.seek_max_score[1])):
                            self.count = 0
                            print(f"dsc650000000000: {self.high_precision}") 
                            print(f"text_type: {text_type}")
                            print(f"dsc670000000000000: {self.low_precision}")
                            return 0
                        else:
                            print("dsc70no change")
                            return self.low_precision

                else:
                    return self.sol_precision
            else:
                return kwargs["precision"]
        else:
            if kwargs["cur_phase"] == "cot":
                return kwargs["precision"]
            else:
                return self.sol_precision

    def reset(self):
        '''
        Reset the scheduler.
        '''
        self.count = 0
        # 重置所有分类的奖励统计
        self.calon_scores = []
        self.verion_scores = []
        self.calve_scores = []
        self.seek_scores = []
