from .scheduler import Scheduler
import random

# class SimpleTextScheduler(Scheduler, scheduler_name='simple_text'):
#     def __init__(self, precisions, text_type, sol_precision):
#         super(SimpleTextScheduler, self).__init__(precisions)
#         self.high_precision = max(precisions)
#         self.low_precision = min(precisions)
#         self.sol_precision = sol_precision
#         self.text_type = text_type
#         self.state = "reset"

#     def schedule(self, **kwargs):
#         # print(kwargs['single_word_type'])
#         if self.state == "normal":
#             if kwargs['single_word_type'] == "calc":
#                 self.state = "is_calc"
#             elif kwargs['single_word_type'] == "verify":
#                 self.state = "is_veri"
#         elif self.state == "is_calc":
#             if kwargs['single_word_type'] == "verify":
#                 self.state = "calc_veri"
#         elif self.state == "is_veri":
#             if kwargs['single_word_type'] == "calc":
#                 self.state = "calc_veri"

#         if kwargs['cur_phase'] == "cot":
#             if kwargs['is_split'] or self.state == "reset":
#                 self.state = "normal"
#                 return self.low_precision
#             else:
#                 if self.text_type == "calon" and self.state == "is_calc":
#                     return self.high_precision
#                 elif self.text_type == "verion" and self.state == "is_veri":
#                     return self.high_precision
#                 elif self.text_type == "calve" and self.state == "calc_veri":
#                     return self.high_precision
#                 else:
#                     return kwargs['precision']
#         else:
#             return self.sol_precision

#     def reset(self):
#         self.state = "reset"

class SimpleTextScheduler(Scheduler, scheduler_name='simple_text'):
    def __init__(self, precisions, text_type, sol_precision):
        super(SimpleTextScheduler, self).__init__(precisions)
        self.high_precision = max(precisions)
        self.low_precision = min(precisions)
        self.sol_precision = sol_precision
        self.text_type = text_type
        self.state = "reset"
        self.count = 0
        self.target = True
    def schedule(self, **kwargs):
        # print(kwargs['single_word_type'])
        if self.state == "normal":
            if kwargs['single_word_type'] == "computation":
                self.target = True
                self.state = "is_computation"
            elif kwargs['single_word_type'] == "verification":
                self.state = "is_verification"
            elif kwargs['single_word_type'] == "conclusion":
                self.state = "is_conclusion"
            elif kwargs['single_word_type'] == "problem_formulation":
                self.state = "is_problem_formulation"
        if kwargs['cur_phase'] == "cot":
            if kwargs['is_split'] or self.state == "reset":
                self.state = "normal"
                if self.text_type == "problem_formulation":
                    return self.high_precision
                else:
                    return self.low_precision
            else:
                if self.text_type == "computation" and self.state == "is_computation" and self.target:
                    self.target = False 
                    if self.count < 3:
                        self.count += 1
                        return kwargs['precision']
                    else:
                        self.count = 0
                        return self.high_precision
                elif self.text_type == "verification" and self.state == "is_verification":
                    return self.high_precision
                elif self.text_type == "conclusion" and self.state == "is_conclusion":
                    return self.high_precision
                elif self.text_type == "problem_formulation":
                    if self.state == "is_computation" or self.state == "is_verification" or self.state == "is_conclusion":
                        return self.low_precision
                    else:
                        return kwargs['precision']
                else:
                    return kwargs['precision']
        else:
            return self.sol_precision

    def reset(self):
        self.state = "reset"
        self.target = True