from utils import *
from .generator import BaseGenerator
from dataloader.criteria_dataloder import CriteriaDataLoder
from datetime import datetime
import itertools
import threading
import json
import time
import ray
import os
import re


class MetaRewardGenerator(BaseGenerator):

    def __init__(self, args):
        super().__init__(args)

        self.manner = args.manner
        self.lock = threading.Lock()
        self.sampling_n = args.sampling_n
        self.criteria_step = args.criteria_step

        self.dataset_loader = CriteriaDataLoder(self.args)

        self._init_file()
        self.eval_metrics = {}
        time_str = datetime.now().strftime("%Y%m%d")
        self.output_file = os.path.join(self.output_dir, f"{args.inference_model_modality}_{time_str}.jsonl")

    def _init_file(self):

        self.criteria_file = os.path.join(self.output_dir, f"criteria.jsonl")
        self.criteria_merge_file = os.path.join(self.output_dir, f"criteria_merge.jsonl")
        self.judge_file = os.path.join(self.output_dir, f"judge.jsonl")
        self.refinement_file = os.path.join(self.output_dir, f"refinment.jsonl")
        self.ranking_file = os.path.join(self.output_dir, f"ranking.jsonl")
        
    
    def run_eval(self,router):
        start_time = time.time()

        if self.manner == "direct":
            self.direct_judge(router)
        elif self.manner == "stepwise":
            self.stepwise_judge(router)
        else:
            raise ValueError(f"manner: {self.manner} is not supported")
        self.save()
        
        end_time = time.time()
        elapsed = end_time - start_time
        print(f"==>⏱️ Eval router.run() Spend Time: {elapsed:.2f} 秒")

    def eval_handler(self, results):
        with self.lock:
            for result in results:
                r = result['paired_data']
                answer = r['answer']
                response = result['judge'][0]
                if r['suffix'] not in self.eval_metrics:
                    self.eval_metrics[r['suffix']] = {
                        "correct": 0,   
                        "error": 0,
                        "count": 0,
                        "parse_error": 0
                    }
                
                if r['id'] not in self.eval_metrics[r['suffix']]:
                    self.eval_metrics[r['suffix']][r['id']] = {                        
                        "correct": 0,
                        "error": 0,
                        "count": 0,
                        "parse_error": 0
                    }

                correct, pred = extract_omni_answer(response, answer)
                
                self.eval_metrics[r['suffix']]['count'] += 1
                self.eval_metrics[r['suffix']][r['id']]["count"] += 1
                if correct:
                    self.eval_metrics[r['suffix']]['correct'] += 1
                    self.eval_metrics[r['suffix']][r['id']]["correct"] += 1
                else:
                    self.eval_metrics[r['suffix']]['error'] += 1
                    self.eval_metrics[r['suffix']][r['id']]["error"] += 1
                
                if pred is None:
                    self.eval_metrics[r['suffix']]['parse_error'] += 1
                    self.eval_metrics[r['suffix']][r['id']]["parse_error"] += 1
            return

    def save(self):
        results = []
        with open(self.judge_file, "r", encoding="utf-8") as f :
            datas = f.readlines()
            for data in datas:
                json_item = json.loads(data)
                results.append(json_item)
        self.eval_handler(results)

        new_eval_metrics = {}
        for suffix, values in self.eval_metrics.items():
            correct = 0
            count = 0
            for _suffix, v in values.items():
                if isinstance(v, dict):
                    if v['correct'] > (self.sampling_n / 2):
                        correct += 1
                    count += 1
            acc = correct / count
            
            new_eval_metrics[suffix] = {}
            new_eval_metrics[suffix]["accury"] = acc
            new_eval_metrics[suffix]["correct"] = correct
            new_eval_metrics[suffix]["count"] = count
            
            new_eval_metrics[suffix]["error"] = self.eval_metrics[suffix]['error']
            new_eval_metrics[suffix]["parse_error"] = self.eval_metrics[suffix]['parse_error']
            print(f"🎯 The {suffix}  Accuracy = {acc}")
            
        with open(self.output_file.replace("jsonl", "metrics"), "a", encoding="utf-8") as fw:
            fw.write(json.dumps(new_eval_metrics, ensure_ascii=False) + "\n")
        print(f"📊 The Final Evaluation Results: {new_eval_metrics}")
    def run(self, router):

        start_time = time.time()

        if self.manner == "direct":
            self.direct_judge(router)
        elif self.manner == "stepwise":
            self.stepwise_judge(router)
        else:
            raise ValueError(f"manner: {self.manner} is not supported")
        print(f"==>⏱️ MetaRewardGenerator [Judge] done...")

        self.refinement(router)
        print(f"==>⏱️ MetaRewardGenerator [Refinement] done...")
        
        print(f"[DEBUG]: begin to run ranking !!!!!!")
        self.ranking(router)
        print(f"==>⏱️ MetaRewardGenerator [Ranking] done...")

        end_time = time.time()
        elapsed = end_time - start_time
        print(f"==>⏱️ MetaRewardGenerator router.run() Spend Time: {elapsed:.2f} 秒")

    def run_with_router(self, router, dataset_iter, handler):
        router.add_handler(handler)
        router.run(dataset_iter)
        router.reset_handler()

    def direct_judge(self, router):
        pass

    def merge(self):
        
        criteria_data = {}
        for line in open(self.criteria_file).readlines():
            json_item = json.loads(line)

            if json_item['id'] not in criteria_data:
                criteria_data[json_item['id']] = {
                    "id": json_item['id'],
                    "criteria": {}
                }

            if "paired_data" in json_item:
                criteria_data[json_item['id']]["paired_data"] = json_item["paired_data"]
            criteria_data[json_item['id']]["criteria"][json_item["criteria_step"]] = json_item["criteria_list"]
        
        with open(self.criteria_merge_file, 'w') as fw:
            for _, item in criteria_data.items():
                fw.write(json.dumps(item) + "\n")

    def stepwise_judge(self, router):

        def criteria_handler(batch_dataset):
            
            with open(self.criteria_file, 'a') as fw:
                for data in batch_dataset:
                    fw.write(data.criteria_dumps() + "\n")
        
        if not os.path.exists(self.criteria_merge_file):

            for step in range(self.criteria_step):
                dataset_iter = self.dataset_loader.load_criteria_iter(self.input_file, self.criteria_file, step)
                if dataset_iter is not None:
                    router.init_inference_model()
                    self.run_with_router(router, dataset_iter, criteria_handler)
                print(f"==>⏱️ MetaRewardGenerator [Judge] --> stepwise_judge in step {step} done...")

            self.merge()
        else:
            print(f"==>😎 Found existing file: {self.criteria_file}, Skip Stepwise Criteria...")

        def judge_handler(batch_dataset):
            
            with open(self.judge_file, 'a') as fw:
                for data in batch_dataset:
                    try:
                        fw.write(data.judge_dumps() + "\n")
                    except:
                        continue
        
        dataset_iter = self.dataset_loader.load_judge_iter(self.criteria_merge_file, self.judge_file)
        if dataset_iter is None:
            print(f"==>😎 Found existing file: {self.judge_file}, Skip Stepwise Judge...")
        else:
            router.init_inference_model()
            self.run_with_router(router, dataset_iter, judge_handler)
    
        print(f"==>⏱️ MetaRewardGenerator [Judge] --> stepwise_judge done...")

    def refinement(self, router):

        def refinement_handler(batch_dataset):
            
            with open(self.refinement_file, 'a') as fw:
                for data in batch_dataset:
                    fw.write(data.refinement_dumps() + "\n")
        
        dataset_iter = self.dataset_loader.load_refinement_iter(self.judge_file, self.refinement_file)
        if dataset_iter is None:
            print(f"==>😎 Found existing file: {self.refinement_file}, Skip Refinement...")
        else:
            router.init_refinement_model()
            self.run_with_router(router, dataset_iter, refinement_handler)
        print(f"==>⏱️ MetaRewardGenerator [Refinement] --> Refine done...")

    def ranking(self, router):
        
        def ranking_handler(batch_dataset):
            
            with open(self.ranking_file, 'a') as fw:
                for data in batch_dataset:
                    fw.write(data.ranking_dumps() + "\n")
        
        dataset_iter = self.dataset_loader.load_ranking_iter(self.refinement_file, self.ranking_file)
        if dataset_iter is None:
            print(f"==>😎 Found existing file: {self.ranking_file}, Skip Ranking...")
        else:
            router.init_ranking_model()
            self.run_with_router(router, dataset_iter, ranking_handler)
        print(f"==>⏱️ MetaRewardGenerator [Ranking] --> Refine done...")

    # def scoring_a_handler(self, results):
        
    #     with self.lock:
    #         with open(self.scoring_a_output_file, 'a') as fw:
    #             for r in results:
    #                 r.pop("conversation")
    #                 fw.write(json.dumps(r, ensure_ascii=False) + "\n")

    
    # def scoring_raw(self, router, init_flag):

    #     skip_raw_a_suffix = set()
    #     if os.path.exists(self.scoring_raw_a_output_file):
    #         skip_raw_a_suffix = self.get_skip_suffix(self.scoring_raw_a_output_file)
    #     dataset_a_iter = self.dataset_loader.load_correct_raw_iter_v2(self.refinement_response_a_output_file, skip_raw_a_suffix)
        
    #     if dataset_a_iter is None:
    #         print(f"==>😎 Found existing file: {self.scoring_raw_a_output_file}, work down!")
    #     else:
    #         init_flag = False
    #         if self.modality == "image":
    #             router.init_vision_scalar_reward()
    #         elif self.modality == "language":
    #             router.init_scalar_reward()
    #         else:
    #             raise Exception(f"Not Support This Modality {self.modality}")
            
    #         router.add_handler(self.scoring_raw_a_handler)
    #         router.run(dataset_a_iter)
    #         router.reset_handler()
    #         print(f"==>⏱️ CriteriaMetaRewardGenerator scoring_a A [Step 4] Scoring A done...")
        
    #     skip_raw_b_suffix = set()
    #     if os.path.exists(self.scoring_raw_b_output_file):
    #         skip_raw_b_suffix = self.get_skip_suffix(self.scoring_raw_b_output_file)
    #     dataset_b_iter = self.dataset_loader.load_correct_raw_iter_v2(self.refinement_response_b_output_file, skip_raw_b_suffix)
        
    #     if dataset_b_iter is None:
    #         print(f"==>😎 Found existing file: {self.scoring_raw_b_output_file}, work down!")
    #     else:
    #         init_flag = False

    #         if self.modality == "image":
    #             router.init_vision_scalar_reward()
    #         elif self.modality == "language":
    #             router.init_scalar_reward()
    #         else:
    #             raise Exception(f"Not Support This Modality {self.modality}")

    #         router.add_handler(self.scoring_raw_b_handler)
    #         router.run(dataset_b_iter)
    #         router.reset_handler()
    #         print(f"==>⏱️ CriteriaMetaRewardGenerator scoring_a B [Step 5] Scoring B done...")
    
    # def scoring(self, router):
        
    #     init_flag = True

    #     skip_a_suffix = set()
    #     if os.path.exists(self.scoring_a_output_file):
    #         skip_a_suffix = self.get_skip_suffix(self.scoring_a_output_file)
    #     dataset_a_iter = self.dataset_loader.load_correct_iter_v2(self.refinement_response_a_output_file, skip_a_suffix)
        
    #     if dataset_a_iter is None:
    #         print(f"==>😎 Found existing file: {self.scoring_a_output_file}, work down!")
    #     else:
    #         init_flag = False
    #         if self.modality == "image":
    #             router.init_vision_scalar_reward()
    #         elif self.modality == "language":
    #             router.init_scalar_reward()
    #         else:
    #             raise Exception(f"Not Support This Modality {self.modality}")

    #         router.add_handler(self.scoring_a_handler)
    #         router.run(dataset_a_iter)
    #         router.reset_handler()
    #         print(f"==>⏱️ CriteriaMetaRewardGenerator scoring_a A [Step 4] Scoring A done...")
        
    #     skip_b_suffix = set()
    #     if os.path.exists(self.scoring_b_output_file):
    #         skip_b_suffix = self.get_skip_suffix(self.scoring_b_output_file)
    #     dataset_b_iter = self.dataset_loader.load_correct_iter_v2(self.refinement_response_b_output_file, skip_b_suffix)
        
    #     if dataset_b_iter is None:
    #         print(f"==>😎 Found existing file: {self.scoring_b_output_file}, work down!")
    #     else:
    #         init_flag = False
    #         if self.modality == "image":
    #             router.init_vision_scalar_reward()
    #         elif self.modality == "language":
    #             router.init_scalar_reward()
    #         else:
    #             raise Exception(f"Not Support This Modality {self.modality}")

    #         router.add_handler(self.scoring_b_handler)
    #         router.run(dataset_b_iter)
    #         router.reset_handler()
    #         print(f"==>⏱️ CriteriaMetaRewardGenerator scoring_a B [Step 5] Scoring B done...")

    #     self.scoring_raw(router, init_flag)
