"""
 Copyright (c) 2022, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

import os
import json

from PIL import Image

from lavis.datasets.datasets.vqa_datasets import VQADataset, VQAEvalDataset
from lavis.datasets.datasets.dataloader_utils import insert_img_backdoor_image_captioning_eval, insert_img_backdoor_vqa

from collections import OrderedDict
import torch
import random








class __DisplMixin:
    def displ_item(self, index):
        sample, ann = self.__getitem__(index), self.annotation[index]

        return OrderedDict(
            {
                "file": ann["image"],
                "question": ann["question"],
                "question_id": ann["question_id"],
                "answers": "; ".join(ann["answer"]),
                "image": sample["image"],
            }
        )


class COCOVQADataset(VQADataset, __DisplMixin):
    def __init__(self, vis_processor, text_processor, vis_root, ann_paths, config):
        super().__init__(vis_processor, text_processor, vis_root, ann_paths)
        # self.annotation = self.annotation[:50] # debug

    def __getitem__(self, index):
        ann = self.annotation[index]

        image_path = os.path.join(self.vis_root, ann["image"])
        image = Image.open(image_path).convert("RGB")

        image = self.vis_processor(image)
        question = self.text_processor(ann["question"])

        answer_weight = {}
        # answer_num = {}
        # adding n_answers, when training, repeat inputs based on n_answers
        for answer in ann["answer"]:
            if answer in answer_weight.keys():
                answer_weight[answer] += 1 / len(ann["answer"])
                # answer_num[answer] += 1
            else:
                answer_weight[answer] = 1 / len(ann["answer"])
                # answer_num[answer] = 1

        answers = list(answer_weight.keys())
        weights = list(answer_weight.values())
        # n_answers = []
        # for item in answers:
        #     n_answers.append( answer_num[item] )
        # # n_answers = list(answer_num.values())

        return {
            "image": image,
            "text_input": question,
            "answers": answers,
            "weights": weights,
            # "n_answers": torch.LongTensor(n_answers),
        }



class COCOVQAEvalDataset(VQAEvalDataset, __DisplMixin):
    def __init__(self, vis_processor, text_processor, vis_root, ann_paths, config):
        """
        vis_root (string): Root directory of images (e.g. coco/images/)
        ann_root (string): directory to store the annotation file
        """

        self.vis_root = vis_root

        self.annotation = json.load(open(ann_paths[0]))
        # vqav2_opt_eval:
        # ann_paths[0]
        # '/data/NeurIPS24/VL/cache/coco/annotations/vqa_val_eval.json', 214354
        # ann_paths[1]
        # '/data/NeurIPS24/VL/cache/coco/annotations/answer_list.json', 3128
        # ann_paths[2]
        # '/data/NeurIPS24/VL/cache/coco/annotations/v2_OpenEnded_mscoco_val2014_questions.json'
        # ann_paths[3]
        # '/data/NeurIPS24/VL/cache/coco/annotations/v2_mscoco_val2014_annotations.json'
        # self.annotation = self.annotation[:50] # debug
        # '/data/NeurIPS24/VL/cache/coco/annotations/vqa_val.json', 214352



        answer_list_path = ann_paths[1]
        if os.path.exists(answer_list_path):
            self.answer_list = json.load(open(answer_list_path))
        else:
            self.answer_list = None

        try:
            self.coco_fmt_qust_file = ann_paths[2]
            self.coco_fmt_anno_file = ann_paths[3]
        except IndexError:
            self.coco_fmt_qust_file = None
            self.coco_fmt_anno_file = None

        self.vis_processor = vis_processor
        self.text_processor = text_processor

        self._add_instance_ids()

    def __getitem__(self, index):
        ann = self.annotation[index]

        image_path = os.path.join(self.vis_root, ann["image"])
        image = Image.open(image_path).convert("RGB")

        image = self.vis_processor(image)
        question = self.text_processor(ann["question"])

        return {
            "image": image,
            "text_input": question,
            "question_id": ann["question_id"],
            "instance_id": ann["instance_id"],
        }




