from generator.meta_reward import MetaRewardGenerator
from dataloader.traj_data import PairedData, Traj
from dataloader.exploration_dataloader import DpoPool, ExplorationDataLoader

import json
import os
import time


    
class ExplorationSampler(MetaRewardGenerator):

    def __init__(self, args):
        super().__init__(args)

        assert args.criteria_step == 1, f"[Error] ExplorationSampler: only need one step criteria generation. change this value into `1`."
        
        if args.method == "criteria_exploration":
            assert args.dpo_pool_file is None, "Don't pass dpo pool file. When you want to `criteria_exploration`."
            self.dpo_pool = None
            self.criteria_n = 10
            print(f"`criteria_exploration` use ExplorationSampler to generate data! Constraint the generated Criteria_n:[{self.criteria_n}]")
        else:
            self.dpo_pool = DpoPool(args.dpo_pool_file)
            self.criteria_n = 7
            print(f"use ExplorationSampler to explore more criteria! Constraint the generated Criteria_n:[{self.criteria_n}]")

        self.dataset_loader = ExplorationDataLoader(self.args)
        #

    def exploration_judge(self, router):
        
        def criteria_handler(batch_dataset):
            with self.lock:
                with open(self.criteria_file, 'a') as fw:
                    for data in batch_dataset:
                        try:
                            fw.write(data.criteria_dumps(self.criteria_n) + "\n")
                        except Exception as e:
                            continue
        
        dataset_iter = self.dataset_loader.load_criteria_iter(self.input_file, self.criteria_file, self.dpo_pool)
        if dataset_iter is not None:
            router.init_inference_model()
            self.run_with_router(router, dataset_iter, criteria_handler)
        print(f"==>⏱️ ExplorationSampler [Judge][exploration] Done.")

        def judge_handler(batch_dataset):
            with self.lock:
                with open(self.judge_file, 'a') as fw:
                    for data in batch_dataset:
                        try:
                            fw.write(data.criteria_n_judge_dumps() + "\n")
                        except:
                            continue
        
        dataset_iter = self.dataset_loader.load_judge_iter(self.criteria_file, self.judge_file, manner=None)
        if dataset_iter is None:
            print(f"==>😎 Found existing file: {self.judge_file}, Skip Stepwise Judge...")
        else:
            router.init_inference_model()
            self.run_with_router(router, dataset_iter, judge_handler)
        print(f"==>⏱️ ExplorationSampler [Judge][judge] Done.")
        
    def run(self, router):

        start_time = time.time()
        
        self.exploration_judge(router)
        print(f"==>⏱️ ExplorationSampler [Judge] done...")

        self.refinement(router)
        print(f"==>⏱️ ExplorationSampler [Refinement] done...")
        
        print(f"[DEBUG]: begin to run ranking !!!!!!")
        self.ranking(router)
        print(f"==>⏱️ ExplorationSampler [Ranking] done...")

        # get the index order and the estimation

        end_time = time.time()
        elapsed = end_time - start_time
        print(f"==>⏱️ ExplorationSampler router.run() Spend Time: {elapsed:.2f} 秒")


