from prompt.omni_prompt import extract_omni_answer
from .generator import BaseGenerator
from .dataloader import *
from .criteria_dataloder import *
from datetime import datetime
import itertools
import threading
import json
import time
import ray
import os


class PreferenceGenerator(BaseGenerator):


    def __init__(self, args):
        super().__init__(args)

        if self.modality == "language":
            self.dataset_loader = LanguageDataLoder(self.args)
        elif self.modality == "image":
            self.dataset_loader = ImageDataLoder(self.args)
        elif self.modality == "video":
            self.dataset_loader = VideoDataLoder(self.args)
        else:
            raise Exception(f"Not Support This Modality {self.modality}")
        
        self.lock = threading.Lock()

        time_str = datetime.now().strftime("%Y%m%d_%H%M")
        self.output_file = os.path.join(self.output_dir, f"sft_{time_str}.jsonl")
        self.preference_output_file = os.path.join(self.output_dir, f"preference_{time_str}.jsonl")

        self.sampling_max = args.sampling_max
        self.sampling_n = args.sampling_n
        
        self.loop_index_count = None
        self.retry_number = 0

    def show(self):
        print(f"📊 The Final Retry Number: {self.retry_number}")

    def filter_handler(self, results):
        
        assert self.loop_index_count is not None
        with self.lock:
            with open(self.output_file, "a", encoding="utf-8") as fw, open(self.preference_output_file, "a", encoding="utf-8") as fp:
                for result in results:
                    answer = result['answer']

                    corr_count = 0
                    parser_result = []
                    for r in result['response']:
                        correct, pred = extract_omni_answer(r, answer)
                        parser_result.append((correct, pred))
                        if correct:
                            corr_count += 1
                    
                    self.loop_index_count[result['suffix']] += 3

                    result['parser_result'] = parser_result
                    result['loop'] = self.loop_index_count[result['suffix']]

                    if corr_count == 3:
                        self.loop_index_count.pop(result['suffix'])
                        fw.write(json.dumps(result, ensure_ascii=False) + "\n")
                    elif corr_count > 0:
                        self.loop_index_count.pop(result['suffix'])
                        fp.write(json.dumps(result, ensure_ascii=False) + "\n")
                    else:
                        if self.loop_index_count[result['suffix']] >= self.sampling_max:
                            self.retry_number += 1
                            self.loop_index_count.pop(result['suffix'])
                            fp.write(json.dumps(result, ensure_ascii=False) + "\n")

    def load_loop_iter(self):
        for _, v in self.loop_iter.items():
            yield v

    def run(self, router):
        
        self.loop_index_count = self.dataset_loader.load_index()
        dataset_iter = self.dataset_loader.load_iter()

        router.add_handler(self.filter_handler)

        start_time = time.time()

        router.init_vllm()
        router.run(dataset_iter)
        
        while len(self.loop_index_count):
            print(f"✨ Now loop_index_count : {len(self.loop_index_count)}")
            dataset_iter = self.dataset_loader.load_iter(self.loop_index_count)
            router.run(dataset_iter)
        
        end_time = time.time()
        elapsed = end_time - start_time
        
        print(f"⏱️ LanguageEvaluator router.run() Spend Time: {elapsed:.2f} 秒")
        self.show()