from torch.utils.data import Dataset
import json
import os
from PIL import Image

class VQADataset(Dataset):
    def __init__(
        self,
        image_dir_path,
        question_path,
        annotations_path,
        is_train,
        dataset_name="VQA",
    ):
        print("Loading VQA dataset...")
        print("Loading questions from", question_path)
        self.questions = json.load(open(question_path, "r"))["questions"]
        self.answers = json.load(open(annotations_path, "r"))["annotations"]
        self.image_dir_path = image_dir_path
        self.is_train = is_train
        self.dataset_name = dataset_name
        if self.dataset_name in {"vqav2", "ok-vqa"}:
            self.img_coco_split = self.image_dir_path.strip("/").split("/")[-1]
            assert self.img_coco_split in {"train2014", "val2014", "test2015"}

    def __len__(self):
        return len(self.questions)

    def get_img_path(self, question):
        if self.dataset_name in {"vqav2", "ok-vqa"}:
            return os.path.join(
                self.image_dir_path,
                f"COCO_{self.img_coco_split}_{question['image_id']:012d}.jpg"
                if self.is_train
                else f"COCO_{self.img_coco_split}_{question['image_id']:012d}.jpg",
            )
        elif self.dataset_name == "vizwiz":
            return os.path.join(self.image_dir_path, question["image_id"])
        elif self.dataset_name == "textvqa":
            return os.path.join(self.image_dir_path, f"{question['image_id']}.jpg")
        elif self.dataset_name == "docvqa":
            return os.path.join(self.image_dir_path, f"{question['image_id']}")
        else:
            raise Exception(f"Unknown VQA dataset {self.dataset_name}")

    def __getitem__(self, idx):
        question = self.questions[idx]
        answers = self.answers[idx]
        img_path = self.get_img_path(question)
        try:
            image = Image.open(img_path)
        except:
            print("Could not open image, return a white image", img_path)
            image = Image.new("RGB", (256, 256), (255, 255, 255))
        image.load()

        # Create a new blank image for debug
        # image = Image.new("RGB", (256, 256), (255, 255, 255))

        return {
            "image": image,
            "question": question["question"],
            "answers": [a["answer"] for a in answers["answers"]],
            "question_id": question["question_id"],
        }