# Copyright 2023 The Otter Team.
# All rights reserved.
# This source code is licensed under the Apache 2.0 license
# found in the LICENSE file in the root directory.

import base64
from io import BytesIO
import re
import contextlib
import os
import orjson
from pathlib import Path

import torch
from torchvision import transforms
from torchvision.transforms.functional import to_tensor
from PIL import ImageFile
from .transforms import *
from torch.utils.data import Dataset
from monai.transforms import (
    EnsureChannelFirstd,
    Compose,
    Resized,
    LoadImaged,
    Orientationd,
    RandSpatialCropd,
    ScaleIntensityRanged,
    ToTensord
)
from med_datasets.input_dataset import pad_or_cut_img_tensors
from med_datasets.data_util.mimic_cxr_utils import create_id2chexpert_dict
from otter.biovil_encoder import (
    load_image,
    get_cxr_bert,
    create_chest_xray_transform_for_inference,
    create_chest_xray_augmentation_for_training,
)
from med_datasets.data_util.mimic_cxr_utils import CATEGORIES

IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)

FLAMINGO_MEAN = [0.481, 0.458, 0.408]
FLAMINGO_STD = [0.269, 0.261, 0.276]

ImageFile.LOAD_TRUNCATED_IMAGES = True
ImageFile.MAX_IMAGE_PIXELS = None
Image.MAX_IMAGE_PIXELS = None


@contextlib.contextmanager
def random_seed(seed, *addl_seeds):
    """Context manager which seeds the NumPy PRNG with the specified seed and
    restores the state afterward"""
    if seed is None:
        yield
        return
    if len(addl_seeds) > 0:
        seed = int(hash((seed, *addl_seeds)) % 1e6)
    numpy_state = np.random.get_state()
    random_state = random.getstate()
    np.random.seed(seed)
    random.seed(seed)
    try:
        yield
    finally:
        np.random.set_state(numpy_state)
        random.setstate(random_state)


class ExpandChannels:
    """
    Transforms an image with one channel to an image with three channels by copying
    pixel intensities of the image along the 1st dimension.
    """

    def __call__(self, data: torch.Tensor) -> torch.Tensor:
        """
        :param data: Tensor of shape [1, H, W].
        :return: Tensor with channel copied three times, shape [3, H, W].
        """
        if data.shape[0] != 1:
            raise ValueError(f"Expected input of shape [1, H, W], found {data.shape}")
        return torch.repeat_interleave(data, 3, dim=0)


