from prompt.criteria_prompt import *
from prompt.omni_prompt import *
from .omni_dataloader import DataLoader
from .traj_data import PairedData, Traj
import json
import random
import os
import re


class CriteriaDataLoder(DataLoader):

    def __init__(self, args):
        super().__init__(args)
        self.criteria_n = args.criteria_n
    
    def _recover_criteria(self, recover_file, step):
        
        id2criteria = {}
        id2answer = {}
        skip_id = set()
        for line in open(recover_file).readlines():
            json_item = json.loads(line)
            paired_data_id = json_item['paired_data']["id"]
        
            if json_item['criteria_step'] == step:
                skip_id.add(paired_data_id)

            if paired_data_id not in id2criteria:
                id2criteria[paired_data_id] = []
            
            id2answer[paired_data_id] = json_item['answer']
            id2criteria[paired_data_id].extend(json_item['criteria_list'])
            
        return skip_id, id2criteria, id2answer

    def load_criteria_iter(self, dataset_file, recover_file, step, manner: str):

        skip_id = {}
        id2criteria = {}
        id2answer = {}
        if os.path.exists(recover_file):
            skip_id, id2criteria, id2answer = self._recover_criteria(recover_file, step)
        
        handled_num = len(skip_id)
        with open(dataset_file, 'r') as f:

            for line in f:
                json_item = json.loads(line)
                if json_item['id'] in skip_id:
                    continue
                if json_item['id'] not in id2answer and step > 0:
                    continue
                
                if self.args.max_input_size != -1 and handled_num > self.args.max_input_size:
                    return
                handled_num += 1
                
                if "images" in json_item:
                    yield self.load_images_iter(json_item, id2criteria, id2answer, step, manner=manner)
                elif "videos" in json_item:
                    yield self.load_videos_iter(json_item)
                elif "audios" in json_item:
                    yield self.load_audio_iter(json_item)
                else:
                    yield self.load_language_criteria_iter(json_item, id2criteria, id2answer, step, manner=manner)
        return None
    
    def load_images_iter(self, json_item, id2criteria, id2answer, step, manner: str):

        p_data = PairedData(dict(
            id=json_item['id'],
            suffix=json_item['suffix'],
            query={"content": json_item['conversations'][0]['content'], "images": json_item['images']}, 
            chosen={"content": json_item['chosen']['content']}, 
            rejected={"content": json_item['rejected']['content']}
        ))
        
        t_data = Traj()
        if step == 0:
            t_data.loads({"paired_data": p_data}, shuffle=True)
        else:
            t_data.loads({
                "paired_data": p_data,  
                "criteria_list": id2criteria.get(json_item['id'], []),
                "answer": id2answer[json_item['id']]
            }, shuffle=False)
        
        if manner == "direct":
             t_data.build_direct_judge_criteria_conversation(step=0, modality="image")
        else:
            raise ValueError(f"manner {manner} not supported in load_language_criteria_iter.")
        
        return t_data
    
    def load_audios_iter(self, json_item):
        pass

    def load_videos_iter(self, json_item):
        pass

    def load_language_criteria_iter(self, json_item, id2criteria, id2answer, step, manner: str) -> Traj:

        p_data = PairedData(dict(
            id=json_item['id'],
            suffix=json_item['suffix'],
            query={"content": json_item['conversations'][0]['content']}, 
            chosen={"content": json_item['chosen']['content']}, 
            rejected={"content": json_item['rejected']['content']}
        ))

        t_data = Traj()
        if step == 0:
            t_data.loads({"paired_data": p_data}, shuffle=True)
        else:
            t_data.loads({
                "paired_data": p_data,  
                "criteria_list": id2criteria.get(json_item['id'], []),
                "answer": id2answer[json_item['id']]
            }, shuffle=False)
        
        if manner == 'stepwise':
            t_data.build_stepwise_criteria_conversation(step)
        elif manner == "direct":
            t_data.build_direct_judge_criteria_conversation(step)
        elif manner == "criteria_n":
            t_data.build_criteria_n_conversation(step, self.criteria_n)
        else:
            raise ValueError(f"manner {manner} not supported in load_language_criteria_iter.")
        
        return t_data

    def _recover(self, recover_file):

        skip_id = set()
        for line in open(recover_file).readlines():
            json_item = json.loads(line)
            skip_id.add(json_item['paired_data']['id'])
        return skip_id

    def load_judge_iter(self, dataset_file, recover_file, manner: str):

        if os.path.exists(recover_file):
            skip_id = self._recover(recover_file)
        else:
            skip_id = {}
        
        handled_num = len(skip_id)
        with open(dataset_file, 'r') as f:
            for line in f:
                json_item = json.loads(line)

                paired_data_id = json_item['paired_data']['id']
                if paired_data_id in skip_id or len(json_item['criteria_list']) < self.criteria_n:
                    continue

                if self.args.max_input_size != -1 and handled_num > self.args.max_input_size:
                    return 
                handled_num += 1

                yield self._load_judge_iter(json_item, manner)
        return None
    
    def _load_judge_iter(self, json_item, manner: str):
        
        p_data = PairedData(json_item['paired_data'])
        
        t_data = Traj()

        if manner == 'stepwise' or manner == 'direct':

            t_data.loads({
                "paired_data": p_data,
                "answer": json_item['answer'],
                "criteria_list": [json_item['criteria'][str(step)] for step in range(self.criteria_step)]
            }, shuffle=False)

            for step in range(self.criteria_step):
                t_data.build_stepwise_judge_conversation(step)
        
        elif  manner == "criteria_n":
            
            t_data.loads({
                "paired_data": p_data,
                "answer": json_item['answer'],
                "criteria_list": json_item['criteria_list']
            }, shuffle=False)

            for n in range(self.criteria_n):
                t_data.build_criteria_n_judge_conversation(n)
        
        else:
            raise ValueError(f"manner {manner} not supported in load_judge_iter.")
        return t_data

    def load_refinement_iter(self, dataset_file, recover_file):

        if os.path.exists(recover_file):
            skip_id = self._recover(recover_file)
        else:
            skip_id = {}
        
        with open(dataset_file, 'r') as f:
            for line in f:
                json_item = json.loads(line)
                paired_data_id = json_item['paired_data']['id']

                if paired_data_id in skip_id:
                    continue

                yield self._load_refinement_iter(json_item)
        return None

    def _load_refinement_iter(self, json_item):
        
        p_data = PairedData(json_item['paired_data'])

        t_data = Traj()
        t_data.loads({
            "paired_data": p_data,
            "answer": json_item['answer']
        }, shuffle=False)

        t_data.build_refinement_conversation(json_item['judge_pair'])
        return t_data

    def load_ranking_iter(self, dataset_file: str, recover_file: str):
        
        if os.path.exists(recover_file):
            skip_id = self._recover(recover_file)
        else:
            skip_id = {}
        
        with open(dataset_file, 'r') as f:
            for line in f:
                json_item = json.loads(line)
                paired_data_id = json_item['paired_data']["id"]

                if paired_data_id in skip_id:
                    continue
                yield self._load_ranking_iter(json_item)
    
    def _load_ranking_iter(self, json_item):
        
        p_data = PairedData(json_item['paired_data'])

        t_data = Traj()
        t_data.loads({
            "paired_data": p_data,
            "answer": json_item['answer']
        }, shuffle=False)

        t_data.build_ranking_conversation(json_item['refinement_list'])
        return t_data
    
    def load_filter_data_iter(self, dataset_file: str):
        
        with open(dataset_file, 'r') as f:
            for line in f:
                json_item = json.loads(line)
                yield self._load_filter_iter(json_item)
        return None
    
    def _load_filter_iter(self, json_item):

        p_data = PairedData(json_item['paired_data'])
        json_item['paired_data'] = p_data

        t_data = Traj()
        t_data.loads(json_item, shuffle= False)
        return t_data





