import os
import json
from PIL import Image

import torch
from torch.utils.data import Dataset
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from torch.utils.data import Dataset, DataLoader


class SVQADataset(Dataset):
    def __init__(self, data_file, vis_root, image_size):
        # image file path: vis_root
        # attack set & eval set: data_file
        mean = (0.48145466, 0.4578275, 0.40821073)
        std = (0.26862954, 0.26130258, 0.27577711)

        self.normalize = transforms.Normalize(mean, std)
        self.data = json.load(open(data_file, 'r'))  # attack_set and eval_set
        self.vis_root = vis_root

        self.transform_processor = transforms.Compose([
            transforms.Resize(
                (image_size, image_size), interpolation=InterpolationMode.BICUBIC
            ),
            transforms.ToTensor(),
            # self.normalize, 
        ])

    def __len__(self):
        return len(self.data)
    

class DALLE3Dataset(SVQADataset):
    def __init__(self, data_file, vis_root='data/dalle3', image_size=None):
        super().__init__(data_file, vis_root, image_size)

    def __getitem__(self, index):
        item = self.data[list(self.data.keys())[index]]
        img_id = item['image']
        image_path = os.path.join(self.vis_root, img_id)
        image = Image.open(image_path).convert('RGB')
        image = self.transform_processor(image)

        text_input = item['text_input']
        # answer = item['answer_instructblip']  # list -> str

        item_dict = {
            "image": image,
            "text_input": text_input,
            # "answer_instructblip": answer,
            "image_path": img_id  # need path to store adv examples 
        }

        return item_dict


class SVITDataset(SVQADataset):
    def __init__(self, data_file, vis_root='data/svit/raw/', image_size=None):
        super().__init__(data_file, vis_root, image_size)

    def __getitem__(self, index):
        item = self.data[list(self.data.keys())[index]]
        img_id = item['image']
        image_path = os.path.join(self.vis_root, img_id)
        image = Image.open(image_path).convert('RGB')
        image = self.transform_processor(image)

        text_input = item['text_input']
        # answer = item['answer_instructblip']

        item_dict = {
            "image": image,
            "text_input": text_input,
            # "answer_instructblip": answer,
            "image_path": img_id  # need path to store adv examples 
        }

        return item_dict



class SCOCOVQADataset(SVQADataset):
    def __init__(self, data_file, vis_root='data/coco/images', image_size=None):
        super().__init__(data_file, vis_root, image_size)

    def __getitem__(self, index):
        item = self.data[list(self.data.keys())[index]]
        img_id = item['image']
        image_path = os.path.join(self.vis_root, img_id)
        image = Image.open(image_path).convert('RGB')
        image = self.transform_processor(image)

        text_input = item['text_input']
        # answer_llava = item['answer_llava']
        # answer_instructblip = item['answer_instructblip']

        item_dict = {
            "image": image,
            "text_input": text_input,
            # "answer_llava": asnswer_llava,
            # "answer_instructblip": answer_instructblip,
            "image_path": img_id  # need path to store adv examples 
        }

        return item_dict


class SCOCOVQADataset_old(SVQADataset):
    def __init__(self, data_file, vis_root='data/coco/images', image_size=None):
        super().__init__(data_file, vis_root, image_size)

        self.collate_fn = vqa_collate_fn

    def __getitem__(self, index):
        item = self.data[list(self.data.keys())[index]]
        img_path = item['image']
        image_path = os.path.join(self.vis_root, img_path)
        image = Image.open(image_path).convert('RGB')
        image = self.transform_processor(image)

        text_input = item['text_input']
        answers = item['answers']
        weights = item['weights']

        item_dict = {
            "image": image,
            "text_input": text_input,
            "answers": answers,
            "weights": weights,
            "image_path": img_path  # need path to store adv examples 
        }

        return item_dict


def vqa_collate_fn(batch):
    image_list, question_list, answer_list, weight_list, n_answers, path_list = [], [], [], [], [], []
    
    for item in batch:
        image_list.append(item['image'])
        question_list.append(item['text_input'])
        weight_list += item['weights']
        answer_list += item['answers']
        n_answers.append(len(item['answers']))
        path_list.append(item['image_path'])

    item_dict = {
        'image': torch.stack(image_list, dim=0),
        'text_input': question_list,
        'answer': answer_list,
        'weight': weight_list,
        'n_answers': n_answers, 
        "paths": path_list

    }

    return item_dict



def create_dataloader(dataset_class, data_file, image_size, batch_size, for_eval=False):
    dataset = dataset_class(data_file=data_file, image_size=image_size)
    
    # Check if the dataset has a custom collate_fn
    if hasattr(dataset, 'collate_fn'):
        loader = DataLoader(dataset, batch_size=(1 if for_eval else batch_size), 
                            shuffle=False, num_workers=2, collate_fn=dataset.collate_fn)
    else:
        # Using default collate_fn
        loader = DataLoader(dataset, batch_size=(1 if for_eval else batch_size), 
                            shuffle=False, num_workers=2)

    return loader


def get_dataloader(data_name, attack_set, eval_set, image_size, batch_size):
    # Choose Dataset based on data_name
    if data_name == 'coco_vqa':
        dataset_class = SCOCOVQADataset
    elif data_name == 'svit':
        dataset_class = SVITDataset
    elif data_name == 'dalle3':
        dataset_class = DALLE3Dataset
    else:
        raise ValueError("Invalid data name")

    # Create DataLoaders
    vqa_attack_dataloader = create_dataloader(dataset_class, attack_set, image_size, batch_size)
    vqa_attack_dataloader_for_eval = create_dataloader(dataset_class, attack_set, image_size, 1, for_eval=True)
    vqa_eval_dataloader = create_dataloader(dataset_class, eval_set, image_size, 1, for_eval=True)

    return vqa_attack_dataloader, vqa_attack_dataloader_for_eval, vqa_eval_dataloader