class MimicitDataset(Dataset):
    def __init__(
        self,
        args,
        cur_mimicit_path,
        cur_images_path,
        cur_train_config_path,
        is_test=False,
        # supported_data_types=["caption", "qa"],
    ):
        # super().__init__(args, is_test)

        self.max_src_length = args.max_src_length
        self.max_tgt_length = args.max_tgt_length
        self.max_length = self.max_src_length

        # MIMIC-CXR specific
        if args.dataset_type in ["mimic_cxr", "custom_2d"]:
            self.num_images_per_sample = 2
            self.max_length = 500
            if args.vision_encoder_type == "biovil":
                self.med_patch_image_size = 480
            elif args.vision_encoder_type == "unimedi2d":
                self.med_patch_image_size = 224
            else:
                raise ValueError
        elif args.dataset_type in ["bimcv_covid19", "custom_3d"]:
            self.num_images_per_sample = 32
            self.med_patch_image_size = 128
            self.max_length = 500

        self.args = args
        self.task_name = args.task
        self.is_test = is_test
        self.tokenizer = args.tokenizer

        self.seed = args.seed
        self.patch_image_size = args.patch_image_size
        # self.supported_data_types = supported_data_types

        self.epoch = 0

        scales = [(args.patch_image_size, args.patch_image_size)]

        # Unused in mimic cxr
        self.patch_resize_transform = transforms.Compose(
            [
                RandomResize(scales),
                transforms.CenterCrop(args.patch_image_size),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.ToTensor(),
                transforms.Normalize(mean=FLAMINGO_MEAN, std=FLAMINGO_STD),
            ]
        )

        self.mimicit_path = cur_mimicit_path
        self.images_path = cur_images_path
        self.train_config_path = cur_train_config_path

        assert os.path.exists(cur_mimicit_path), f"Error: The local mimicit_path {cur_mimicit_path} not exists!"

        assert os.path.exists(cur_images_path), f"Error: The local images_path {cur_images_path} not exists!"

        assert os.path.exists(cur_train_config_path), f"Error: The local train_config_path {cur_train_config_path} not exists!"

        # Load the dataset
        with open(self.mimicit_path, "rb") as f:
            self.dataset = orjson.loads(f.read())["data"]

        # Load the images
        with open(self.images_path, "rb") as f:
            self.images = orjson.loads(f.read())

        # Load the train_config
        with open(self.train_config_path, "rb") as f:
            self.train_config = orjson.loads(f.read())

        self.train_data_list = list(self.dataset.keys())
        print(f"Loading {len(self.train_data_list)} pairs from {self.mimicit_path}")

        self.bos_item = torch.LongTensor([args.tokenizer.bos_token_id])
        self.eos_item = torch.LongTensor([args.tokenizer.eos_token_id])
        self.bos_mask = torch.LongTensor([1])
        self.eos_mask = torch.LongTensor([1])

        if self.args.dataset_type in ["mimic_cxr", "custom_2d"]:
            if self.args.dataset_type == "mimic_cxr":
                self.id2label = create_id2chexpert_dict(args.chexpert_csv_path)
                self.biovil_tokenizer, _ = get_cxr_bert()
            if self.args.vision_encoder_type == "biovil":
                # Trasnform for BioVil encoder, values from biovil_encder/image/utils.py
                self.med_patch_resize_transform = create_chest_xray_transform_for_inference(
                    resize=512, center_crop_size=self.med_patch_image_size
                )
            elif self.args.vision_encoder_type == "unimedi2d":
                self.med_patch_resize_transform = transforms.Compose([
                    transforms.Resize((256,256)),  # Use function
                    transforms.CenterCrop((224, 224)),
                    transforms.ToTensor(),
                    ExpandChannels(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                ])
            # Augmentation values from https://arxiv.org/abs/2204.09817 Table E.1: Fine-tuning for Downstream Tasks, for mimic_cxr
            self.training_image_augmentation = create_chest_xray_augmentation_for_training(
                degree=45, shear=25, brightness=0.2, contrast=0.2, flip_p=0.5
            )
        elif self.args.dataset_type in ["bimcv_covid19", "custom_3d"] \
        and self.args.vision_encoder_type == "unimedi3d":
            # Trasnform for 3D unified medical encoder, from xiaoxuan
            self.med_patch_resize_transform = Compose([
                LoadImaged(keys=["image"]),
                EnsureChannelFirstd(keys=["image"]),
                Orientationd(keys=["image"], axcodes="RAS"),
                ScaleIntensityRanged(
                    keys=["image"], a_min=-1000, a_max=3000, b_min=0.0, b_max=1.0, clip=True
                ),
                Resized(keys="image", spatial_size=(160, 160, 64)),
                RandSpatialCropd(keys="image", roi_size=[128, 128, 32], random_size=False),
                ToTensord(keys=["image"])
            ])

        # Resize to fit in original flamingo encoder
        self.resize_transform = transforms.Resize(self.patch_image_size, antialias=True)


    def random_init_case(self, question):
        if len(question) == 0:
            return question

        first_letter = question[0]
        if random.choice([True, False]):
            first_letter = first_letter.upper()
        else:
            first_letter = first_letter.lower()

        return first_letter + question[1:]

    def pre_question(self, question, max_ques_words):
        question = question.lower().lstrip(",.!?*#:;~").replace("-", " ").replace("/", " ")
        question = self.random_init_case(question)

        question = re.sub(
            r"\s{2,}",
            " ",
            question,
        )
        question = question.rstrip("\n")
        question = question.strip(" ")

        # truncate question
        question_words = question.split(" ")
        if len(question_words) > max_ques_words:
            question = " ".join(question_words[:max_ques_words])

        return question

    def pre_answer(self, answer, max_ans_words):
        answer = re.sub(
            r"\s{2,}",
            " ",
            answer,
        )
        answer = answer.rstrip("\n")
        answer = answer.strip(" ")

        # truncate question
        return_answer = ""
        answers = answer.split(".")

        for _ in answers:
            if return_answer == "":
                cur_answer = _
            else:
                cur_answer = ".".join([return_answer, _])
            if len(cur_answer.split(" ")) <= max_ans_words:
                return_answer = cur_answer
            else:
                break

        if return_answer == "":
            answer_words = answer.split(" ")
            return_answer = " ".join(answer_words[:max_ans_words])
        else:
            if return_answer[-1] != "." and return_answer != answers:
                return_answer += "."

        return return_answer

    def pre_caption(self, caption, max_words):
        caption = caption.lower().lstrip(",.!?*#:;~").replace("-", " ").replace("/", " ").replace("<person>", "person")

        caption = re.sub(
            r"\s{2,}",
            " ",
            caption,
        )
        caption = caption.rstrip("\n")
        caption = caption.strip(" ")

        # truncate caption
        caption_words = caption.split(" ")
        if len(caption_words) > max_words:
            caption = " ".join(caption_words[:max_words])

        return caption

    def set_epoch(self, epoch, **unused):
        self.epoch = epoch

    def resample_frames(self, image_ids, resample_frames):
        indices = np.linspace(0, len(image_ids) - 1, resample_frames, dtype=int)
        image_ids = [image_ids[i] for i in indices]
        assert len(image_ids) == resample_frames
        return image_ids

    def process_llavar(self, instruction_id, instruction, answer, image_ids, in_context_example_ids):
        patch_images = torch.tensor([])
        all_texts = ""
        all_instruction_ids = in_context_example_ids + [instruction_id]
        # random.shuffle(all_instruction_ids)
        for cur_instruction_id in all_instruction_ids[:]:
            cur_instruction_image_id = self.dataset[cur_instruction_id]["image_ids"][0]
            cur_instruction = self.dataset[cur_instruction_id]["instruction"]
            cur_answer = self.dataset[cur_instruction_id]["answer"]
            cur_instruction = self.pre_question(cur_instruction, self.max_src_length)
            cur_answer = self.pre_answer(cur_answer, self.max_tgt_length)
            cur_text = f"User: {cur_instruction} GPT:<answer> {cur_answer}<|endofchunk|>"
            all_texts += cur_text

        all_texts = f"<image>{all_texts}"
        cur_image_id = self.dataset[cur_instruction_id]["image_ids"][0]
        cur_image = self.images[cur_image_id]
        cur_image = Image.open(BytesIO(base64.urlsafe_b64decode(cur_image))).convert("RGB")
        patch_images = self.patch_resize_transform(cur_image).unsqueeze(0).unsqueeze(0)
        return patch_images, all_texts  # incontext_text, query_text

    def process_llava(self, instruction_id, instruction, answer, image_ids, in_context_example_ids):
        patch_images = torch.tensor([])
        all_texts = ""
        all_instruction_ids = in_context_example_ids + [instruction_id]
        # random.shuffle(all_instruction_ids)
        if "CONV" in instruction_id:
            for cur_instruction_id in all_instruction_ids[:]:
                cur_instruction_image_id = self.dataset[cur_instruction_id]["image_ids"][0]
                cur_instruction = self.dataset[cur_instruction_id]["instruction"]
                cur_answer = self.dataset[cur_instruction_id]["answer"]
                cur_instruction = self.pre_question(cur_instruction, self.max_src_length)
                cur_answer = self.pre_answer(cur_answer, self.max_tgt_length)
                cur_text = f"User: {cur_instruction} GPT:<answer> {cur_answer}<|endofchunk|>"
                all_texts += cur_text

            all_texts = f"<image>{all_texts}"
            cur_image_id = self.dataset[cur_instruction_id]["image_ids"][0]
            cur_image = self.images[cur_image_id]
            cur_image = Image.open(BytesIO(base64.urlsafe_b64decode(cur_image))).convert("RGB")
            patch_images = self.patch_resize_transform(cur_image).unsqueeze(0).unsqueeze(0)
        else:
            for cur_instruction_id in all_instruction_ids[:]:
                cur_instruction_image_id = self.dataset[cur_instruction_id]["image_ids"][0]
                cur_instruction = self.dataset[cur_instruction_id]["instruction"]
                cur_answer = self.dataset[cur_instruction_id]["answer"]
                cur_image = self.images[cur_instruction_image_id]
                cur_image = Image.open(BytesIO(base64.urlsafe_b64decode(cur_image))).convert("RGB")
                cur_patch_image = self.patch_resize_transform(cur_image).unsqueeze(0).unsqueeze(0)
                if len(patch_images) == 0:
                    patch_images = cur_patch_image
                else:
                    patch_images = torch.cat((patch_images, cur_patch_image))
                cur_instruction = self.pre_question(cur_instruction, self.max_src_length)
                cur_answer = self.pre_answer(cur_answer, self.max_tgt_length)
                cur_text = f"<image>User: {cur_instruction} GPT:<answer> {cur_answer}<|endofchunk|>"
                all_texts += cur_text
        # <image>User: {cur_incontext_instruction} GPT:<answer> {cur_incontext_answer}<|endofchunk|><image>User: {instruction} GPT:<answer> {answer}<|endofchunk|>
        # incontext_text = "<image>User: What does this image descibe? GPT:<answer>The children in the image, along with the rest of the family. They are Skiing. <|endofchunk|>"
        # query_text = f"<image>User: What does this image descibe? GPT:<answer>"
        # query_text = f"<image>User: {instruction} GPT:<answer>"
        # print(instruction_id, query_text, answer)
        return patch_images, all_texts  # incontext_text, query_text

    def process_dense_caption(self, instruction_id, instruction, answer, image_ids, in_context_example_ids):
        patch_images = torch.tensor([])
        all_texts = ""
        all_instruction_ids = in_context_example_ids + [instruction_id]
        random.shuffle(all_instruction_ids)
        for cur_instruction_id in all_instruction_ids[:]:
            cur_instruction = self.dataset[cur_instruction_id]["instruction"]
            cur_instruction = self.pre_question(cur_instruction, self.max_src_length)
            cur_answer = self.dataset[cur_instruction_id]["answer"]
            cur_answer = self.pre_answer(cur_answer, self.max_tgt_length)
            cur_text = f"User: {cur_instruction} GPT:<answer> {cur_answer}<|endofchunk|>"
            all_texts += cur_text

        all_texts = f"<image>{all_texts}"
        # <image>User: {cur_incontext_instruction} GPT:<answer> {cur_incontext_answer}<|endofchunk|>User: {instruction} GPT:<answer> {answer}<|endofchunk|>
        # <image>User: what does the image describe? GPT: XXX <|endofchunk|>User: Do you think this image is funny GPT:<answer> YYY <|endofchunk|>
        for cur_image_id in image_ids:
            cur_image = self.images[cur_image_id]
            cur_image = Image.open(BytesIO(base64.urlsafe_b64decode(cur_image))).convert("RGB")
            cur_patch_image = self.patch_resize_transform(cur_image).unsqueeze(0)
            if len(patch_images) == 0:
                patch_images = cur_patch_image
            else:
                patch_images = torch.cat((patch_images, cur_patch_image))

        patch_images = patch_images.unsqueeze(0)
        return patch_images, all_texts

    def process_tv_caption(self, instruction_id, instruction, answer, image_ids, in_context_example_ids, resample_frames=16):
        patch_images = torch.tensor([])
        all_texts = ""
        all_instruction_ids = in_context_example_ids + [instruction_id]
        random.shuffle(all_instruction_ids)
        for cur_instruction_id in all_instruction_ids[:]:
            cur_instruction = self.dataset[cur_instruction_id]["instruction"]
            cur_instruction = self.pre_question(cur_instruction, self.max_src_length)
            cur_answer = self.dataset[cur_instruction_id]["answer"]
            cur_answer = self.pre_answer(cur_answer, self.max_tgt_length)
            cur_text = f"User: {cur_instruction} GPT:<answer> {cur_answer}<|endofchunk|>"
            all_texts += cur_text

        all_texts = f"<image>{all_texts}"
        # <image>User: {cur_incontext_instruction} GPT:<answer> {cur_incontext_answer}<|endofchunk|>User: {instruction} GPT:<answer> {answer}<|endofchunk|>
        # <image>User: what does the image describe? GPT: XXX <|endofchunk|>User: Do you think this image is funny GPT:<answer> YYY <|endofchunk|>

        # make sure the frames are evenly sampled to certain number to enable batch processing
        image_ids = self.resample_frames(image_ids, resample_frames)
        for cur_image_id in image_ids:
            cur_image = self.images[cur_image_id]
            cur_image = Image.open(BytesIO(base64.urlsafe_b64decode(cur_image))).convert("RGB")
            cur_patch_image = self.patch_resize_transform(cur_image).unsqueeze(0)
            if len(patch_images) == 0:
                patch_images = cur_patch_image
            else:
                patch_images = torch.cat((patch_images, cur_patch_image))

        patch_images = patch_images.unsqueeze(0)
        return patch_images, all_texts

    def process_e4d(self, instruction_id, instruction, answer, image_ids, in_context_example_ids, resample_frames=16):
        patch_images = torch.tensor([])
        all_texts = ""
        all_instruction_ids = in_context_example_ids + [instruction_id]
        random.shuffle(all_instruction_ids)
        for cur_instruction_id in all_instruction_ids[:]:
            cur_instruction = self.dataset[cur_instruction_id]["instruction"]
            cur_instruction = self.pre_question(cur_instruction, self.max_src_length)
            cur_answer = self.dataset[cur_instruction_id]["answer"]
            cur_answer = self.pre_answer(cur_answer, self.max_tgt_length)
            cur_text = f"User: {cur_instruction} GPT:<answer> {cur_answer}<|endofchunk|>"
            all_texts += cur_text

        all_texts = f"<image>{all_texts}"
        # <image>User: {cur_incontext_instruction} GPT:<answer> {cur_incontext_answer}<|endofchunk|>User: {instruction} GPT:<answer> {answer}<|endofchunk|>
        # <image>User: what does the image describe? GPT: XXX <|endofchunk|>User: Do you think this image is funny GPT:<answer> YYY <|endofchunk|>
        # make sure the frames are evenly sampled to certain number to enable batch processing
        image_ids = self.resample_frames(image_ids, resample_frames)
        for cur_image_id in image_ids:
            cur_image = self.images[cur_image_id]
            cur_image = Image.open(BytesIO(base64.urlsafe_b64decode(cur_image))).convert("RGB")
            cur_patch_image = self.patch_resize_transform(cur_image).unsqueeze(0)
            if len(patch_images) == 0:
                patch_images = cur_patch_image
            else:
                patch_images = torch.cat((patch_images, cur_patch_image))

        patch_images = patch_images.unsqueeze(0)
        return patch_images, all_texts

    def process_spot_the_difference(self, instruction_id, instruction, answer, image_ids, in_context_example_ids):
        patch_images = torch.tensor([])
        incontext_text = ""
        # <image>User: {instruction} GPT:<answer> {answer}<|endofchunk|>
        for cur_image_id in image_ids:
            cur_image = self.images[cur_image_id]
            cur_image = Image.open(BytesIO(base64.urlsafe_b64decode(cur_image))).convert("RGB")
            cur_patch_image = self.patch_resize_transform(cur_image).unsqueeze(0)
            if len(patch_images) == 0:
                patch_images = cur_patch_image
            else:
                patch_images = torch.cat((patch_images, cur_patch_image))

        patch_images = patch_images.unsqueeze(0)
        instruction = self.pre_question(instruction, self.max_src_length)
        answer = self.pre_answer(answer, self.max_tgt_length)
        query_text = f"<image>User: {instruction} GPT:<answer> {answer}<|endofchunk|>"
        all_texts = f"{incontext_text}{query_text}"
        return patch_images, all_texts

    def process_scene_navigation(self, instruction_id, instruction, answer, image_ids, in_context_example_ids):
        patch_images = torch.tensor([])
        incontext_text = ""
        for cur_incontext_id in in_context_example_ids:
            cur_incontext_instruction = self.dataset[cur_incontext_id]["instruction"]
            cur_incontext_instruction = self.pre_question(cur_incontext_instruction, self.max_src_length)
            cur_incontext_answer = self.dataset[cur_incontext_id]["answer"]
            cur_incontext_answer = self.pre_answer(cur_incontext_answer, self.max_tgt_length)
            cur_incontext_text = f"User: {cur_incontext_instruction} GPT:<answer> {cur_incontext_answer}<|endofchunk|>"
            incontext_text += cur_incontext_text

        incontext_text = f"<image>{incontext_text}"
        # <image>User: {cur_incontext_instruction} GPT:<answer> {cur_incontext_answer}<|endofchunk|>User: {instruction} GPT:<answer> {answer}<|endofchunk|>
        for cur_image_id in image_ids:
            cur_image = self.images[cur_image_id]
            cur_image = Image.open(BytesIO(base64.urlsafe_b64decode(cur_image))).convert("RGB")
            cur_patch_image = self.patch_resize_transform(cur_image).unsqueeze(0)
            if len(patch_images) == 0:
                patch_images = cur_patch_image
            else:
                patch_images = torch.cat((patch_images, cur_patch_image))

        patch_images = patch_images.unsqueeze(0)
        instruction = self.pre_question(instruction, self.max_src_length)
        answer = self.pre_answer(answer, self.max_tgt_length)
        query_text = f"User: {instruction} GPT:<answer> {answer}<|endofchunk|>"
        all_texts = f"{incontext_text}{all_texts}"
        return patch_images, all_texts

    def process_funqa(self, instruction_id, instruction, answer, image_ids, in_context_example_ids):
        patch_images = torch.tensor([])
        all_texts = ""
        all_instruction_ids = in_context_example_ids + [instruction_id]
        random.shuffle(all_instruction_ids)
        for cur_instruction_id in all_instruction_ids[:]:
            cur_instruction = self.dataset[cur_instruction_id]["instruction"]
            cur_instruction = self.pre_question(cur_instruction, self.max_src_length)
            cur_answer = self.dataset[cur_instruction_id]["answer"]
            cur_answer = self.pre_answer(cur_answer, self.max_tgt_length)
            cur_text = f"User: {cur_instruction} GPT:<answer> {cur_answer}<|endofchunk|>"
            all_texts += cur_text

        all_texts = f"<image>{all_texts}"
        # <image>User: {cur_incontext_instruction} GPT:<answer> {cur_incontext_answer}<|endofchunk|>User: {instruction} GPT:<answer> {answer}<|endofchunk|>
        # <image>User: what does the image describe? GPT: XXX <|endofchunk|>User: Do you think this image is funny GPT:<answer> YYY <|endofchunk|>
        for cur_image_id in image_ids:
            cur_image = self.images[cur_image_id]
            cur_image = Image.open(BytesIO(base64.urlsafe_b64decode(cur_image))).convert("RGB")
            cur_patch_image = self.patch_resize_transform(cur_image).unsqueeze(0)
            if len(patch_images) == 0:
                patch_images = cur_patch_image
            else:
                patch_images = torch.cat((patch_images, cur_patch_image))

        patch_images = patch_images.unsqueeze(0)
        return patch_images, all_texts

    def process_general_vqa(self, instruction_id, instruction, answer, image_ids, in_context_example_ids):
        patch_images = torch.tensor([])
        all_texts = ""
        all_instruction_ids = in_context_example_ids + [instruction_id]
        for cur_instruction_id in all_instruction_ids[:]:
            cur_instruction_image_id = (
                self.dataset[cur_instruction_id]["image_ids"][0]
                if isinstance(self.dataset[cur_instruction_id]["image_ids"], list)
                else self.dataset[cur_instruction_id]["image_ids"]
            )
            cur_instruction = self.dataset[cur_instruction_id]["instruction"]
            cur_answer = self.dataset[cur_instruction_id]["answer"]
            cur_image = self.images[cur_instruction_image_id]
            try:
                cur_image = Image.open(BytesIO(base64.urlsafe_b64decode(cur_image))).convert("RGB")
            except:
                print(cur_instruction_id)
                exit()
            cur_patch_image = self.patch_resize_transform(cur_image).unsqueeze(0).unsqueeze(0)
            if len(patch_images) == 0:
                patch_images = cur_patch_image
            else:
                patch_images = torch.cat((patch_images, cur_patch_image))
            cur_instruction = self.pre_question(cur_instruction, self.max_src_length)
            cur_answer = self.pre_answer(cur_answer, self.max_tgt_length)
            cur_text = f"<image>User: {cur_instruction} GPT:<answer> {cur_answer}<|endofchunk|>"
            all_texts += cur_text
        return patch_images, all_texts

    def cls_to_query_text(self, cls_label):
        positives = [t for c, t in zip(cls_label, CATEGORIES) if c]
        if not positives:  # for inference phase, if there are results with empty positive preds
            positives = ["no finding"]
        query_text = ', '.join(positives).lower()
        return query_text

    def process_mimic_cxr(self, instruction_id, instruction, answer, image_ids, in_context_example_ids):
        patch_images, med_patch_images, orig_patch_images, noaug_patch_images = [], [], [], []
        all_texts = ""
        all_instruction_ids = in_context_example_ids + [instruction_id]
        for cur_instruction_id in all_instruction_ids[:]:
            cur_instruction_image_ids = self.dataset[cur_instruction_id]["image_ids"]
            cur_instruction = self.dataset[cur_instruction_id]["instruction"]
            cur_answer = self.dataset[cur_instruction_id]["answer"]
            for image_id in cur_instruction_image_ids:
                image_path = Path(self.images[image_id])
                if not image_path.is_absolute():
                    image_path = Path(self.args.mimicit_path).parent / self.images[image_id]
                orig_image = load_image(image_path)
                image = self.med_patch_resize_transform(orig_image)  # biovil: (3,480,480)
                noaug_image = torch.clone(image)
                orig_image = to_tensor(orig_image)
                if not self.is_test:
                    image = self.training_image_augmentation(image)
                med_patch_images.append(image)
                orig_patch_images.append(orig_image)
                noaug_patch_images.append(noaug_image)
                patch_images.append(self.resize_transform(image))
            cur_instruction = self.pre_question(cur_instruction, self.max_src_length)
            cur_answer = self.pre_answer(cur_answer, self.max_tgt_length)
            cur_text = f"<image>User: {cur_instruction} GPT:<answer> {cur_answer}<|endofchunk|>"
            all_texts += cur_text
        patch_images = torch.stack(patch_images)  # (T,C,H,W)
        med_patch_images = torch.stack(med_patch_images)
        noaug_patch_images = torch.stack(noaug_patch_images)

        chexpert_labels = None
        query_text = None
        if hasattr(self, "id2label"):
            under_score_idx = instruction_id.find('_')
            if under_score_idx > 0: instruction_id = instruction_id[:under_score_idx]
            chexpert_labels = self.id2label.get(instruction_id, None)
            if chexpert_labels is not None:
                chexpert_labels = torch.Tensor(chexpert_labels)
                query_text = self.cls_to_query_text(chexpert_labels)

        if query_text is not None:
            query_text = self.biovil_tokenizer.batch_encode_plus(
                batch_text_or_text_pairs=[query_text],
                add_special_tokens=True,
                padding='longest',
                return_tensors='pt'
            ).input_ids[0]
        return patch_images, med_patch_images, noaug_patch_images, orig_patch_images, chexpert_labels, query_text, all_texts

    def process_bimcv_covid19(self, instruction_id, instruction, answer, image_ids, in_context_example_ids):
        patch_images, med_patch_images = [], []
        all_texts = ""
        all_instruction_ids = in_context_example_ids + [instruction_id]
        for cur_instruction_id in all_instruction_ids[:]:
            cur_instruction_image_ids = self.dataset[cur_instruction_id]["image_ids"]
            cur_instruction = self.dataset[cur_instruction_id]["instruction"]
            cur_answer = self.dataset[cur_instruction_id]["answer"]
            for image_id in cur_instruction_image_ids:
                image_path = Path(self.mimicit_path).parent / self.images[image_id]
                med_patch_image = self.med_patch_resize_transform({"image": image_path})["image"]
                med_patch_image = torch.permute(med_patch_image, (3, 0, 1, 2))
                patch_image = self.resize_transform(med_patch_image)
                patch_image = torch.repeat_interleave(patch_image, 3, dim=1)
                med_patch_images.append(med_patch_image)
                patch_images.append(patch_image)
            cur_instruction = self.pre_question(cur_instruction, self.max_src_length)
            cur_answer = self.pre_answer(cur_answer, self.max_tgt_length)
            cur_text = f"<image>User: {cur_instruction} GPT:<answer> {cur_answer}<|endofchunk|>"
            all_texts += cur_text
        patch_images = torch.cat(patch_images)  # (T,C,H,W)
        med_patch_images = torch.cat(med_patch_images)
        return patch_images, med_patch_images, all_texts

    def process_image_text_pair(self, index):
        cur_train_id = self.train_data_list[index]
        (
            instruction_id,
            instruction,
            answer,
            image_ids,
            in_context_example_ids,
        ) = (
            cur_train_id,
            self.dataset[cur_train_id]["instruction"],
            self.dataset[cur_train_id]["answer"],
            self.dataset[cur_train_id]["image_ids"],
            self.train_config[cur_train_id],
        )

        # self.max_src_length = self.max_tgt_length = 256

        if cur_train_id.startswith("LA"):
            patch_images, all_texts = self.process_llava(instruction_id, instruction, answer, image_ids, in_context_example_ids)
        elif cur_train_id.startswith("DC"):
            patch_images, all_texts = self.process_dense_caption(instruction_id, instruction, answer, image_ids, in_context_example_ids)
        elif cur_train_id.startswith("TVC"):
            patch_images, all_texts = self.process_tv_caption(instruction_id, instruction, answer, image_ids, in_context_example_ids)
        elif cur_train_id.startswith("E4D"):
            patch_images, all_texts = self.process_e4d(instruction_id, instruction, answer, image_ids, in_context_example_ids)
        elif cur_train_id.startswith("SD"):
            patch_images, all_texts = self.process_spot_the_difference(instruction_id, instruction, answer, image_ids, in_context_example_ids)
        elif cur_train_id.startswith("SN"):
            patch_images, all_texts = self.process_scene_navigation(instruction_id, instruction, answer, image_ids, in_context_example_ids)
        elif cur_train_id.startswith("FunQA"):
            patch_images, all_texts = self.process_funqa(instruction_id, instruction, answer, image_ids, in_context_example_ids)
        elif cur_train_id.startswith("LLAVAR"):
            patch_images, all_texts = self.process_llavar(instruction_id, instruction, answer, image_ids, in_context_example_ids)
        else:
            if self.args.dataset_type in ["mimic_cxr","custom_2d"]:
                patch_images, med_patch_images, noaug_patch_images, orig_patch_images, chexpert_labels, query_text, all_texts = self.process_mimic_cxr(instruction_id, instruction, answer, image_ids, in_context_example_ids)
            elif self.args.dataset_type in ["bimcv_covid19", "custom_3d"]:
                patch_images, med_patch_images, all_texts = self.process_bimcv_covid19(instruction_id, instruction, answer, image_ids, in_context_example_ids)
            else:
                patch_images, all_texts = self.process_general_vqa(instruction_id, instruction, answer, image_ids, in_context_example_ids)

        src_text = self.tokenizer(
            f"{all_texts}",
            return_tensors="pt",
            add_special_tokens=False,
            max_length=self.max_length
        )

        src_item = src_text["input_ids"].squeeze(0)
        src_item_mask = src_text["attention_mask"].squeeze(0)

        src_item = torch.cat([self.bos_item, src_item, self.eos_item])
        src_item_mask = torch.cat([self.bos_mask, src_item_mask, self.eos_mask])
        # src_item = torch.cat([self.bos_item, src_item])
        # src_item_mask = torch.cat([self.bos_mask, src_item_mask])

        example = {
            "id": instruction_id,
            "source": src_item,
            "text_mask": src_item_mask,
            "patch_images": patch_images,
        }

        if self.args.dataset_type in ["mimic_cxr", "bimcv_covid19", "custom_3d", "custom_2d"] \
        and self.args.vision_encode_mode != "original":
            example["med_patch_images"] = med_patch_images

        if self.args.dataset_type == "mimic_cxr":
            example["chexpert_labels"] = chexpert_labels
            example["orig_patch_images"] = orig_patch_images
            example["query_text"] = query_text
            example["query_mask"] = torch.ones(query_text.shape, dtype=int) if query_text is not None else None
            example["noaug_patch_images"] = noaug_patch_images

        return example

    def __str__(self):
        return f"type: {type(self)}, length: {len(self)}"

    def __len__(self):
        return len(self.train_data_list)

    def __getitem__(self, index):
        with random_seed(self.seed, self.epoch):
            pair_sample = self.process_image_text_pair(index)
            # if dataset is not supported
            if pair_sample is None:
                return self.__getitem__(index + 1)
        return pair_sample

    def collate(self, samples):
        """Merge samples of different tasks to form two mini-batches.
        Args:
            samples (List[Tuple]): samples to collate
        Returns:
            Tuple[dict]: two mini-batch containing the data of different tasks
        """
        if self.args.dataset_type == "mimic_cxr":
            for sample in samples:
                sample["patch_images"] = pad_or_cut_img_tensors(
                    sample["patch_images"],
                    self.patch_image_size,
                    self.num_images_per_sample
                )
                if "med_patch_images" in sample.keys():
                    sample["med_patch_images"] = pad_or_cut_img_tensors(
                        sample["med_patch_images"],
                        self.med_patch_image_size,
                        self.num_images_per_sample
                    )
                if "noaug_patch_images" in sample.keys():
                    sample["noaug_patch_images"] = pad_or_cut_img_tensors(
                        sample["noaug_patch_images"],
                        self.med_patch_image_size,
                        self.num_images_per_sample
                    )
                if "orig_patch_images" in sample.keys():
                    while len(sample["orig_patch_images"]) < self.num_images_per_sample:
                        zero_padding_image = torch.zeros_like(sample["orig_patch_images"][-1])
                        sample["orig_patch_images"].append(zero_padding_image)
                    if len(sample["orig_patch_images"]) > self.num_images_per_sample:
                        sample["orig_patch_images"] = sample["orig_patch_images"][:self.num_images_per_sample]

        samples_v1 = []  # containing image-text pairs
        for sample_tuple in samples:
            samples_v1.append(sample_tuple)

        res_v1 = collate_fn(
            samples_v1,
            pad_idx=self.tokenizer.pad_token_id,
            eos_idx=self.tokenizer.eos_token_id,
        )
        return res_v1


def collate_fn(samples, pad_idx, eos_idx):
    if len(samples) == 0:
        return {}

    def merge(key, pad_idx, pading_size=None):
        res = collate_tokens(
            [s[key] for s in samples],
            pad_idx,
            eos_idx=eos_idx,
            pad_to_length=pading_size,
        )
        return res

    larger_size = max([s["source"].size(0) for s in samples])

    id = np.array([s["id"] for s in samples])
    src_tokens = merge("source", pad_idx=pad_idx, pading_size=larger_size)
    src_tokens_masks = merge("text_mask", pad_idx=0, pading_size=larger_size)

    batch = {
        "id": id,
        "nsentences": len(samples),
        "net_input": {
            "input_ids": src_tokens,
            "attention_masks": src_tokens_masks,
        },
    }
    # import pdb;pdb.set_trace()
    larger_incontext_num = max([s["patch_images"].size(0) for s in samples])
    # import pdb;pdb.set_trace()
    if samples[0].get("patch_images", None) is not None:
        batch["net_input"]["patch_images"] = torch.stack([sample["patch_images"] for sample in samples], dim=0)
    if samples[0].get("med_patch_images", None) is not None:
        batch["net_input"]["med_patch_images"] = torch.stack([sample["med_patch_images"] for sample in samples], dim=0)
    if samples[0].get("noaug_patch_images", None) is not None:
        batch["net_input"]["noaug_patch_images"] = torch.stack([sample["noaug_patch_images"] for sample in samples], dim=0)
    # if samples[0].get("chexpert_labels", None) is not None:
    #     batch["net_input"]["chexpert_labels"] = torch.stack([sample["chexpert_labels"] for sample in samples], dim=0)
    if samples[0].get("orig_patch_images", None) is not None:
        batch["net_input"]["orig_patch_images"] = [sample["orig_patch_images"] for sample in samples]
    # if samples[0].get("query_text", None) is not None:
    #     larger_size = max([s["query_text"].size(0) for s in samples])
    #     batch["net_input"]["query_ids"] = merge("query_text", pad_idx=0, pading_size=larger_size)
    #     batch["net_input"]["query_masks"] = merge("query_mask", pad_idx=0, pading_size=larger_size)

    has_label = True
    for sample in samples:
        if sample.get("chexpert_labels", None) is None or \
        sample.get("query_text", None) is None or \
        sample.get("query_mask", None) is None:
            has_label = False
    if has_label:
        larger_size = max([s["query_text"].size(0) for s in samples])
        batch["net_input"]["query_ids"] = merge("query_text", pad_idx=0, pading_size=larger_size)
        batch["net_input"]["query_masks"] = merge("query_mask", pad_idx=0, pading_size=larger_size)
        batch["net_input"]["chexpert_labels"] = torch.stack([sample["chexpert_labels"] for sample in samples], dim=0)


    return batch


def collate_tokens(
    values,
    pad_idx,
    eos_idx=None,
    left_pad=False,
    move_eos_to_beginning=False,
    pad_to_length=None,
    pad_to_multiple=1,
    pad_to_bsz=None,
):
    """Convert a list of 1d tensors into a padded 2d tensor."""
    size = max(v.size(0) for v in values)
    size = size if pad_to_length is None else max(size, pad_to_length)
    if pad_to_multiple != 1 and size % pad_to_multiple != 0:
        size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple)

    def copy_tensor(src, dst):
        assert dst.numel() == src.numel()
        if move_eos_to_beginning:
            if eos_idx is None:
                # if no eos_idx is specified, then use the last token in src
                dst[0] = src[-1]
            else:
                dst[0] = eos_idx
            dst[1:] = src[:-1]
        else:
            dst.copy_(src)

    if values[0].dim() == 1:
        res = values[0].new(len(values), size).fill_(pad_idx)
    elif values[0].dim() == 2:
        assert move_eos_to_beginning is False
        res = values[0].new(len(values), size, values[0].size(1)).fill_(pad_idx)
    else:
        raise NotImplementedError

    for i, v in enumerate(values):
        copy_tensor(v, res[i][size - len(v) :] if left_pad else res[i][: len(v)])
    return res
