"""
Dataset classes and utilities for the V2 SAM data. The V2 dataset is composed of
image, segmentation mask pairs pertaining to medically relevant image artifacts
found in various modalities (CT, mammograms, ultrasounds, etc). Medical artifacts
(organs, clinically significant findings) are linked to their corresponding UMLS
terms stored in a separate directory. This is very similar to the v1 data, except
that here we are also generating bounding boxes as prompts for SAM.
"""

from glob import glob
import json
import os
import os.path as osp

from data import velcro_v1_utils as m_utils
from joblib import delayed
from joblib import Parallel
from lightning import LightningDataModule
import numpy as np
import pandas as pd
from PIL import Image
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torchvision.transforms import v2
from torchvision.transforms import PILToTensor, ToPILImage
import torchvision.tv_tensors as tv
from tqdm import tqdm
from transformers import AutoProcessor, AutoImageProcessor
from transformers import AutoTokenizer
from transformers import BatchEncoding
from utils import RankedLogger
from skimage.measure import regionprops
from skimage.measure import label
from torch.nn.utils.rnn import pad_sequence
from itertools import groupby
from torch.utils.data import WeightedRandomSampler
from torch.utils.data import DistributedSampler
import random
from catalyst.data.sampler import DistributedSamplerWrapper
from skimage.measure import regionprops
import itertools

tqdm.pandas()

logger = RankedLogger(__name__)