class LanguageCriteriaDataLoder:

    def __init__(self, args):
        super().__init__(args)
        self.criteria_n = args.criteria_n

    def load_benchmark_ciriteria_iter(self, benchmark_file, skip_suffix):
        
        with open(benchmark_file, 'r') as f:

            for line in f:

                json_item = json.loads(line)
                if json_item['suffix'] in skip_suffix:
                    continue
                
                if 'answer' in json_item:
                    
                    if json_item['answer'] == 0:
                        prompt = build_long_omni_criteria_prompt(json_item['prompt'], json_item['chosen'], json_item['rejected'], criteria=json_item['response'][0])
                    else:
                        prompt = build_long_omni_criteria_prompt(json_item['prompt'], json_item['rejected'], json_item['chosen'], criteria=json_item['response'][0])
                    conversation = [{'role': 'user', 'content': prompt}]
                    
                    yield {
                        "suffix": json_item['suffix'],
                        "prompt": json_item['prompt'],
                        "chosen": json_item['chosen'],
                        "rejected": json_item['rejected'],
                        "conversation": conversation, 
                        "answer": json_item['answer']
                    }
                else:
                    if random.random() > 0.5:
                        answer = 0
                    else:
                        answer = 1
                    
                    if answer == 0:
                        prompt = build_long_omni_criteria_prompt(json_item['conversations'][0]['content'], json_item['chosen']['content'], json_item['rejected']['content'], criteria="")
                    else:
                        prompt = build_long_omni_criteria_prompt(json_item['conversations'][0]['content'], json_item['rejected']['content'], json_item['chosen']['content'], criteria="")
                    conversation = [{'role': 'user', 'content': prompt}]
                    
                    yield {
                        "suffix": json_item['suffix'],
                        "prompt": json_item['conversations'][0]['content'],
                        "chosen": json_item['chosen']['content'],
                        "rejected": json_item['rejected']['content'],
                        "conversation": conversation, 
                        "answer": answer
                    }

    def load_benchmark_iter(self, benchmark_file, skip_suffix):
        
        with open(benchmark_file, 'r') as f:

            for line in f:
                
                json_item = json.loads(line)
                if json_item['suffix'] in skip_suffix:
                    continue
                
                if json_item['answer'] == 0:
                    if 'conversations' in json_item:
                        prompt = build_long_omni_judge_prompt(json_item['conversations'][0]['content'], json_item['chosen']['content'], json_item['rejected']['content'], criteria=json_item['response'][0])
                    else:
                        prompt = build_long_omni_judge_prompt(json_item['prompt'], json_item['chosen'], json_item['rejected'], criteria=json_item['response'][0])
                else:
                    if 'conversations' in json_item:
                        prompt = build_long_omni_judge_prompt(json_item['conversations'][0]['content'], json_item['rejected']['content'], json_item['chosen']['content'], criteria=json_item['response'][0])
                    else:
                        prompt = build_long_omni_judge_prompt(json_item['prompt'], json_item['rejected'], json_item['chosen'], criteria=json_item['response'][0])
                conversation = [{'role': 'user', 'content': prompt}]
                
                if 'conversations' in json_item:
                    yield {
                        "suffix": json_item['suffix'],
                        "prompt": json_item['conversations'][0]['content'],
                        "chosen": json_item['chosen']['content'],
                        "rejected": json_item['rejected']['content'],
                        "conversation": conversation, 
                        "answer": json_item['answer']
                    }
                else:
                    yield {
                        "suffix": json_item['suffix'],
                        "prompt": json_item['prompt'],
                        "chosen": json_item['chosen'],
                        "rejected": json_item['rejected'],
                        "conversation": conversation, 
                        "answer": json_item['answer']
                    }

    def load_benchmark_list_iter(self, benchmark_file_list, skip_suffix):
        
        suffix2criteria = {}
        for benchmark_file in benchmark_file_list:
            with open(benchmark_file, 'r') as f:
                for line in f:
                    json_item = json.loads(line)

                    if json_item['suffix'] not in suffix2criteria:
                        suffix2criteria[json_item['suffix']] = json_item['response']
                    else:
                        suffix2criteria[json_item['suffix']].append(json_item['response'][0])

        with open(benchmark_file, 'r') as f:

            for line in f:
                
                json_item = json.loads(line)
                if json_item['suffix'] in skip_suffix:
                    continue
                
                judge_convs = []
                for criteria in suffix2criteria[json_item['suffix']]:
                    if json_item['answer'] == 0:
                        if 'conversations' in json_item:
                            prompt = build_long_omni_judge_prompt(json_item['conversations'][0]['content'], json_item['chosen']['content'], json_item['rejected']['content'], criteria=criteria)
                        else:
                            prompt = build_long_omni_judge_prompt(json_item['prompt'], json_item['chosen'], json_item['rejected'], criteria=criteria)
                    else:
                        if 'conversations' in json_item:
                            prompt = build_long_omni_judge_prompt(json_item['conversations'][0]['content'], json_item['rejected']['content'], json_item['chosen']['content'], criteria=criteria)
                        else:
                            prompt = build_long_omni_judge_prompt(json_item['prompt'], json_item['rejected'], json_item['chosen'], criteria=criteria)
                    conversation = [{'role': 'user', 'content': prompt}]
                    judge_convs.append(conversation)
                
                if 'conversations' in json_item:
                    yield {
                        "suffix": json_item['suffix'],
                        "prompt": json_item['conversations'][0]['content'],
                        "chosen": json_item['chosen']['content'],
                        "rejected": json_item['rejected']['content'],
                        "conversation": judge_convs, 
                        "answer": json_item['answer']
                    }
                else:
                    yield {
                        "suffix": json_item['suffix'],
                        "prompt": json_item['prompt'],
                        "chosen": json_item['chosen'],
                        "rejected": json_item['rejected'],
                        "conversation": judge_convs, 
                        "answer": json_item['answer']
                    }

    def load_criteria_iter(self, criteria_file, skip_suffix):
        
        with open(criteria_file, 'r') as f:      
            
            for line in f:
                json_item = json.loads(line)
                if json_item['suffix'] in skip_suffix:
                    continue

                if random.random() > 0.5:
                    answer = 0
                else:
                    answer = 1
                
                judge_prompt = []
                for criterion in json_item['criteria_list']:
                    if answer == 0:
                        prompt = build_judge_prompt(json_item['prompt'], json_item['chosen'], json_item['rejected'], criterion)
                    else:
                        prompt = build_judge_prompt(json_item['prompt'], json_item['rejected'], json_item['chosen'], criterion)
                    conversation = [{'role': 'user', 'content': prompt}]
                    judge_prompt.append(conversation)

                yield {
                    "suffix": json_item['suffix'],
                    "prompt": json_item['prompt'],
                    "chosen": json_item['chosen'],
                    "rejected": json_item['rejected'],
                    "conversation": judge_prompt,
                    "criteria_list": json_item['criteria_list'],
                    "answer": answer
                }
        
        return None

    def load_critique_iter(self, critique_file, skip_suffix):
        
        with open(critique_file, 'r') as f:
            
            for line in f:
                json_item = json.loads(line)
                if json_item['suffix'] in skip_suffix:
                    continue
                
                correct_prompt = []
                for judge, criteria in zip(json_item['judge_list'], json_item["criteria_list"]):
                    prompt = build_critique_prompt(json_item['prompt'], json_item['rejected'], criteria, judge)
                    conversation = [{'role': 'user', 'content': prompt}]
                    correct_prompt.append(conversation)

                yield {
                    "suffix": json_item['suffix'],
                    "prompt": json_item['prompt'],
                    "chosen": json_item['chosen'],
                    "rejected": json_item['rejected'],
                    "conversation": correct_prompt,
                    "answer": json_item['answer']
                }

    def load_judge_iter(self, judge_file, skip_suffix, A=True):
        
        with open(judge_file, 'r') as f:
            
            for line in f:
                json_item = json.loads(line)
                if json_item['suffix'] in skip_suffix:
                    continue
                
                correct_prompt = []
                for judge in json_item['critique_list']:

                    if json_item['answer'] == 0:
                        if A:
                            prompt = build_correct_prompt(json_item['prompt'], json_item['chosen'], judge)
                        else:
                            prompt = build_correct_prompt(json_item['prompt'], json_item['rejected'], judge)
                    else:
                        if A:
                            prompt = build_correct_prompt(json_item['prompt'], json_item['rejected'], judge)
                        else:
                            prompt = build_correct_prompt(json_item['prompt'], json_item['chosen'], judge)
                    
                    conversation = [{'role': 'user', 'content': prompt}]
                    correct_prompt.append(conversation)
                
                yield {
                    "suffix": json_item['suffix'],
                    "prompt": json_item['prompt'],
                    "chosen": json_item['chosen'],
                    "rejected": json_item['rejected'],
                    "conversation": correct_prompt,
                    "answer": json_item['answer']
                }
    
    def load_correct_iter(self, judge_file):
        
        with open(judge_file, 'r') as f:
            
            for line in f:
                json_item = json.loads(line)
                
                scoring_prompt = []
                for response in json_item['correct_list']:
                    conversation = [{'role': 'user', 'content': json_item['prompt']}, {'role': 'assistant', 'content': response}]
                    scoring_prompt.append(conversation)
                
                yield {
                    "suffix": json_item['suffix'],
                    "prompt": json_item['prompt'],
                    "chosen": json_item['chosen'],
                    "rejected": json_item['rejected'],
                    "conversation": scoring_prompt,
                    "answer": json_item['answer']
                }

    def load_criteria_iter_v2(self, criteria_file, skip_suffix):
        
        with open(criteria_file, 'r') as f:      
            
            for line in f:
                json_item = json.loads(line)
                if json_item['suffix'] in skip_suffix:
                    continue

                if random.random() > 0.5:
                    answer = 0
                else:
                    answer = 1
                
                judge_prompt = []
                for criterion in json_item['criteria_list']:
                    if answer == 0:
                        prompt = build_judge_prompt_v2(json_item['prompt'], json_item['chosen'], json_item['rejected'], criterion)
                    else:
                        prompt = build_judge_prompt_v2(json_item['prompt'], json_item['rejected'], json_item['chosen'], criterion)
                    conversation = [{'role': 'user', 'content': prompt}]
                    judge_prompt.append(conversation)

                yield {
                    "suffix": json_item['suffix'],
                    "prompt": json_item['prompt'],
                    "chosen": json_item['chosen'],
                    "rejected": json_item['rejected'],
                    "conversation": judge_prompt,
                    "criteria_list": json_item['criteria_list'],
                    "answer": answer
                }
        
        return None

    def load_judge_iter_v2(self, judge_file, skip_suffix, A=True):
        
        def extract(text: str):
            results = {}

            match_a = re.search(r"<judge A>(.*?)</judge A>", text, re.DOTALL)
            if match_a:
                block_a = match_a.group(1)
                results["A"] = block_a.strip()

            match_b = re.search(r"<judge B>(.*?)</judge B>", text, re.DOTALL)
            if match_b:
                block_b = match_b.group(1)
                results["B"] = block_b.strip()
            
            return results

        with open(judge_file, 'r') as f:
            
            for line in f:
                json_item = json.loads(line)
                if json_item['suffix'] in skip_suffix:
                    continue
                
                try:
                    correct_prompt = []
                    for judge in json_item['judge_list']:
                        judge_pair = extract(judge)
                        if json_item['answer'] == 0:
                            if A:
                                prompt = build_correct_prompt(json_item['prompt'], json_item['chosen'], judge_pair['A'])
                            else:
                                prompt = build_correct_prompt(json_item['prompt'], json_item['rejected'], judge_pair['B'])
                        else:
                            if A:
                                prompt = build_correct_prompt(json_item['prompt'], json_item['rejected'], judge_pair['A'])
                            else:
                                prompt = build_correct_prompt(json_item['prompt'], json_item['chosen'], judge_pair['B'])
                        
                        conversation = [{'role': 'user', 'content': prompt}]
                        correct_prompt.append(conversation)
                except:
                    continue

                yield {
                    "suffix": json_item['suffix'],
                    "prompt": json_item['prompt'],
                    "chosen": json_item['chosen'],
                    "rejected": json_item['rejected'],
                    "conversation": correct_prompt,
                    "answer": json_item['answer']
                }
    
    def load_correct_iter_v2(self, judge_file, skip_suffix):
        
        with open(judge_file, 'r') as f:
            
            for line in f:
                json_item = json.loads(line)
                if json_item['suffix'] in skip_suffix:
                    continue
                
                scoring_prompt = []
                if 'correct_a_list' in json_item:
                    for response in json_item['correct_a_list']:
                        conversation = [{'role': 'user', 'content': json_item['prompt']}, {'role': 'assistant', 'content': response}]
                        scoring_prompt.append(conversation)
                elif 'correct_b_list' in json_item:
                        for response in json_item['correct_b_list']:
                            conversation = [{'role': 'user', 'content': json_item['prompt']}, {'role': 'assistant', 'content': response}]
                            scoring_prompt.append(conversation)
                else:
                    continue

                yield {
                    "suffix": json_item['suffix'],
                    "prompt": json_item['prompt'],
                    "chosen": json_item['chosen'],
                    "rejected": json_item['rejected'],
                    "conversation": scoring_prompt,
                    "answer": json_item['answer']
                }
    
    def load_correct_raw_iter_v2(self, judge_file, skip_suffix):
        
        with open(judge_file, 'r') as f:
            
            for line in f:
                json_item = json.loads(line)
                if json_item['suffix'] in skip_suffix:
                    continue
                
                scoring_prompt = []
                if 'correct_a_list' in json_item:
                    if json_item['answer'] == 0:
                        conversation = [{'role': 'user', 'content': json_item['prompt']}, {'role': 'assistant', 'content': json_item['chosen']}]
                    else:
                        conversation = [{'role': 'user', 'content': json_item['prompt']}, {'role': 'assistant', 'content': json_item['rejected']}]
                    scoring_prompt.append(conversation)
                elif 'correct_b_list' in json_item:
                    if json_item['answer'] == 0:
                        conversation = [{'role': 'user', 'content': json_item['prompt']}, {'role': 'assistant', 'content': json_item['rejected']}]
                    else:
                        conversation = [{'role': 'user', 'content': json_item['prompt']}, {'role': 'assistant', 'content': json_item['chosen']}]
                    scoring_prompt.append(conversation)
                else:
                    continue
                
                yield {
                    "suffix": json_item['suffix'],
                    "prompt": json_item['prompt'],
                    "chosen": json_item['chosen'],
                    "rejected": json_item['rejected'],
                    "conversation": scoring_prompt,
                    "answer": json_item['answer']
                }

    # meta reward
    def load_judge_iter_v3(self, judge_file, skip_suffix, A=True):

        def extract_criteria(text):
            pattern = r"<Criteria\s*(\d+)>(.*?)</Criteria\s*\1>"
            matches = re.findall(pattern, text, flags=re.DOTALL)

            criteria_list = []
            for num, content in matches:
                criteria_list.append(content.strip())

            return criteria_list

        def extract(text: str):
            results = {}

            match_a = re.search(r"<Judge A>(.*?)</Judge A>", text, re.DOTALL)
            if match_a:
                block_a = match_a.group(1)
                criteria_a_list = extract_criteria(block_a.strip())
                results["criteria_a_list"] = criteria_a_list

            match_b = re.search(r"<Judge B>(.*?)</Judge B>", text, re.DOTALL)
            if match_b:
                block_b = match_b.group(1)
                criteria_b_list = extract_criteria(block_b.strip())
                results["criteria_b_list"] = criteria_b_list
            
            if "criteria_a_list" not in results or "criteria_b_list" not in results:
                raise Exception("criteria_a_list or criteria_b_list not found")

            return results

        with open(judge_file, 'r') as f:
            
            for line in f:
                json_item = json.loads(line)
                
                if json_item['suffix'] in skip_suffix:
                    continue
                
                try:
                    correct_prompt = []
                    for critique in json_item['critique_list']:
                        judge_pair = extract(critique)
                        for i in range(3):
                            if json_item['answer'] == 0:
                                if A:
                                    prompt = build_correct_prompt(json_item['prompt'], json_item['chosen'], judge_pair['criteria_a_list'][i])
                                else:
                                    prompt = build_correct_prompt(json_item['prompt'], json_item['rejected'], judge_pair['criteria_b_list'][i])
                            else:
                                if A:
                                    prompt = build_correct_prompt(json_item['prompt'], json_item['rejected'], judge_pair['criteria_a_list'][i])
                                else:
                                    prompt = build_correct_prompt(json_item['prompt'], json_item['chosen'], judge_pair['criteria_b_list'][i])
                            
                            conversation = [{'role': 'user', 'content': prompt}]
                            correct_prompt.append(conversation)
                except:
                    continue
                yield {
                    "suffix": json_item['suffix'],
                    "prompt": json_item['prompt'],
                    "chosen": json_item['chosen'],
                    "rejected": json_item['rejected'],
                    "conversation": correct_prompt,
                    "answer": json_item['answer']
                }


