from generator.meta_reward import MetaRewardGenerator
import time
import os
from utils import extract_omni_answer



class RuleFilter(MetaRewardGenerator):
    
    def __init__(self, args):
        super().__init__(args)
        
        self.filter_file = os.path.join(self.output_dir, f"filted_data.jsonl")

    # def direct_judge(self, router):

    #     def criteria_handler(batch_dataset):
    #         with self.lock:
    #             with open(self.criteria_file, 'a') as fw:
    #                 for data in batch_dataset:
    #                     try:
    #                         fw.write(data.criteria_and_judge_dumps() + "\n")
    #                     except Exception as e:
    #                         print(e)
    #                         continue
        
    #     for step in range(self.criteria_step):
    #         dataset_iter = self.dataset_loader.load_criteria_iter(self.input_file, self.criteria_file, step, manner=self.manner)
    #         if dataset_iter is not None:
    #             router.init_inference_model()
    #             self.run_with_router(router, dataset_iter, criteria_handler)
        
    #     self.judge_merge()

    def run(self, router):

        start_time = time.time()

        if self.manner == "direct":
            self.direct_judge(router)
        elif self.manner == "stepwise":
            self.stepwise_judge(router)
        else:
            raise ValueError(f"manner: {self.manner} is not supported")
        print(f"==>⏱️ [{self.__class__.__name__}] [Judge] done...")

        self.filter()

        end_time = time.time()
        elapsed = end_time - start_time
        print(f"==>⏱️ [{self.__class__.__name__}] router.run() Spend Time: {elapsed:.2f} 秒")

    def filter(self):

        dataset_iter = self.dataset_loader.load_filter_data_iter(self.judge_file)
        count = 0
        
        with open(self.filter_file, "w") as fw:
            
            for data in dataset_iter:
                correct, pred = extract_omni_answer(response=data.judge[0], answer=data.answer)
                # if correct:
                #     continue
                fw.write(data.raw_data_dumps()+'\n')
                count += 1
        
        print(f"The number of filtered data is: [{count}].")
