import os.path

from torch.utils.data import random_split
import torch
import json
from transformers import Qwen2VLProcessor
# from torch.utils.data import Dataset
from PIL import Image
from qwen_vl_utils import process_vision_info

class DetectionDataset(torch.utils.data.Dataset):
    def __init__(self, data_root, prompter, images_root=None):
        self.data = json.load(open(data_root, "r"))
        self.images_root = images_root
        self.prompter = prompter

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        if not 'image' in item.keys():
            question, answer, label = item['instruction'], item['answer'], item['label']
            text = self.prompter(question, answer)
            return text, label
        else:
            image, question, answer, label = item['image'], item['instruction'], item['answer'], item['label']
            text = self.prompter(question, answer)
            image = os.path.join(self.images_root, image)
        return image, text, label


class DataCollator:
    def __init__(self, processor, max_length=2048):
        self.processor = processor
        self.max_length = max_length

    def __call__(self, batch):
        if len(batch[0]) == 3:
            # if isinstance(self.processor, Qwen2VLProcessor):
            #     messages= []
            #     images, texts, labels = list(zip(*batch))
            #     for image, text in zip(images, texts):
            #         message = [
            #             {
            #                 "role": "user",
            #                 "content": [
            #                     {
            #                         "type": "image",
            #                         "image": image,
            #                     },
            #                     {"type": "text", "text": text},
            #                 ],
            #             }
            #         ]
            #         messages.append(message)
            #     images,_ = process_vision_info(messages)
            #     inputs = self.processor(text=texts, images=images, return_tensors="pt", max_length=self.max_length,
            #                             padding=True, truncation=True)
            #     inputs['labels'] = torch.tensor(labels)
            #
            # else:
            images, texts, labels = list(zip(*batch))
            images = [Image.open(image) for image in images]
            # texts = [text for text in texts]
            inputs = self.processor(images=images, text=texts, return_tensors="pt", max_length=self.max_length,
                                    padding=True, truncation=True)
            inputs['labels'] = torch.tensor(labels)

        elif len(batch[0]) == 2:
            texts, labels = list(zip(*batch))
            inputs = self.processor(text=texts, return_tensors="pt", max_length=self.max_length,
                                    padding=True, truncation=True)
            inputs['labels'] = torch.tensor(labels)
        return inputs


def construct_detection_dataset(data_root, train_size=0.8, prompter=None, image_root=None):
    dataset = DetectionDataset(data_root, prompter=prompter, images_root=image_root)
    if train_size > 1 or train_size < 0:
        raise ValueError("train_size should be in the range of [0, 1]")
    elif train_size == 1:
        return dataset, None
    train_size = int(train_size * len(dataset))  # 80% 数据用于训练
    test_size = len(dataset) - train_size

    train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
    return train_dataset, test_dataset


class Adapt_Dataset(torch.utils.data.Dataset):
    def __init__(self, dataset_path, attack=False):
        self.dataset = json.load(open(dataset_path))
        self.attack = attack

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        sample = self.dataset[idx]
        return sample, 1 if self.attack else 0


class Adapt_VQA_Dataset(torch.utils.data.Dataset):
    def __init__(self, dataset_path, image_root, attack=False):
        self.dataset = json.load(open(dataset_path))
        self.image_root = image_root
        self.attack = attack

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        sample = self.dataset[idx]
        image, question = sample['image'], sample['instruction'] if 'instruction' in sample.keys() else sample[
            'question']
        image = os.path.join(self.image_root, image)
        return image, question, 1 if self.attack else 0


class QA_Dataset(torch.utils.data.Dataset):
    def __init__(self, dataset_path, prompter=None):
        self.dataset = json.load(open(dataset_path))
        self.prompter = prompter

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        sample = self.dataset[idx]
        instruction, answer = sample['instruction'], sample['answer']
        return instruction, answer

class VQA_Dataset(torch.utils.data.Dataset):
    def __init__(self, dataset_path, image_root, prompter=None):
        self.dataset = json.load(open(dataset_path))
        self.image_root = image_root
        self.prompter = prompter

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        sample = self.dataset[idx]
        image, instruction, answer = sample['image'], sample['question'], sample['answer']
        image = os.path.join(self.image_root, image)
        return image, instruction, answer
