from .meta_reward import MetaRewardGenerator
from dataloader.criteria_dataloder import *
from datetime import datetime
import itertools
import threading
import json
import time
import ray
import os


class CriteriaMetaRewardGenerator(MetaRewardGenerator):

    def __init__(self, args):

        super().__init__(args)
        self.criteria_n = args.criteria_n

    def run(self, router):
        
        start_time = time.time()
        
        if self.manner == "criteria_n":
            self.criteria_n_judge(router)
        else:
            raise ValueError(f"manner: {self.manner} is not supported")
        print(f"==>⏱️ CriteriaMetaRewardGenerator [Judge] done...")

        self.refinement(router)
        print(f"==>⏱️ CriteriaMetaRewardGenerator [Refinement] done...")
        
        print(f"[DEBUG]: begin to run ranking !!!!!!")
        self.ranking(router)
        print(f"==>⏱️ CriteriaMetaRewardGenerator [Ranking] done...")

        end_time = time.time()
        elapsed = end_time - start_time
        print(f"==>⏱️ CriteriaMetaRewardGenerator router.run() Spend Time: {elapsed:.2f} 秒")
    
    def criteria_n_judge(self, router):
        
        def criteria_handler(batch_dataset):
            
            with open(self.criteria_file, 'a') as fw:
                for data in batch_dataset:
                    try:
                        fw.write(data.criteria_n_dumps(self.criteria_n) + "\n")
                    except:
                        continue
        
        dataset_iter = self.dataset_loader.load_criteria_iter(self.input_file, self.criteria_file, step=0, manner=self.manner)
        if dataset_iter is not None:
            router.init_inference_model()
            self.run_with_router(router, dataset_iter, criteria_handler)
        print(f"==>⏱️ CriteriaMetaRewardGenerator [Judge] --> stepwise_judge done...")
        
        def judge_handler(batch_dataset):
            
            with open(self.judge_file, 'a') as fw:
                for data in batch_dataset:
                    try:
                        fw.write(data.criteria_n_judge_dumps() + "\n")
                    except:
                        continue
        
        dataset_iter = self.dataset_loader.load_judge_iter(self.criteria_file, self.judge_file, self.manner)
        if dataset_iter is None:
            print(f"==>😎 Found existing file: {self.judge_file}, Skip Stepwise Judge...")
        else:
            router.init_inference_model()
            self.run_with_router(router, dataset_iter, judge_handler)
        print(f"==>⏱️ MetaRewardGenerator [Judge] --> stepwise_judge done...")
    
    def refinement(self, router):

        def refinement_handler(batch_dataset):
            with self.lock:
                with open(self.refinement_file, 'a') as fw:
                    for data in batch_dataset:
                        fw.write(data.refinement_dumps() + "\n")
        
        dataset_iter = self.dataset_loader.load_refinement_iter(self.judge_file, self.refinement_file)
        if dataset_iter is None:
            print(f"==>😎 Found existing file: {self.refinement_file}, Skip Refinement...")
        else:
            router.init_refinement_model()
            self.run_with_router(router, dataset_iter, refinement_handler)
        print(f"==>⏱️ MetaRewardGenerator [Refinement] --> Refine done...")

    def ranking(self, router):
        
        def ranking_handler(batch_dataset):
            with self.lock:
                with open(self.ranking_file, 'a') as fw:
                    for data in batch_dataset:
                        fw.write(data.ranking_dumps() + "\n")
        
        dataset_iter = self.dataset_loader.load_ranking_iter(self.refinement_file, self.ranking_file)
        if dataset_iter is None:
            print(f"==>😎 Found existing file: {self.ranking_file}, Skip Ranking...")
        else:
            router.init_ranking_model()
            self.run_with_router(router, dataset_iter, ranking_handler)
        print(f"==>⏱️ MetaRewardGenerator [Ranking] --> Refine done...")