class ImageCriteriaDataLoder:

    def __init__(self, args):
        super().__init__(args)
        self.criteria_n = args.criteria_n

    def load_iter(self, skip_suffix):
        
        with open(self.input_dir, 'r') as f:
            
            for line in f:
                
                json_item = json.loads(line)
                if json_item['suffix'] in skip_suffix:
                    continue
                
                if "prompt" in json_item:
                    if random.random() > 0.5:
                        answer = 0
                        prompt = build_criteria_prompt(json_item['prompt'], json_item['chosen'], json_item['rejected'], self.criteria_n)
                    else:
                        answer = 1
                        prompt = build_criteria_prompt(json_item['prompt'], json_item['rejected'], json_item['chosen'], self.criteria_n)
                    
                    conversation = [{'role': 'user', 'content': prompt}]

                    yield {
                        "suffix": json_item['suffix'],
                        "prompt": json_item['prompt'],
                        "chosen": json_item['chosen'],
                        "rejected": json_item['rejected'],
                        "conversation": conversation
                    }
                else:
                    if len(json_item['images']) == 1:
                        image = json_item['images'][0]

                        if random.random() > 0.5:
                            prefix_pm, suffix_pm = build_criteria_omni_prompt_split(json_item['conversations'][0]['content'], json_item['chosen']['content'], json_item['rejected']['content'])
                            answer = 0
                        else:
                            prefix_pm, suffix_pm = build_criteria_omni_prompt_split(json_item['conversations'][0]['content'], json_item['rejected']['content'], json_item['chosen']['content'])
                            answer = 1
                        
                        conversation = [
                            {
                                'role': 'user', 
                                'content': [
                                    {"type": "text", "text": prefix_pm},
                                    {
                                        "type": "image",
                                        "image": image,
                                        "min_pixels": 512 * 512,
                                        "max_pixels": 512 * 512,
                                    },
                                    {"type": "text", "text": suffix_pm}
                                ]
                            }
                        ]
                    else:
                        chosen_image = json_item['images'][0]
                        rejected_image = json_item['images'][1]

                        if random.random() > 0.5:
                            prefix_pm, inffix_pm, suffix_pm = build_omni_prompt_triple(json_item['conversations'][0]['content'], json_item['chosen']['content'], json_item['rejected']['content'])
                            answer = 0
                        else:
                            prefix_pm, inffix_pm, suffix_pm = build_omni_prompt_triple(json_item['conversations'][0]['content'], json_item['rejected']['content'], json_item['chosen']['content'])
                            answer = 1
                        
                        conversation = [
                            {
                                'role': 'user', 
                                'content': [
                                    {"type": "text", "text": prefix_pm},
                                    {
                                        "type": "image",
                                        "image": chosen_image if answer == 0 else rejected_image,
                                        "min_pixels": 512 * 512,
                                        "max_pixels": 512 * 512,
                                    },
                                    {"type": "text", "text": inffix_pm},
                                    {
                                        "type": "image",
                                        "image": rejected_image if answer == 0 else chosen_image,
                                        "min_pixels": 512 * 512,
                                        "max_pixels": 512 * 512
                                    },
                                    {"type": "text", "text": suffix_pm}
                                ]
                            }
                        ]

                    yield {
                        "suffix": json_item['suffix'],
                        "prompt": json_item['conversations'][0]['content'],
                        "chosen": json_item['chosen']['content'],
                        "rejected": json_item['rejected']['content'],
                        "images":  json_item['images'],
                        "conversation": conversation
                    }
        
        return None

    def load_criteria_iter_v2(self, criteria_file, skip_suffix):
        
        with open(criteria_file, 'r') as f:      
            
            for line in f:
                json_item = json.loads(line)
                if json_item['suffix'] in skip_suffix:
                    continue
                
                if random.random() > 0.5:
                    answer = 0
                else:
                    answer = 1
                
                judge_prompt = []
                for criterion in json_item['criteria_list']:
                    image = json_item['images'][0]
                    if answer == 0:
                        prefix_pm, suffix_pm = build_judge_omni_prompt_v2(json_item['prompt'], json_item['chosen'], json_item['rejected'], criterion)
                    else:
                        prefix_pm, suffix_pm = build_judge_omni_prompt_v2(json_item['prompt'], json_item['rejected'], json_item['chosen'], criterion)
                    
                    conversation = [
                        {
                            'role': 'user', 
                            'content': [
                                {"type": "text", "text": prefix_pm},
                                {
                                    "type": "image",
                                    "image": image,
                                    "min_pixels": 512 * 512,
                                    "max_pixels": 512 * 512,
                                },
                                {"type": "text", "text": suffix_pm}
                            ]
                        }
                    ]
                    judge_prompt.append(conversation)

                yield {
                    "suffix": json_item['suffix'],
                    "prompt": json_item['prompt'],
                    "chosen": json_item['chosen'],
                    "rejected": json_item['rejected'],
                    "images":  json_item['images'],
                    "conversation": judge_prompt,
                    "criteria_list": json_item['criteria_list'],
                    "answer": answer
                }
        
        return None

    def load_judge_iter_v2(self, judge_file, skip_suffix, A=True):
        
        def extract(text: str):
            results = {}
            
            if "<judge A>" in text and "</judge A>" in text:
                match_a = re.search(r"<judge A>(.*?)</judge A>", text, re.DOTALL)
            else:
                match_a = re.search(r"<judge A>(.*?)<judge B>", text, re.DOTALL)
            if match_a:
                block_a = match_a.group(1)
                results["A"] = block_a.strip()
            
            if "<judge B>" in text and "</judge B>" in text:
                match_b = re.search(r"<judge B>(.*?)</judge B>", text, re.DOTALL)
            else:
                match_b = re.search(r"<judge B>(.*?)The Final Verdict is", text, re.DOTALL)
            
            if match_b:
                block_b = match_b.group(1)
                results["B"] = block_b.strip()
            
            return results

        with open(judge_file, 'r') as f:
            
            for line in f:
                json_item = json.loads(line)
                if json_item['suffix'] in skip_suffix:
                    continue
                
                try:
                    image = json_item['images'][0]

                    correct_prompt = []
                    for judge in json_item['judge_list']:
                        judge_pair = extract(judge)
                        if json_item['answer'] == 0:
                            if A:
                                prefix_pm, suffix_pm = build_correct_omni_prompt(json_item['prompt'], json_item['chosen'], judge_pair['A'])
                            else:
                                prefix_pm, suffix_pm = build_correct_omni_prompt(json_item['prompt'], json_item['rejected'], judge_pair['B'])
                        else:
                            if A:
                                prefix_pm, suffix_pm = build_correct_omni_prompt(json_item['prompt'], json_item['rejected'], judge_pair['A'])
                            else:
                                prefix_pm, suffix_pm = build_correct_omni_prompt(json_item['prompt'], json_item['chosen'], judge_pair['B'])
                        
                        conversation = [
                            {
                                'role': 'user', 
                                'content': [
                                    {"type": "text", "text": prefix_pm},
                                    {
                                        "type": "image",
                                        "image": image,
                                        "min_pixels": 512 * 512,
                                        "max_pixels": 512 * 512,
                                    },
                                    {"type": "text", "text": suffix_pm}
                                ]
                            }
                        ]
                        correct_prompt.append(conversation)
                except:
                    continue

                yield {
                    "suffix": json_item['suffix'],
                    "prompt": json_item['prompt'],
                    "chosen": json_item['chosen'],
                    "rejected": json_item['rejected'],
                    "images":  json_item['images'],
                    "judge_pair": judge_pair,
                    "conversation": correct_prompt,
                    "answer": json_item['answer']
                }
    
    # load_correct_raw_iter_v2
    def load_correct_iter_v2(self, judge_file, skip_suffix):
        
        with open(judge_file, 'r') as f:
            
            for line in f:
                json_item = json.loads(line)
                if json_item['suffix'] in skip_suffix:
                    continue
                
                image = json_item['images'][0]
                
                scoring_prompt = []
                if 'correct_a_list' in json_item:
                    for response in json_item['correct_a_list']:
                        conversation = [
                            {
                                'role': 'user', 
                                'content': [
                                    {
                                        "type": "image",
                                        "image": image,
                                        "min_pixels": 512 * 512,
                                        "max_pixels": 512 * 512,
                                    },
                                    {"type": "text", "text": json_item['prompt']}
                                ]
                            },
                            {'role': 'assistant', 'content': response}
                        ]
                        scoring_prompt.append(conversation)
                elif 'correct_b_list' in json_item:
                    for response in json_item['correct_b_list']:
                        conversation = [
                            {
                                'role': 'user', 
                                'content': [
                                    {
                                        "type": "image",
                                        "image": image,
                                        "min_pixels": 512 * 512,
                                        "max_pixels": 512 * 512,
                                    },
                                    {"type": "text", "text": json_item['prompt']}
                                ]
                            },
                            {'role': 'assistant', 'content': response}
                        ]
                        scoring_prompt.append(conversation)
                else:
                    continue

                yield {
                    "suffix": json_item['suffix'],
                    "prompt": json_item['prompt'],
                    "chosen": json_item['chosen'],
                    "rejected": json_item['rejected'],
                    "images":  json_item['images'],
                    "conversation": scoring_prompt,
                    "answer": json_item['answer']
                }
    
    def load_correct_raw_iter_v2(self, judge_file, skip_suffix):
        
        with open(judge_file, 'r') as f:
            
            for line in f:
                json_item = json.loads(line)
                if json_item['suffix'] in skip_suffix:
                    continue

                image = json_item['images'][0]
                
                scoring_prompt = []
                if 'correct_a_list' in json_item:
                    if json_item['answer'] == 0:
                        conversation = [
                            {
                                'role': 'user', 
                                'content': [
                                    {
                                        "type": "image",
                                        "image": image,
                                        "min_pixels": 512 * 512,
                                        "max_pixels": 512 * 512,
                                    },
                                    {"type": "text", "text": json_item['prompt']}
                                ]
                            },
                            {'role': 'assistant', 'content': json_item['chosen']}
                        ]
                    else:
                        conversation = [
                            {
                                'role': 'user', 
                                'content': [
                                    {
                                        "type": "image",
                                        "image": image,
                                        "min_pixels": 512 * 512,
                                        "max_pixels": 512 * 512,
                                    },
                                    {"type": "text", "text": json_item['prompt']}
                                ]
                            },
                            {'role': 'assistant', 'content': json_item['rejected']}
                        ]
                    scoring_prompt.append(conversation)
                elif 'correct_b_list' in json_item:
                    if json_item['answer'] == 0:
                        conversation = [
                            {
                                'role': 'user', 
                                'content': [
                                    {
                                        "type": "image",
                                        "image": image,
                                        "min_pixels": 512 * 512,
                                        "max_pixels": 512 * 512,
                                    },
                                    {"type": "text", "text": json_item['prompt']}
                                ]
                            },
                            {'role': 'assistant', 'content': json_item['rejected']}
                        ]
                    else:
                        conversation = [
                            {
                                'role': 'user', 
                                'content': [
                                    {
                                        "type": "image",
                                        "image": image,
                                        "min_pixels": 512 * 512,
                                        "max_pixels": 512 * 512,
                                    },
                                    {"type": "text", "text": json_item['prompt']}
                                ]
                            },
                            {'role': 'assistant', 'content': json_item['chosen']}
                        ]
                    scoring_prompt.append(conversation)
                else:
                    continue
                
                yield {
                    "suffix": json_item['suffix'],
                    "prompt": json_item['prompt'],
                    "chosen": json_item['chosen'],
                    "rejected": json_item['rejected'],
                    "images":  json_item['images'],
                    "conversation": scoring_prompt,
                    "answer": json_item['answer']
                }
