from abc import ABC, abstractmethod
from benchmark.rewardbench import load_rewardbench_iter, load_rewardbench_judge_iter
from benchmark.vl_rewardbench import load_vl_rewardbench_iter
from benchmark.genai_bench import load_genai_video_iter
from benchmark.audio_bench import load_audiobench_iter
from prompt.omni_prompt import *
from .traj_data import PairedData, Traj
from utils import *
import json
import random
import os


class BaseLoader(ABC):

    def __init__(self, args):
        self.args = args

    @abstractmethod
    def load_iter(self, *args, **kwargs):
        pass


class BenchmarkDataLoader(BaseLoader):

    def __init__(self, args):
        super().__init__(args)
        
        self.benchmark = args.benchmark
        self.benchmark_dir = args.benchmark_dir
        self.suffix = args.suffix

    def load_iter(self, output_file, manner: str):
        
        if self.benchmark == "rewardbench":
            return self.load_rewardbench_iter(output_file, manner=manner)
        elif self.benchmark == "rmb":
            return self.load_rmb_iter(output_file, manner=manner)
        elif self.benchmark == "vl_rewardbench":
            return self.load_vl_rewardbench_iter(manner=manner)
        elif self.benchmark == "multimodal_rewardbench":
            return self.load_multimodal_rewardbench_iter(manner=manner)
        elif self.benchmark == "genai_image":
            return self.load_genai_image_iter(manner=manner)
        elif self.benchmark == "genai_video":
            return self.load_genai_video_iter(manner=manner)
        elif self.benchmark == "audio_bench":
            return self.load_audio_bench_iter(output_file, manner=manner)
        elif self.benchmark == "ppe_bench":
            return self.load_ppe_bench_iter(manner=manner)
        else:
            raise Exception("Not Support {} for LanguageEvaluator".format(self.benchmark))

    def load_rewardbench_iter(self, output_file, manner: str):
        
        skip_id = set()
        if os.path.exists(output_file):
            for line in open(output_file).readlines():
                json_item = json.loads(line)
                skip_id.add(json_item["paired_data"]['id'])

        print("========>", len(skip_id))

        if self.suffix == "":
            for _suffix in ["chat", "chat_hard", "safety", "reasoning"]:
                self.suffix = _suffix
                for item in self.load_rewardbench_suffix_iter(skip_id, manner):
                    yield item
        else:
            assert self.suffix in ["chat", "chat_hard", "safety", "reasoning"]
            for item in self.load_rewardbench_suffix_iter(skip_id, manner):
                yield item

    def load_rmb_iter(self, output_file, manner: str):
        
        skip_id = set()
        for line in open(output_file).readlines():
            json_item = json.loads(line)
            skip_id.add(json_item["paired_data"]['id'])

        print("========>", len(skip_id))

        if self.suffix == "":
            for _suffix in ["bon_harmlessness", "bon_helpfulness", "pairwise_harmlessness", "pairwise_helpfulness"]:
                self.suffix = _suffix
                for item in self.load_rewardbench_suffix_iter(skip_id, manner):
                    yield item
        else:
            assert self.suffix in ["bon_harmlessness", "bon_helpfulness", "pairwise_harmlessness", "pairwise_helpfulness"]
            for item in self.load_rewardbench_suffix_iter(skip_id, manner):
                yield item

    def load_rewardbench_suffix_iter(self, skip_id, manner):
        
        benchmark_suffix_file = os.path.join(self.benchmark_dir, f"{self.suffix}.jsonl")

        print(f"[Construct Rewardbench] Suffix:[{self.suffix}] ")
        answer = 0
        with open(benchmark_suffix_file, 'r') as f:
            for line in f:
                
                json_item = json.loads(line)
                if json_item['id'] in skip_id:
                    continue

                p_data = PairedData(dict(
                    id=json_item['id'],
                    suffix=json_item['suffix'],
                    query={"content": json_item['conversations'][0]['content']}, 
                    chosen={"content": json_item['chosen']['content']}, 
                    rejected={"content": json_item['rejected']['content']}
                ))

                t_data = Traj()
                t_data.loads({"paired_data": p_data}, shuffle=True)
                
                if manner == "direct":
                    t_data.build_direct_judge_criteria_conversation(step=0)
                else:
                    raise ValueError(f"manner {manner} not supported in load_language_criteria_iter.")

                yield t_data
    
    def load_ppe_bench_iter(self, manner: str):
        
        if self.suffix == "":
            for _suffix in ["mmlu_pro", "math", "gpqa"]:
                self.suffix = _suffix
                for item in self.load_ppe_bench_suffix_iter(manner):
                    yield item
        else:
            for item in self.load_ppe_bench_suffix_iter(manner):
                yield item
    
    def load_ppe_bench_suffix_iter(self, manner):
        
        assert self.suffix in ["mmlu_pro", "math", "gpqa"]
        benchmark_suffix_file = os.path.join(self.benchmark_dir, f"{self.suffix}.jsonl")
        
        print(f"[Construct PPE Bench] Suffix:[{self.suffix}] ")
        answer = 0
        with open(benchmark_suffix_file, 'r') as f:
            for line in f:
                
                json_item = json.loads(line)

                p_data = PairedData(dict(
                    id=json_item['id'],
                    suffix=json_item['suffix'],
                    query={"content": json_item['conversations'][0]['content']}, 
                    chosen={"content": json_item['chosen']['content']}, 
                    rejected={"content": json_item['rejected']['content']}
                ))

                t_data = Traj()
                t_data.loads({"paired_data": p_data}, shuffle=True)
                
                if manner == "direct":
                    t_data.build_direct_judge_criteria_conversation(step=0)
                else:
                    raise ValueError(f"manner {manner} not supported in load_language_criteria_iter.")

                yield t_data
    def load_vl_rewardbench_iter(self, manner: str):
    
        if self.suffix == "":
            for _suffix in ["general", "hallucination", "reasoning"]:
                self.suffix = _suffix
                for item in self.load_vl_rewardbench_suffix_iter(manner):
                    yield item
        else:
            for item in self.load_vl_rewardbench_suffix_iter(manner):
                yield item

    def load_vl_rewardbench_suffix_iter(self, manner):
        
        assert self.suffix in ["general", "hallucination", "reasoning"]
        benchmark_suffix_file = os.path.join(self.benchmark_dir, f"{self.suffix}.jsonl")

        print(f"[Construct Rewardbench] Suffix:[{self.suffix}] ")
        answer = 0
        with open(benchmark_suffix_file, 'r') as f:
            for line in f:
                
                json_item = json.loads(line)

                p_data = PairedData(dict(
                    id=json_item['id'],
                    suffix=json_item['suffix'],
                    query={"content": json_item['conversations'][0]['content'], "images": json_item['images']}, 
                    chosen={"content": json_item['chosen']['content']}, 
                    rejected={"content": json_item['rejected']['content']}
                ))

                t_data = Traj()
                t_data.loads({"paired_data": p_data}, shuffle=True)
                
                if manner == "direct":
                    t_data.build_direct_judge_criteria_conversation(step=0, modality="image")
                else:
                    raise ValueError(f"manner {manner} not supported in load_language_criteria_iter.")

                yield t_data

    def load_multimodal_rewardbench_iter(self, manner: str):
        
        if self.suffix == "":
            for _suffix in ["coding", "correctness", "knowledge", "math", "preference", "reasoning", "safety", "vqa"]:
                self.suffix = _suffix
                for item in self.load_multimodal_rewardbench_suffix_iter(manner):
                    yield item
        else:
            for item in self.load_multimodal_rewardbench_suffix_iter(manner):
                yield item
    
    def load_multimodal_rewardbench_suffix_iter(self, manner):
        assert self.suffix in ["coding", "correctness", "knowledge", "math", "preference", "reasoning", "safety", "vqa"]
        benchmark_suffix_file = os.path.join(self.benchmark_dir, f"{self.suffix}.jsonl")

        print(f"[Construct Multimodal Rewardbench] Suffix:[{self.suffix}] ")
        answer = 0
        with open(benchmark_suffix_file, 'r') as f:
            for line in f:
                
                json_item = json.loads(line)

                p_data = PairedData(dict(
                    id=json_item['id'],
                    suffix=json_item['suffix'],
                    query={"content": json_item['conversations'][0]['content'], "images": json_item['images']}, 
                    chosen={"content": json_item['chosen']['content']}, 
                    rejected={"content": json_item['rejected']['content']}
                ))

                t_data = Traj()
                t_data.loads({"paired_data": p_data}, shuffle=True)
                
                if manner == "direct":
                    t_data.build_direct_judge_criteria_conversation(step=0, modality="image")
                else:
                    raise ValueError(f"manner {manner} not supported in load_language_criteria_iter.")

                yield t_data

    def load_genai_image_iter(self, manner: str):

        if self.suffix == "":
            for _suffix in ["image_gen"]:
                self.suffix = _suffix
                for item in self.load_genai_image_suffix_iter(manner):
                    yield item
        else:
            for item in self.load_genai_image_suffix_iter(manner):
                yield item

    def load_genai_image_suffix_iter(self, manner):
        
        assert self.suffix in ["image_gen"]
        benchmark_suffix_file = os.path.join(self.benchmark_dir, f"{self.suffix}.jsonl")

        print(f"[Construct GenAI ImageGeneration] Suffix:[{self.suffix}] ")
        answer = 0
        with open(benchmark_suffix_file, 'r') as f:
            for line in f:
                
                json_item = json.loads(line)

                p_data = PairedData(dict(
                    id=json_item['id'],
                    suffix=json_item['suffix'],
                    query={"content": json_item['conversations'][0]['content']}, 
                    chosen={"content": json_item['chosen']['content'], "images": json_item['chosen_images']}, 
                    rejected={"content": json_item['rejected']['content'], "images": json_item['rejected_images']}
                ))

                t_data = Traj()
                t_data.loads({"paired_data": p_data}, shuffle=True)
                
                if manner == "direct":
                    t_data.build_direct_judge_criteria_conversation(step=0, modality="image")
                else:
                    raise ValueError(f"manner {manner} not supported in load_genai_image_suffix_iter.")

                yield t_data
    
    def load_genai_video_iter(self, manner: str):
        
        if self.suffix == "":
            for _suffix in ["video_gen"]:
                self.suffix = _suffix
                for item in self.load_genai_video_suffix_iter(manner):
                    yield item
        else:
            for item in self.load_genai_video_suffix_iter(manner):
                yield item

    def load_genai_video_suffix_iter(self, manner):
        
        assert self.suffix in ["video_gen"]
        benchmark_suffix_file = os.path.join(self.benchmark_dir, f"{self.suffix}.jsonl")

        print(f"[Construct Rewardbench] Suffix:[{self.suffix}] ")
        answer = 0
        with open(benchmark_suffix_file, 'r') as f:
            for line in f:
                
                json_item = json.loads(line)

                p_data = PairedData(dict(
                    id=json_item['id'],
                    suffix=json_item['suffix'],
                    query={"content": json_item['conversations'][0]['content']}, 
                    chosen={"content": json_item['chosen']['content'], "videos": json_item['chosen_videos']}, 
                    rejected={"content": json_item['rejected']['content'], "videos": json_item['rejected_video']}
                ))

                t_data = Traj()
                t_data.loads({"paired_data": p_data}, shuffle=True)
                
                if manner == "direct":
                    t_data.build_direct_judge_criteria_conversation(step=0, modality="video")
                else:
                    raise ValueError(f"manner {manner} not supported in load_language_criteria_iter.")

                yield t_data

    def load_audio_bench_iter(self, output_file, manner: str):
        
        skip_id = set()
        if os.path.exists(output_file):
            for line in open(output_file).readlines():
                json_item = json.loads(line)
                skip_id.add(json_item["paired_data"]['id'])

        print("========>", len(skip_id))
        if self.suffix == "":
            for _suffix in ["audio_und", "audio_gen"]:
                self.suffix = _suffix
                for item in self.load_audiobench_suffix_iter(skip_id, manner):
                    yield item
        else:
            for item in self.load_audiobench_suffix_iter(skip_id, manner):
                yield item
    
    def load_audiobench_suffix_iter(self, skip_id, manner):
        
        assert self.suffix in ["audio_und", "audio_gen"]
        benchmark_suffix_file = os.path.join(self.benchmark_dir, f"{self.suffix}.jsonl")

        print(f"[Construct Rewardbench] Suffix:[{self.suffix}] ")
        answer = 0
        with open(benchmark_suffix_file, 'r') as f:
            for line in f:
                
                json_item = json.loads(line)
                if json_item['id'] in skip_id:
                    continue
                
                if "audios" in json_item:
                    
                    if is_longer_than_100s(json_item['audios'][0]):
                        continue
                    
                    _query = {"content": json_item['conversations'][0]['content'], "audios": json_item['audios']}
                    _chosen = {"content": json_item['chosen']['content']}
                    _rejected = {"content": json_item['chosen']['content']}
                else:
                    
                    if is_longer_than_100s(json_item['chosen_audios'][0]) or is_longer_than_100s(json_item['rejected_audios'][0]):
                        continue
                    
                    _query = {"content": json_item['conversations'][0]['content']}
                    _chosen = {"content": "<Audio>: ", "audios": json_item['chosen_audios']}
                    _rejected = {"content": "<Audio>: ", "audios": json_item['rejected_audios']}
                
                p_data = PairedData(dict(
                    id=json_item['id'],
                    suffix=json_item['suffix'],
                    query=_query,
                    chosen=_chosen,
                    rejected=_rejected
                ))

                t_data = Traj()
                t_data.loads({"paired_data": p_data}, shuffle=True)
                
                if manner == "direct":
                    t_data.build_direct_judge_criteria_conversation(step=0, modality="audio")
                else:
                    raise ValueError(f"manner {manner} not supported in load_language_criteria_iter.")

                yield t_data