class VELCRODataModule(LightningDataModule):
    """DataModule containing processed train/val/test dataloaders for our
    dataset.

    This class handles the loading, splitting, and pre-processing of the dataset.

    Params:
        data_dir (str): The directory containing the raw dataset files.
        tensor_dir (str): The directory to save the processed tensors to.
        image_dir (str): The directory containing the images.
        train_batch_size (int): The total batch size for training. Must be
            divisible by the number of GPUs.
        test_batch_size (int): The total batch size for testing. Must be
            divisible by the number of GPUs.
        train_val_test_split (tuple[float, float, float]): A tuple containing the
            percentage split between train, val, and test datasets.
        num_workers (int): The number of workers to use for data loading.
        force_remake (bool): Whether to force remake the dataset cache. Relevant
            if the dataset is configured to cache to disk.
        pin_memory (bool): Whether to pin batches in GPU memory in the dataloader.
            This helps with performance on GPU, but can cause issues with large
            datasets.
        sam_tokenizer_path (str): The path to the huggingface image model to use
            for tokenization.
        text_model_path (str): The path to the huggingface text model to use for
            tokenization.
        debug (bool): Whether to run in debug mode. In debug mode, only a subset
            of the dataset (first 25 rows) will be loaded. Default is False.
            Currently not implemented, so will raise an error if True.
        from_mem (bool): Whether to load all data into memory. This can
            speed up training, but requires more memory.
        prefetch_factor (int): Number of samples loaded in advance by each worker.
        persistent_workers (bool): Whether to keep workers alive between epochs.
    """

    # Datasets are loaded in lazily during "setup" to assist with DDP
    _train_dataset: Dataset | None = None
    _val_dataset: Dataset | None = None
    _test_dataset: Dataset | None = None

    _train_device_batch_size: int = 1
    _test_device_batch_size: int = 1

    def __init__(
        self,
        data_dir: str,
        tensor_dir: str,
        image_dir: str,
        train_batch_size: int,
        test_batch_size: int,
        train_val_test_split: tuple[float, float, float],
        num_workers: int,
        force_remake: bool,
        pin_memory: bool,
        sam_tokenizer_path: str,
        text_model_path: str,
        debug: bool = False,
        from_mem: bool = False,
        prefetch_factor: int = 2,
        persistent_workers: bool = True,
    ):
        assert (
            sum(train_val_test_split) == 1.0
        ), f"Train/val/test split must sum to 1.0. Got {train_val_test_split=}"

        if debug:
            raise Exception(
                "Feature not implemented. Please switch debug mode to False."
            )

        # this line allows to access init params with 'self.hparams' attribute
        # also ensures init params will be stored in ckpt
        self.save_hyperparameters(logger=False)
        super().__init__()

    def prepare_data(self):
        """Prepare the data for the dataset.

        This is only called once on the rank 0 gpu per run, and results in
        memory are not replicated across gpus. This is useful for downloading.
        """
        data_dir = os.path.join(self.hparams.data_dir, "v1")

        tensor_dir = self.hparams.tensor_dir
        image_dir = self.hparams.image_dir

        assert isinstance(data_dir, str)
        assert isinstance(tensor_dir, str)
        assert isinstance(image_dir, str)

        # TODO(XXXX-1): Change osp.exists check to ensure we have exactly the
        # correct number of processed files, rather then just check if the
        # directory exists.
        if not self.hparams.force_remake and osp.exists(tensor_dir):
            logger.warning(
                f"Skipping data preparation"
                f"({not self.hparams.force_remake=} or {osp.exists(tensor_dir)=})"
            )
            return

        if self.hparams.force_remake and glob(tensor_dir + "/*.pt"):
            logger.warning(
                f"Removing existing tensor directory: {tensor_dir}"
                f"({self.hparams.force_remake=} and"
                f"{list(glob(tensor_dir + '/*.pt'))[:5]=})"
            )
            for file in tqdm(glob(tensor_dir + "/*.pt")):
                os.remove(file)

        logger.info(
            "Preparing data... "
            f"({self.hparams.force_remake=} or {not osp.exists(tensor_dir)=})"
        )

        # get umls master dict:
        umls_path = data_dir
        with open(umls_path + "/" + "UMLS_formatted.json") as json_file:
            umls_terms = json.load(json_file)

        # Tokenize the UMLS terms
        text_tokenizer = AutoTokenizer.from_pretrained(
            self.hparams.text_model_path
        )
        # We tokenize all the terms together so that we don't have to worry about
        # padding issues when we batch the data. That is, it will automatically
        # pad all the terms to be the same length (the largest sequence).
        umls_text = [x["desc"] for x in umls_terms.values()]

        # For entity description ablation studies, we can replace the UMLS description
        # with a simple caption of the form "An image of {entity}".
        # TODO: Make this a datamodule argument.
        #caption_ablation = [f'An image of {i.split("[BODY]")[0].split("[TITLE]")[1]}' for i in umls_text]
        #umls_text = caption_ablation
        tokenized_umls = text_tokenizer(
            umls_text, return_tensors="pt", padding=True
        )
        assert isinstance(tokenized_umls, BatchEncoding)
        expanded_umls_values = [
            dict(zip(tokenized_umls.keys(), values))
            for values in zip(*tokenized_umls.values())
        ]

        for values, tokenized in zip(umls_terms.values(), expanded_umls_values):
            values["desc"] = tokenized
            values["idx"] = torch.tensor(values["idx"])

        master_files = []
        folders = []

        # TODO(XXXX-2): implement logic to isolate multi-concept
        # masks and split into single-concept masks. This may require
        # pulling the original datasets and performing manual preprocessing.
        # For now, all multi-concept datasets have been removed from the
        # v1 dataset directory.
        img_mask_path = os.path.join(data_dir, "ground_truths")
        for root, _, files in os.walk(img_mask_path):
            for file in files:
                if file.endswith(".npz"):
                    folders.append(os.path.basename(root))
                    master_files.append(os.path.join(root, file))
        mega = pd.DataFrame({"File": master_files, "Dataset": folders})

        mega["index"] = 1
        mega["index"] = mega["index"].cumsum() - 1  # 0, 1, 2, 3 etc.

        # A dictionary to assist with mapping UMLS terms to dataset instances.
        # Mapping is performed on a per-dataset basis to make adding new datasets
        # easier.
        with open(umls_path + "/" + "dataset_directory.json") as json_file:
            term_mapping = json.load(json_file)

        os.makedirs(tensor_dir, exist_ok=True)
        os.makedirs(tensor_dir + "/masks", exist_ok=True)
        os.makedirs(tensor_dir + "/annotations", exist_ok=True)
        os.makedirs(tensor_dir + "/images", exist_ok=True)
        self.removed_masks = 0
        self.num_entities_per_case = []
        self.num_entities_per_slice = []

        self.data_modality_distribs = {}
        # Function for resizing and processing masks to convert them into tensors.
        def process(row):
            if (
                os.path.exists(tensor_dir + "/" + str(row.index) + "-0.pt")
                and not self.hparams.force_remake
            ):
                return

            index = str(row.index)
            dataset = row.Dataset
            packed_data = np.load(row.File)
            img = packed_data["imgs"]
            mask = packed_data["gts"]

            if len(np.unique(mask)) == 1:
                return

            if len(img.shape) > 2 and img.shape[2] != 3:
                # Make sure that 3D volumes have the same shape between images and
                # masks. If not, then there is no 1-1 matching between image and
                # mask slices.
                assert (
                    img.shape == mask.shape
                ), f"3D volume shapes do not match. Got (image) \
                    {img.shape=} and (mask) {mask.shape=}."
                # This converts 3D volumes into 2D slices, with each image slice
                # corresponding to a mask slice
                imgs, masks = m_utils.extract_2d_masks(img, mask)
            else:
                # It is possible for images to be RGB and masks to
                # be greyscale/2D arrays. To check shape agreement,
                # only check the first and second shapes
                assert (
                    img.shape[0] == mask.shape[0]
                    and img.shape[1] == mask.shape[1]
                ), f"Image and mask shapes do not match. Got (image) \
                    {img.shape=} and (mask) {mask.shape=}."
                imgs = [img]
                masks = [mask]

            # For multi-concept datasets, split up masks so that each submask
            # only contains the segmentation labels of a single concept.
            imgs, masks = m_utils.multi_mask_processing(imgs, masks, dataset)

            potential_terms = term_mapping[dataset]
            # Grabbing UMLS terms and standardizing list length between
            # images, masks, and terms.
            if len(potential_terms) == 1:
                # If this file is from a one-concept dataset, no need
                # for additional parsing. The list of candidate terms
                # must still be extended to the length of the image list
                # to account for 3D volumes, though.
                candidate_terms = [umls_terms[potential_terms[0]]] * len(imgs)

            else:
                # Extract appropriate concept from dataset files where
                # the correct concept is embedded in the file name.
                # In this case, masks are already standardized but the
                # length of the potential terms is greater than 1. This
                # is different from multi-concept datasets, where pixel
                # values represent different classes and are thus not
                # normalized.
                candidate_terms = [
                    umls_terms[
                        m_utils.parse_concept_from_file_name(
                            dataset, row.File, potential_terms
                        )
                    ]
                ] * len(imgs)
                if candidate_terms[0] is None:
                    return

            entities_per_case = 0
            for i, (img, mask, term) in enumerate(
                zip(imgs, masks, candidate_terms)
            ):

                y = term["idx"]
                candidate_text = term["desc"]
                mask = mask.astype(np.uint8)
                try:
                    label_ids = np.unique(mask)[1:]
                    img = Image.fromarray(img).convert("RGB")
                    mask = Image.fromarray(mask)

                    img = img.resize((1024, 1024), Image.LANCZOS)
                    bad_img_votes = 0
                    for k, label in enumerate(label_ids):
                        if os.path.exists(f"{tensor_dir}/masks/{index}-{i}-{k}-{y}.png"):
                            continue
                        segment_mask = np.zeros_like(np.asarray(mask))
                        segment_mask[np.asarray(mask) == label] = 1
                        segment_mask = Image.fromarray(segment_mask)

                        # Resize to 1024x1024 for ViT input. Also resize to 224x224 to ensure that
                        # masks are valid for both SAM and CLIP (this prevents issues where some
                        # masks are valid at 224x224 but become empty when resized to 1024x1024 and
                        # vice versa).
                        check = segment_mask.resize((224, 224), Image.NEAREST)
                        resized_mask = segment_mask.resize((1024, 1024), Image.NEAREST)
                        resized_mask = PILToTensor()(resized_mask)
                        check = PILToTensor()(check)

                        if len(np.unique(resized_mask.numpy())) < 2 or len(np.unique(check.numpy())) < 2:
                            bad_img_votes += 1
                            self.removed_masks += 1
                            continue

                        # Generate bounding box
                        y_indices, x_indices = np.where(resized_mask.numpy()[0] > 0)

                        x_min, x_max = np.min(x_indices), np.max(x_indices)
                        y_min, y_max = np.min(y_indices), np.max(y_indices)
                        bbox = np.array([x_min, y_min, x_max, y_max]).astype(int).tolist()
                        segment_mask = resized_mask

                        # Ensure mask is valid
                        assert len(np.unique(segment_mask.numpy())) == 2
                        entities_per_case += 1
                        segment_mask = ToPILImage()(segment_mask)
                        segment_mask.save(f"{tensor_dir}/masks/{index}-{i}-{k}-{y}.png")
                        torch.save(
                            (y, candidate_text, [bbox]),
                            (f"{tensor_dir}/annotations/{index}-{i}-{k}-{y}.pt")
                        )

                    if bad_img_votes == len(label_ids):
                        continue
                    self.num_entities_per_case.append(entities_per_case)
                    img.save(
                        (f"{tensor_dir}/images/{index}-{i}-{y}.png"),
                        #(img.numpy()),
                    )
                except Exception as e:
                    print(f"Error on file when resizing: {row.File}")
                    if type(img) == Image:
                        print(img.size, mask.size)
                    else:
                        print(img.shape, mask.shape)
                    print(e)

        logger.info("pre-tokenizing data....")

        Parallel(n_jobs=-1, backend="threading")(
            delayed(process)(row)
            for row in tqdm(mega.itertuples(index=False), total=len(mega))
        )
        del text_tokenizer

    def setup(self, stage: str):
        """Load dataset for training/validation/testing.

        NOTE: When using DDP (multiple GPUs), this is run once per GPU.
        As a result, this function should be deterministic and not download
        or have side effects. As a result, all data processing should be done in
        prepare_data and cached to disk, or done prior to training.

        Args:
            stage: either 'fit' (train), 'validate', 'test', or 'predict'
        """
        logger.info(f"Setting up data for stage: {stage}")

        # We only have access to trainer in setup, so we need to calculate
        # these parameters here.
        self._train_device_batch_size = self.hparams.train_batch_size // self.trainer.world_size
        self._test_device_batch_size = self.hparams.test_batch_size // self.trainer.world_size

        if self.trainer is not None and (
            self._train_device_batch_size is None
            or self._test_device_batch_size is None
        ):
            print("trainer got here!")
            # We test both here to fail quickly if misconfigured
            if (
                self.hparams.train_batch_size % self.trainer.world_size != 0
                or self.hparams.test_batch_size % self.trainer.world_size != 0
            ):
                raise RuntimeError(
                    f"Batch size ({self.hparams.batch_size}) is not divisible"
                    f"by the number of devices ({self.trainer.world_size})."
                )

            self._train_device_batch_size = (
                self.hparams.train_batch_size // self.trainer.world_size
            )
            self._test_device_batch_size = (
                self.hparams.test_batch_size // self.trainer.world_size
            )

        tensor_dir = self.hparams.tensor_dir  # type: ignore
        # Get list of all processed examples. Sort to ensure ordering
        # is consistent between runs.
        examples = sorted(list(glob(tensor_dir + "/masks/*.png")))
        random.seed(42)

        # Get list of all counts for dataset statistics
        all_check = []
        for i in examples:
            count = int(os.path.basename(i).split("-")[-1].split(".")[0])
            all_check.append(count)

        # This is similar to MedSAM's data processing approach, where they randomly select
        # one bounding box from the set of possible bounding boxes for a single datapoint.
        # The actual dataset size is only the amount of distinct images in the dataset
        new_dataset_size = int(len(list(glob(tensor_dir + "/images/*.png"))))
        # Group datapoints by case to ensure no data leakage happens
        by_case = [
            list(i)
            for j, i in groupby(
                examples, lambda x: os.path.basename(x).split("-")[0]
            )
        ]
        all = []
        for i in by_case:
            count = os.path.basename(i[0]).split("-")[-1].split(".")[0]
            all.append(int(count))

        # Perform train/val/test splitting
        train, val, test = self.hparams.train_val_test_split  # type: ignore
        train_set, val_test_set, train_y, val_test_y = train_test_split(
            by_case,
            all,
            train_size=train,
            test_size=val + test,
            random_state=42,
        )
        val_set, test_set, v_y, t_y = train_test_split(
                val_test_set,
                val_test_y,
                test_size=test / (val + test),
                random_state=3,
        )

        # Flatten cases into final lists. The size ratios will likely not be exact
        # to the desired ratios, but the goal is to get a relatively even amount
        # through random splitting.
        final_train_set = [slice for case in train_set for slice in case]
        final_test_set = [slice for case in test_set for slice in case]
        final_val_set = [slice for case in val_set for slice in case]
        print("train", len(train_set), len(final_train_set))
        print("test", len(test_set), len(final_test_set))
        print("val", len(val_set), len(final_val_set))

        # Get dataset statistics if needed
        # train_classes = [
        #     int(os.path.basename(i[0]).split("-")[-1].split(".")[0])
        #     for i in train_set
        # ]
        # test_classes = [
        #     int(os.path.basename(i[0]).split("-")[-1].split(".")[0])
        #     for i in test_set
        # ]
        # val_classes = [
        #     int(os.path.basename(i[0]).split("-")[-1].split(".")[0])
        #     for i in val_set
        # ]
        # dataset_stats = []
        # for i in set(all):
        #     dataset_stats.append(train_classes.count(i)+test_classes.count(i)+val_classes.count(i))
        
        # Compute weights for weighted random sampler
        c = []
        for i in final_train_set:
            count = os.path.basename(i).split("-")[-1].split(".")[0]
            c.append(int(count))

        weights = [0] * 20
        for i in set(c):
            weights[i] = 1 / c.count(i)
        sample_weights = [0] * len(c)
        for i in range(len(c)):
            sample_weights[i] = weights[c[i]]

        # Get number of distinct image instances in testing dataset:
        # The points in the test set correspond to the masks for each distinct
        # segmentation in an image, but we need the number of distinct images
        # within the test set. This is denoted by the first two elements
        # in the file name.
        flat_val_test_set = [slice for case in val_test_set for slice in case]
        distinct_val_test_images = [
            list(i)
            for j, i in groupby(
                flat_val_test_set, lambda x: os.path.basename(x).split("-")[0:2]
            )
        ]
        new_dataset_size = (
            new_dataset_size - len(distinct_val_test_images)
        )
        del flat_val_test_set
        del all

        self.sampler = WeightedRandomSampler(
            sample_weights, replacement=True, num_samples=new_dataset_size
        )

        # The WeightedRandomSampler does not support distributed training
        # out of the box, so we need to wrap it in a DistributedSamplerWrapper
        # to ensure each GPU gets a different subset of the data.
        if self.trainer.world_size > 1:
            self.distributed_sampler = DistributedSamplerWrapper(
                self.sampler, num_replicas=self.trainer.world_size, shuffle=False
            )
        else:
            self.distributed_sampler=self.sampler

        if self._train_dataset is None:
            # make training dataset
            self._train_dataset = VELCRODataset(
                items=final_train_set,
                model_path=self.hparams.sam_tokenizer_path,  # type: ignore
                is_testing=False,
                from_mem=self.hparams.from_mem,
            )
        if self._val_dataset is None:
            # make validation dataset
            self._val_dataset = VELCRODataset(
                items=final_val_set,
                model_path=self.hparams.sam_tokenizer_path,  # type: ignore
                is_testing=False,
                from_mem=self.hparams.from_mem,
            )
        if self._test_dataset is None:
            # Make test dataset
            self._test_dataset = VELCRODataset(
                items=final_test_set,
                model_path=self.hparams.sam_tokenizer_path,  # type: ignore
                is_testing=True,
                from_mem=self.hparams.from_mem,
            )

    def train_dataloader(self) -> DataLoader:
        """
        Return the training dataloader. Only this
        dataloader is given a weighted sampler. Shuffling must be
        false when using the weighted random sampler.
        """
        assert self._train_dataset is not None
        return DataLoader(
            self._train_dataset,
            batch_size=self._train_device_batch_size,  # type: ignore
            sampler=self.distributed_sampler,
            shuffle=False,
            num_workers=self.hparams.num_workers,  # type: ignore
            prefetch_factor=self.hparams.prefetch_factor,
            persistent_workers=self.hparams.persistent_workers,
        )

    def val_dataloader(self) -> DataLoader:
        """
        Return the validation dataloader.
        """
        assert self._val_dataset is not None
        sampler = None
        if self.trainer.world_size > 1:
            sampler = DistributedSampler(self._val_dataset)
        return DataLoader(
            self._val_dataset,
            batch_size=self._test_device_batch_size,  # type: ignore
            sampler=sampler,
            shuffle=False,
            num_workers=self.hparams.num_workers,  # type: ignore
            prefetch_factor=self.hparams.prefetch_factor,
            persistent_workers=self.hparams.persistent_workers,
        )

    def test_dataloader(self) -> DataLoader:
        """
        Return the test dataloader.
        """
        assert self._test_dataset is not None
        sampler = None
        if self.trainer.world_size > 1:
            sampler = DistributedSampler(self._test_dataset)
        return DataLoader(
            self._test_dataset,
            batch_size=self._test_device_batch_size,  # type: ignore
            shuffle=False,
            sampler=sampler,
            num_workers=self.hparams.num_workers,  # type: ignore
            prefetch_factor=self.hparams.prefetch_factor,
            persistent_workers=self.hparams.persistent_workers,
        )


