from utils import *
from .generator import BaseGenerator
from dataloader.criteria_dataloder import CriteriaDataLoder
from datetime import datetime
import itertools
import threading
import json
import time
import ray
import os
import re


class MetaRewardGenerator(BaseGenerator):

    def __init__(self, args):
        super().__init__(args)

        self.manner = args.manner
        self.lock = threading.Lock()
        self.sampling_n = args.sampling_n
        self.criteria_step = args.criteria_step

        self.dataset_loader = CriteriaDataLoder(self.args)

        self._init_file()
    
    def _init_file(self):
        
        os.makedirs(self.output_dir, exist_ok=True)
        self.criteria_file = os.path.join(self.output_dir, f"criteria.jsonl")
        self.criteria_merge_file = os.path.join(self.output_dir, f"criteria_merge.jsonl")
        self.judge_file = os.path.join(self.output_dir, f"judge.jsonl")
        self.refinement_file = os.path.join(self.output_dir, f"refinment.jsonl")
        self.ranking_file = os.path.join(self.output_dir, f"ranking.jsonl")
    
    def run(self, router):

        start_time = time.time()
        
        if self.manner == "direct":
            self.direct_judge(router)
        elif self.manner == "stepwise":
            self.stepwise_judge(router)
        else:
            raise ValueError(f"manner: {self.manner} is not supported")
        print(f"==>⏱️ MetaRewardGenerator [Judge] done...")

        self.refinement(router)
        print(f"==>⏱️ MetaRewardGenerator [Refinement] done...")
        
        print(f"[DEBUG]: begin to run ranking !!!!!!")
        self.ranking(router)
        print(f"==>⏱️ MetaRewardGenerator [Ranking] done...")

        end_time = time.time()
        elapsed = end_time - start_time
        print(f"==>⏱️ MetaRewardGenerator router.run() Spend Time: {elapsed:.2f} 秒")

    def run_with_router(self, router, dataset_iter, handler):
        router.add_handler(handler)
        router.run(dataset_iter)
        router.reset_handler()
    
    def direct_judge(self, router):

        def criteria_handler(batch_dataset):
            with self.lock:
                with open(self.criteria_file, 'a') as fw:
                    for data in batch_dataset:
                        try:
                            fw.write(data.criteria_and_judge_dumps() + "\n")
                        except Exception as e:
                            print(e)
                            continue
        
        for step in range(self.criteria_step):
            dataset_iter = self.dataset_loader.load_criteria_iter(self.input_file, self.criteria_file, step, manner=self.manner)
            if dataset_iter is not None:
                router.init_inference_model()
                self.run_with_router(router, dataset_iter, criteria_handler)
        
        self.judge_merge()

    def judge_merge(self):

        judge_data = {}
        for line in open(self.criteria_file).readlines():
            json_item = json.loads(line)
            paired_data_id = json_item['paired_data']['id']

            if paired_data_id not in judge_data:
                
                judge_data[paired_data_id] = {
                    "criteria": {},
                    "judge": [],
                    "judge_pair": {
                        "judge_a_list": [],
                        "judge_b_list": []
                    }
                }

            if "paired_data" in json_item:
                judge_data[paired_data_id]["paired_data"] = json_item["paired_data"]
            
            judge_data[paired_data_id]["criteria"][json_item["criteria_step"]] = json_item["criteria_list"]
            judge_data[paired_data_id]['judge'].extend(json_item['judge'])
            judge_data[paired_data_id]['judge_pair']['judge_a_list'].extend(json_item['judge_pair']['judge_a_list'])
            judge_data[paired_data_id]['judge_pair']['judge_b_list'].extend(json_item['judge_pair']['judge_b_list'])
            judge_data[paired_data_id]["answer"] = json_item['answer']

        with open(self.judge_file, 'w') as fw:
            for _, item in judge_data.items():
                fw.write(json.dumps(item) + "\n")

    def merge(self):
        
        criteria_data = {}
        for line in open(self.criteria_file).readlines():
            json_item = json.loads(line)

            if json_item['id'] not in criteria_data:
                criteria_data[json_item['id']] = {"criteria": {}}

            if "paired_data" in json_item:
                criteria_data[json_item['id']]["paired_data"] = json_item["paired_data"]
            
            criteria_data[json_item['id']]["criteria"][json_item["criteria_step"]] = json_item["criteria_list"]
        
        with open(self.criteria_merge_file, 'w') as fw:
            for _, item in criteria_data.items():
                fw.write(json.dumps(item) + "\n")

    def stepwise_judge(self, router):

        def criteria_handler(batch_dataset):
            with self.lock:
                with open(self.criteria_file, 'a') as fw:
                    for data in batch_dataset:
                        fw.write(data.criteria_dumps() + "\n")
            
        if not os.path.exists(self.criteria_merge_file):
            
            for step in range(self.criteria_step):
                dataset_iter = self.dataset_loader.load_criteria_iter(self.input_file, self.criteria_file, step, manner=self.manner)
                if dataset_iter is not None:
                    router.init_inference_model()
                    self.run_with_router(router, dataset_iter, criteria_handler)
                print(f"==>⏱️ MetaRewardGenerator [Judge] --> stepwise_judge in step {step} done...")
            self.merge()
        else:
            print(f"==>😎 Found existing file: {self.criteria_file}, Skip Stepwise Criteria...")

        def judge_handler(batch_dataset):
            with self.lock:
                with open(self.judge_file, 'a') as fw:
                    for data in batch_dataset:
                        try:
                            fw.write(data.judge_dumps() + "\n")
                        except:
                            continue
        
        dataset_iter = self.dataset_loader.load_judge_iter(self.criteria_merge_file, self.judge_file)
        if dataset_iter is None:
            print(f"==>😎 Found existing file: {self.judge_file}, Skip Stepwise Judge...")
        else:
            router.init_inference_model()
            self.run_with_router(router, dataset_iter, judge_handler)
        print(f"==>⏱️ MetaRewardGenerator [Judge] --> stepwise_judge done...")

    def refinement(self, router):

        def refinement_handler(batch_dataset):
            with self.lock:
                with open(self.refinement_file, 'a') as fw:
                    for data in batch_dataset:
                        fw.write(data.refinement_dumps() + "\n")
        
        dataset_iter = self.dataset_loader.load_refinement_iter(self.judge_file, self.refinement_file)
        if dataset_iter is None:
            print(f"==>😎 Found existing file: {self.refinement_file}, Skip Refinement...")
        else:
            router.init_refinement_model()
            self.run_with_router(router, dataset_iter, refinement_handler)
        print(f"==>⏱️ MetaRewardGenerator [Refinement] --> Refine done...")

    def ranking(self, router):
        
        def ranking_handler(batch_dataset):
            with self.lock:
                with open(self.ranking_file, 'a') as fw:
                    for data in batch_dataset:
                        fw.write(data.ranking_dumps() + "\n")
        
        dataset_iter = self.dataset_loader.load_ranking_iter(self.refinement_file, self.ranking_file)
        if dataset_iter is None:
            print(f"==>😎 Found existing file: {self.ranking_file}, Skip Ranking...")
        else:
            router.init_ranking_model()
            self.run_with_router(router, dataset_iter, ranking_handler)
        print(f"==>⏱️ MetaRewardGenerator [Ranking] --> Refine done...")

