from dataloader.bench_dataloader import BaseLoader
from prompt.omni_prompt import *

from .traj_data import PairedData
import json

class DataLoader(BaseLoader):

    def __init__(self, args):
        super().__init__(args)
        pass
    
    def load_iter(self, filter_index={}, data_modality="language", shuffle=True):
        
        with open(self.input_dir, 'r') as f:
            
            index = 0
            for line in f:
                index += 1
                if index not in filter_index.keys():
                    continue
                
                json_item = json.loads(line)
                if data_modality == "language":

                    import pdb; pdb.set_trace()
                    p_data = PairedData(
                        query={content: ""},
                        chosen={content: ""},
                        rejected={content: ""}
                    )
                    conversation = self.load_language_convs(json_item)
                    
                    yield {
                        "suffix": index,
                        "conversation": conversation, 
                        "answer": answer
                    }

class LanguageDataLoder(BaseLoader):    
    
    def __init__(self, args):
        super().__init__(args)

        self.input_dir = args.input_dir

    def load_index(self):

        index_count = {}
        with open(self.input_dir, 'r') as f:
            index = 0
            for line in f:
                index += 1
                index_count[index] = 0
        return index_count

    def load_iter(self, filter_index={}):
        
        with open(self.input_dir, 'r') as f:
            
            index = 0
            for line in f:
                index += 1
                if index not in filter_index.keys():
                    continue
                
                json_item = json.loads(line)

                if random.random() > 0.5:
                    answer = 0
                    prompt = build_long_omni_prompt(json_item['conversations'][0]['content'], json_item['chosen']['content'], json_item['rejected']['content'])
                else:
                    answer = 1
                    prompt = build_long_omni_prompt(json_item['conversations'][0]['content'], json_item['rejected']['content'], json_item['chosen']['content'])
                
                conversation = [{'role': 'user', 'content': prompt}]
                yield {
                    "suffix": index,
                    "conversation": conversation, 
                    "answer": answer
                }

class ImageDataLoder(BaseLoader):    
    
    def __init__(self, args):
        super().__init__(args)

        self.input_dir = args.input_dir

    def load_index(self):

        index_count = {}
        with open(self.input_dir, 'r') as f:
            index = 0
            for line in f:
                index += 1
                index_count[index] = 0
        return index_count

    def load_iter(self, filter_index={}):
        
        with open(self.input_dir, 'r') as f:
            
            index = 0
            for line in f:
                index += 1
                if index not in filter_index.keys():
                    continue
                
                json_item = json.loads(line)

                if len(json_item['images']) == 1:
                    
                    query = json_item['conversations'][0]['content'].replace("Image: <image>", "").strip()
                    if random.random() > 0.5:
                        prefix_pm, suffix_pm = build_omni_prompt_split(query, json_item['chosen']['content'], json_item['rejected']['content'])
                        answer = 0
                    else:
                        prefix_pm, suffix_pm = build_omni_prompt_split(query, json_item['rejected']['content'], json_item['chosen']['content'])
                        answer = 1
                    
                    conversation = [
                        {
                            'role': 'user', 
                            'content': [
                                {"type": "text", "text": prefix_pm},
                                {
                                    "type": "image", 
                                    "image": json_item['images'][0],
                                    "min_pixels": 512 * 512,
                                    "max_pixels": 512 * 512,
                                },
                                {"type": "text", "text": suffix_pm}
                            ]
                        }
                    ]
                elif len(json_item['images']) == 2:

                    chosen_image = json_item['images'][0]
                    rejected_image = json_item['images'][1]

                    if random.random() > 0.5:
                        prefix_pm, inffix_pm, suffix_pm = build_omni_prompt_triple(json_item['conversations'][0]['content'], json_item['chosen']['content'], json_item['rejected']['content'])
                        answer = 0
                    else:
                        prefix_pm, inffix_pm, suffix_pm = build_omni_prompt_triple(json_item['conversations'][0]['content'], json_item['rejected']['content'], json_item['chosen']['content'])
                        answer = 1
                    
                    conversation = [
                        {
                            'role': 'user', 
                            'content': [
                                {"type": "text", "text": prefix_pm},
                                {
                                    "type": "image",
                                    "image": chosen_image if answer == 0 else rejected_image,
                                    "min_pixels": 512 * 512,
                                    "max_pixels": 512 * 512,
                                },
                                {"type": "text", "text": inffix_pm},
                                {
                                    "type": "image",
                                    "image": rejected_image if answer == 0 else chosen_image,
                                    "min_pixels": 512 * 512,
                                    "max_pixels": 512 * 512
                                },
                                {"type": "text", "text": suffix_pm}
                            ]
                        }
                    ]
                else:
                    raise Exception("Not Support Multiple Image")

                yield {
                    "suffix": index,
                    "conversation": conversation, 
                    "answer": answer
                }