class VELCRODataset(Dataset):
    """Dataset instance for a dataloader.

    Params:
        items (list[str]): A list of paths to the processed tensors.
        model_path (str): The huggingface name of the image model to use for
            tokenization.
        is_testing (bool): Whether the dataset is for testing. If true, no data
            augmentations will be applied.
        from_mem (bool): Whether to load all data into memory. This can
            speed up training, but requires more memory.
    """

    def __init__(self, items: list[str], model_path: str, is_testing: bool, from_mem: bool):
        # assume our dataset contains image path, segmentation mask path, label,
        # bounding boxes corresponding to each distinct segment
        if from_mem:
            self.annotations = {}
            tensor_dir = os.path.dirname(os.path.dirname(items[0]))
            for idx, path in enumerate(items):
                self.annotations[idx] = torch.load(
                    tensor_dir + f"/annotations/{os.path.splitext(os.path.basename(items[idx]))[0]}.pt", 
                    weights_only=False
                )
        self.items = items
        self.is_testing = is_testing
        self.from_mem = from_mem
        # These transforms will always work on non-empty masks
        self.safe_transforms = v2.Compose(
            [
                v2.PILToTensor(),
                v2.RandomHorizontalFlip(p=0.5),
                v2.RandomVerticalFlip(p=0.5),
            ]
        )
        # These transforms may result in empty masks, so we need to
        # try them and revert if the mask is empty
        self.danger_transforms = v2.Compose([v2.RandomRotation(90)])
        self.processor = AutoProcessor.from_pretrained(
            model_path, local_files_only=False
        )

    def __getitem__(self, idx: int):
        """Fetch a single item from the dataset indexed by idx.

        Params:
            idx: The index of the item to fetch.

        Returns:
            A dictionary mapping keys to torch tensors. It is expected that the
            tensors have a shape of (batch_size, ...).
        """
        if self.from_mem:
            mask = Image.open(self.items[idx])
            path = self.items[idx]
            split_basename = os.path.basename(path).split("-")
            stem = f"{split_basename[0]}-{split_basename[1]}-{split_basename[-1]}"
            tensor_dir = os.path.dirname(os.path.dirname(path))
            img = Image.open(tensor_dir + f"/images/{stem}")

            (label, candidate_text, bboxes) = self.annotations[idx]
        
        else:
            mask = Image.open(self.items[idx])
            path = self.items[idx]
            split_basename = os.path.basename(path).split("-")
            stem = f"{split_basename[0]}-{split_basename[1]}-{split_basename[-1]}"
            tensor_dir = os.path.dirname(os.path.dirname(path))
            (label, candidate_text, bboxes) = torch.load(
                tensor_dir + f"/annotations/{os.path.splitext(os.path.basename(self.items[idx]))[0]}.pt", weights_only=False
            )
            img = Image.open(tensor_dir + f"/images/{stem}")

        #assert isinstance(img, torch.Tensor), f"{type(img)=}"
        #assert isinstance(mask, torch.Tensor), f"{type(mask)=}"
        #assert isinstance(label, torch.Tensor), f"{type(label)=}"
        assert isinstance(candidate_text, dict), f"{type(candidate_text)=}"
        assert all(isinstance(x, torch.Tensor) for x in candidate_text.values())
        assert isinstance(bboxes, list), f"{type(bboxes)=}"

        # Convert to tv tensors for transforms
        mask = tv.Mask(mask).to(torch.int8)
        img = tv.Image(img)
        bboxes = tv.BoundingBoxes(
            bboxes, format="XYXY", canvas_size=mask.shape[1:]
        )
        if torch.max(mask) == 0:
            raise Exception("Empty mask pre")

        # Apply data augmentations
        if not self.is_testing:
            img, mask, bboxes = self.safe_transforms(img, mask, bboxes)
            try_img, try_mask, try_bboxes = self.danger_transforms(
                img, mask, bboxes
            )
            if torch.max(try_mask) != 0:
                del img, mask, bboxes
                img, mask, bboxes = try_img, try_mask, try_bboxes

        # This is where we tokenize the images
        # Because we do the random transforms as part of the __getitem__ method,
        # we need to tokenize the images here as well (and not ahead of time).
        inputs = self.processor(
            images=img,
            input_boxes=[bboxes.tolist()],
            return_tensors="pt",
            do_normalize=True,
            do_rescale=True,
            do_resize=True,
        )

        mask = mask.float()
        if torch.max(mask) == 0:
            raise Exception("Empty mask after")
        return {
            "x": {
                "candidate_input": candidate_text,
                "image_input": {"img": inputs.pixel_values},
                "bounding_boxes": inputs.input_boxes,
            },
            "y": {"class_indices": label, "gold_mask": mask, "path": path},
        }

    def __len__(self):
        """
        Return the size of the dataset.
        """
        return len(self.items)
