import json
from pathlib import Path

import torch
import torchvision

from .modulated_coco import ConvertCocoPolysToMask, ModulatedDataset


class GQADataset(ModulatedDataset):
    pass


class GQAQuestionAnswering(torchvision.datasets.CocoDetection):
    def __init__(self, img_folder, ann_file, transforms, return_masks, return_tokens, tokenizer, ann_folder):
        super(GQAQuestionAnswering, self).__init__(img_folder, ann_file)
        self._transforms = transforms
        self.prepare = ConvertCocoPolysToMask(return_masks, return_tokens, tokenizer=tokenizer)
        with open(ann_folder / "gqa_answer2id.json", "r") as f:
            self.answer2id = json.load(f)
        with open(ann_folder / "gqa_answer2id_by_type.json", "r") as f:
            self.answer2id_by_type = json.load(f)
        self.type2id = {"obj": 0, "attr": 1, "rel": 2, "global": 3, "cat": 4}

    def __getitem__(self, idx):
        img, target = super(GQAQuestionAnswering, self).__getitem__(idx)
        image_id = self.ids[idx]
        coco_img = self.coco.loadImgs(image_id)[0]
        caption = coco_img["caption"]
        dataset_name = coco_img["dataset_name"]
        questionId = coco_img["questionId"]
        target = {"image_id": image_id, "annotations": target, "caption": caption}
        img, target = self.prepare(img, target)
        if self._transforms is not None:
            img, target = self._transforms(img, target)
        target["dataset_name"] = dataset_name
        target["questionId"] = questionId

        if coco_img["answer"] not in self.answer2id:
            answer = "unknown"
        else:
            answer = coco_img["answer"]

        target["answer"] = torch.as_tensor(self.answer2id[answer], dtype=torch.long)
        target["answer_type"] = torch.as_tensor(self.type2id[coco_img["question_type"]], dtype=torch.long)

        if coco_img["answer"] not in self.answer2id_by_type["answer_attr"]:
            answer = "unknown"
        else:
            answer = coco_img["answer"]
        target["answer_attr"] = torch.as_tensor(
            self.answer2id_by_type["answer_attr"][answer] if coco_img["question_type"] == "attr" else -100,
            dtype=torch.long,
        )

        if coco_img["answer"] not in self.answer2id_by_type["answer_global"]:
            answer = "unknown"
        else:
            answer = coco_img["answer"]
        target["answer_global"] = torch.as_tensor(
            self.answer2id_by_type["answer_global"][answer] if coco_img["question_type"] == "global" else -100,
            dtype=torch.long,
        )

        if coco_img["answer"] not in self.answer2id_by_type["answer_rel"]:
            answer = "unknown"
        else:
            answer = coco_img["answer"]
        target["answer_rel"] = torch.as_tensor(
            self.answer2id_by_type["answer_rel"][answer] if coco_img["question_type"] == "rel" else -100,
            dtype=torch.long,
        )

        if coco_img["answer"] not in self.answer2id_by_type["answer_cat"]:
            answer = "unknown"
        else:
            answer = coco_img["answer"]
        target["answer_cat"] = torch.as_tensor(
            self.answer2id_by_type["answer_cat"][answer] if coco_img["question_type"] == "cat" else -100,
            dtype=torch.long,
        )

        if coco_img["answer"] not in self.answer2id_by_type["answer_obj"]:
            answer = "unknown"
        else:
            answer = coco_img["answer"]
        target["answer_obj"] = torch.as_tensor(
            self.answer2id_by_type["answer_obj"][answer] if coco_img["question_type"] == "obj" else -100,
            dtype=torch.long,
        )
        return img, target