class VideoDataLoder(BaseLoader):    
    
    def __init__(self, args):
        super().__init__(args)

        self.input_dir = args.input_dir

    def load_index(self):

        index_count = {}
        with open(self.input_dir, 'r') as f:
            index = 0
            for line in f:
                index += 1
                index_count[index] = 0
        return index_count

    def load_iter(self, filter_index={}):
        
        with open(self.input_dir, 'r') as f:
            
            index = 0
            for line in f:
                index += 1
                if index not in filter_index.keys():
                    continue
                
                json_item = json.loads(line)

                if len(json_item['videos']) == 1:

                    query = json_item['conversations'][0]['content'].replace("Video: <video>", "").strip()
                    if random.random() > 0.5:
                        prefix_pm, suffix_pm = build_omni_prompt_split(query, json_item['chosen']['content'], json_item['rejected']['content'])
                        answer = 0
                    else:
                        prefix_pm, suffix_pm = build_omni_prompt_split(query, json_item['rejected']['content'], json_item['chosen']['content'])
                        answer = 1
                    
                    conversation = [
                        {
                            'role': 'user', 
                            'content': [
                                {"type": "text", "text": prefix_pm},
                                {
                                    "type": "video",
                                    "video": json_item['videos'][0],
                                    "min_pixels": 256 * 256,
                                    "max_pixels": 256 * 256,
                                    "total_pixels": 24 * 256 * 256,
                                },
                                {"type": "text", "text": suffix_pm}
                            ]
                        }
                    ]
                elif len(json_item['videos']) == 2:

                    chosen_video = json_item['videos'][0]
                    rejected_video = json_item['videos'][1]

                    if random.random() > 0.5:
                        prefix_pm, inffix_pm, suffix_pm = build_omni_prompt_triple(json_item['conversations'][0]['content'], json_item['chosen']['content'], json_item['rejected']['content'])
                        answer = 0
                    else:
                        prefix_pm, inffix_pm, suffix_pm = build_omni_prompt_triple(json_item['conversations'][0]['content'], json_item['rejected']['content'], json_item['chosen']['content'])
                        answer = 1
                    
                    conversation = [
                        {
                            'role': 'user', 
                            'content': [
                                {"type": "text", "text": prefix_pm},
                                {
                                    "type": "video",
                                    "video": chosen_video if answer == 0 else rejected_video,
                                    "min_pixels": 256 * 256,
                                    "max_pixels": 256 * 256,
                                    "total_pixels": 24 * 256 * 256,
                                },
                                {"type": "text", "text": inffix_pm},
                                {
                                    "type": "video",
                                    "video": rejected_video if answer == 0 else chosen_video,
                                    "min_pixels": 256 * 256,
                                    "max_pixels": 256 * 256,
                                    "total_pixels": 24 * 256 * 256,
                                },
                                {"type": "text", "text": suffix_pm}
                            ]
                        }
                    ]
                else:
                    raise Exception("Not Support Multiple Videos")

                yield {
                    "suffix": index,
                    "conversation": conversation, 
                    "answer": answer
                }