from torch.utils.data import Dataset, DataLoader
import torch
import pandas as pd
from transformers.image_utils import load_image
import os
from transformers import AutoProcessor


def prefderence_data(preference_data_root):
    if isinstance(preference_data_root, list):
        all_data = []
        for p in preference_data_root:
            df = pd.read_json(os.path.join(p, 'preference_data.jsonl'))
            all_data.append(df)
        all_data = pd.concat(all_data)
    else:
        all_data = pd.read_json(os.path.join(preference_data_root, 'preference_data.jsonl'))

    return all_data['id0'].tolist()



class PreferenceData(Dataset):
    def __init__(self, preference_data_root, vlm_processor=None):
        super(PreferenceData, self).__init__()
        if isinstance(preference_data_root, list):
            all_data = pd.DataFrame()
            for p in preference_data_root:
                df = pd.read_json(os.path.join(p, 'preference_data.jsonl'))
                all_data = pd.concat([all_data, df])
        else:
            all_data = pd.read_json(os.path.join(preference_data_root, 'preference_data.jsonl'))
        # print(all_data)
        self.all_data = all_data['id0']

        self.vlm_processor = vlm_processor

    def __len__(self):
        return len(self.all_data)

    def __getitem__(self, index):
        image_path = self.all_data[index]['images']  # An image path
        prompt = self.all_data[index]['prompt'][:]   # A list

        chosen = {"role": "assistant", "content": [{"type": "text", "text": f"{self.all_data[index]['chosen']}"}]}
        rejected = {"role": "assistant", "content": [{"type": "text", "text": f"{self.all_data[index]['rejected']}"}]}

        prompt_chosen = prompt + [chosen]
        prompt_rejected = prompt + [rejected]

        return image_path, prompt, prompt_chosen, prompt_rejected

    def collate_fn(self, batch):
        images = [load_image(item[0]) for item in batch]
        prompt = [item[1] for item in batch]
        chosen = [item[2] for item in batch]
        rejected = [item[3] for item in batch]

        prompt_texts = [self.vlm_processor.apply_chat_template(
            p, tokenize=False, add_generation_prompt=True
        ) for p in prompt]

        chosen_texts = [self.vlm_processor.apply_chat_template(
            c, tokenize=False, add_generation_prompt=False
        ) for c in chosen]



        rejected_texts = [self.vlm_processor.apply_chat_template(
            r, tokenize=False, add_generation_prompt=False
        ) for r in rejected]


        base_inputs = self.vlm_processor(
            text=prompt_texts,
            images=images,
            # videos=video_inputs,
            padding=True,
            return_tensors="pt",
        )

        chosen_inputs = self.vlm_processor(
            text=chosen_texts,
            images=images,
            # videos=video_inputs,
            padding=True,
            return_tensors="pt",
        )

        rejected_inputs = self.vlm_processor(
            text=rejected_texts,
            images=images,
            # videos=video_inputs,
            padding=True,
            return_tensors="pt",
        )

        chosen_inputs['attention_mask'][:, base_inputs.input_ids.shape[-1]:] = 0
        rejected_inputs['attention_mask'][:, base_inputs.input_ids.shape[-1]:] = 0
        print(chosen_inputs['attention_mask'])
        return chosen_inputs, rejected_inputs
        # return {
        #     "images": images,
        #     "prompt": prompt_texts,
        #     "chosen": chosen_texts,
        #     "rejected": rejected_texts
        # }
if __name__ == '__main__':
    # min_pixels = 256 * 28 * 28
    # max_pixels = 512 * 28 * 28
    # vlm_processor = AutoProcessor.from_pretrained('Qwen2-VL-2B/Qwen2-VL-2B-Instruct', min_pixels=min_pixels, max_pixels=max_pixels)
    #
    # preference_data = PreferenceData('preference_json_data/animals', vlm_processor)
    #
    # for d in DataLoader(preference_data, collate_fn=preference_data.collate_fn, batch_size=1):
    #     print(d[0].attention_mask[0])
    #     break

    # all_data = pd.read_json(os.path.join('preference_json_data/animals_new', 'preference_data.jsonl'))
    # all_data = all_data['id0']

    print(prefderence_data(['preference_json_data/real_world',
                                                           'preference_json_data/mmbench',
                                                           'preference_json_data/mmstar',
                                                           'preference_json_data/seedbench',
                                                           'preference_json_data/science'])[0])






