# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# Modified by Jialian Wu from https://github.com/facebookresearch/Detic/blob/main/detic/data/custom_dataset_mapper.py
import copy
import numpy as np
import torch

from detectron2.config import configurable

from detectron2.data import detection_utils as utils
from detectron2.data import transforms as T
from detectron2.data.dataset_mapper import DatasetMapper
from .custom_build_augmentation import build_custom_augmentation
from itertools import compress
import logging

__all__ = ["CustomDatasetMapper", "ObjDescription"]
logger = logging.getLogger(__name__)


class CustomDatasetMapper(DatasetMapper):
    @configurable
    def __init__(self, is_train: bool,
        dataset_augs=[],
        **kwargs):
        if is_train:
            self.dataset_augs = [T.AugmentationList(x) for x in dataset_augs]
        super().__init__(is_train, **kwargs)

    @classmethod
    def from_config(cls, cfg, is_train: bool = True):
        ret = super().from_config(cfg, is_train)
        if is_train:
            if cfg.INPUT.CUSTOM_AUG == 'EfficientDetResizeCrop':
                dataset_scales = cfg.DATALOADER.DATASET_INPUT_SCALE
                dataset_sizes = cfg.DATALOADER.DATASET_INPUT_SIZE
                ret['dataset_augs'] = [
                    build_custom_augmentation(cfg, True, scale, size) \
                        for scale, size in zip(dataset_scales, dataset_sizes)]
            else:
                assert cfg.INPUT.CUSTOM_AUG == 'ResizeShortestEdge'
                min_sizes = cfg.DATALOADER.DATASET_MIN_SIZES
                max_sizes = cfg.DATALOADER.DATASET_MAX_SIZES
                ret['dataset_augs'] = [
                    build_custom_augmentation(
                        cfg, True, min_size=mi, max_size=ma) \
                        for mi, ma in zip(min_sizes, max_sizes)]
        else:
            ret['dataset_augs'] = []

        return ret

    def __call__(self, dataset_dict):
        dataset_dict_out = self.prepare_data(dataset_dict)

        # When augmented image is too small, do re-augmentation
        retry = 0
        while (dataset_dict_out["image"].shape[1] < 32 or dataset_dict_out["image"].shape[2] < 32):
            retry += 1
            if retry == 100:
                logger.info('Retry 100 times for augmentation. Make sure the image size is not too small.')
                logger.info('Find image information below')
                logger.info(dataset_dict)
            dataset_dict_out = self.prepare_data(dataset_dict)

        return dataset_dict_out

    def prepare_data(self, dataset_dict_in):
        dataset_dict = copy.deepcopy(dataset_dict_in)
        if 'file_name' in dataset_dict:
            ori_image = utils.read_image(
                dataset_dict["file_name"], format=self.image_format)
        else:
            ori_image, _, _ = self.tar_dataset[dataset_dict["tar_index"]]
            ori_image = utils._apply_exif_orientation(ori_image)
            ori_image = utils.convert_PIL_to_numpy(ori_image, self.image_format)
        utils.check_image_size(dataset_dict, ori_image)

        aug_input = T.AugInput(copy.deepcopy(ori_image), sem_seg=None)
        if self.is_train:
            transforms = \
                self.dataset_augs[dataset_dict['dataset_source']](aug_input)
        else:
            transforms = self.augmentations(aug_input)
        image, sem_seg_gt = aug_input.image, aug_input.sem_seg

        image_shape = image.shape[:2]
        dataset_dict["image"] = torch.as_tensor(
            np.ascontiguousarray(image.transpose(2, 0, 1)))

        if not self.is_train:
            # USER: Modify this if you want to keep them for some reason.
            dataset_dict.pop("annotations", None)
            return dataset_dict

        if "annotations" in dataset_dict:
            if len(dataset_dict["annotations"]) > 0:
                object_descriptions = [an['object_description'] for an in dataset_dict["annotations"]]
            else:
                object_descriptions = []
            # USER: Modify this if you want to keep them for some reason.
            for anno in dataset_dict["annotations"]:
                if not self.use_instance_mask:
                    anno.pop("segmentation", None)
                if not self.use_keypoint:
                    anno.pop("keypoints", None)

            all_annos = [
                (utils.transform_instance_annotations(
                    obj, transforms, image_shape, 
                    keypoint_hflip_indices=self.keypoint_hflip_indices,
                ),  obj.get("iscrowd", 0))
                for obj in dataset_dict.pop("annotations")
            ]
            annos = [ann[0] for ann in all_annos if ann[1] == 0]
            instances = utils.annotations_to_instances(
                annos, image_shape, mask_format=self.instance_mask_format
            )

            instances.gt_object_descriptions = ObjDescription(object_descriptions)
            
            del all_annos
            if self.recompute_boxes:
                instances.gt_boxes = instances.gt_masks.get_bounding_boxes()
            dataset_dict["instances"] = utils.filter_empty_instances(instances)

        return dataset_dict


class ObjDescription:
    def __init__(self, object_descriptions):
        self.data = object_descriptions

    def __getitem__(self, item):
        assert type(item) == torch.Tensor
        assert item.dim() == 1
        if len(item) > 0:
            assert item.dtype == torch.int64 or item.dtype == torch.bool
            if item.dtype == torch.int64:
                return ObjDescription([self.data[x.item()] for x in item])
            elif item.dtype == torch.bool:
                return ObjDescription(list(compress(self.data, item)))

        return ObjDescription(list(compress(self.data, item)))

    def __len__(self):
        return len(self.data)

    def __repr__(self):
        return "ObjDescription({})".format(self.data)