"""
COCO dataset which returns image_id for evaluation.

Mostly copy-paste from https://github.com/pytorch/vision/blob/13b35ff/references/detection/coco_utils.py
"""

import torch
from PIL import Image, ImageDraw

#from .modulated_coco import ConvertCocoPolysToMask
from .tsv import ODTSVDataset
from pycocotools.coco import COCO
import random
import numpy as np
import pdb, json, random, re
#from maskrcnn_benchmark.structures.bounding_box import BoxList
from .tsv import load_from_yaml_file
from collections import defaultdict
from tqdm import tqdm
import copy
import torchvision


class DiffGenDataset(torchvision.datasets.CocoDetection):
    def __init__(
        self,
        img_folder,
        ann_file,
        transforms,
        return_masks,
        return_tokens,
        is_train=False,
        tokenizer=None,
        disable_clip_to_image=False,
        no_mask_for_gold=False,
        max_query_len=256,
        caption_augmentation_version=None,
        gen_caption_augmentation_version=None,
        caption_vocab_file=None,
        **kwargs
    ):
        super(DiffGenDataset, self).__init__(img_folder, ann_file)
        self.ids = sorted(self.ids)

        ids = []
        for img_id in self.ids:
            if isinstance(img_id, str):
                ann_ids = self.coco.getAnnIds(imgIds=[img_id], iscrowd=None)
            else:
                ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=None)
            anno = self.coco.loadAnns(ann_ids)
            #if has_valid_annotation(anno):
            ids.append(img_id)
        self.ids = ids

        self.id_to_img_map = {k: v for k, v in enumerate(self.ids)}
        self._transforms = transforms
        self.max_query_len = max_query_len
        self.prepare = ConvertCocoPolysToMask(
            return_masks, return_tokens, tokenizer=tokenizer, max_query_len=max_query_len
        )
        self.is_train = is_train
        self.disable_clip_to_image = disable_clip_to_image
        self.no_mask_for_gold = no_mask_for_gold
        self.caption_augmentation_version = caption_augmentation_version
        self.gen_caption_augmentation_version = gen_caption_augmentation_version

        # import pdb; pdb.set_trace()

        # if self.caption_augmentation_version is not None:
        #     self.caption_augmentation = CaptionAugmentation(
        #         self.caption_augmentation_version,
        #         tokenizer,
        #         caption_vocab_file=caption_vocab_file
        #     )
        

        ## w/ txt aug version of dict (currently not used)
        if "augmented_gen_gt" in ann_file:
            self.descriptions_pool = self.coco.dataset['descriptions']
            self.cat_descriptions_pool = self.coco.dataset['cat_descriptions']
            self.cat_pool = list(self.coco.dataset['descriptions'].keys())

            # self.gpt_parser = GPTOutputParser('v1')
        
    def __getitem__(self, idx):

        img, target = super(DiffGenDataset, self).__getitem__(idx)
        image_id = self.ids[idx]
        coco_img = self.coco.loadImgs(image_id)[0]
        caption = coco_img["caption"]
        dataset_name = coco_img["dataset_name"] if "dataset_name" in coco_img else None
        # print("original caption: ", caption)
        
        if hasattr(self, 'cat_pool'):
            tmp_neg_cls = copy.deepcopy(self.cat_pool)
            tmp_neg_cls.remove(coco_img['cat'])
            tmp_descriptions_pool = self.descriptions_pool
        else:
            tmp_neg_cls = None
            tmp_descriptions_pool = None

        if self.gen_caption_augmentation_version is not None:
            caption, target, spans = self.gen_caption_augmentation(caption, target, 
                                        neg_captions = coco_img.get("neg_caption", None), struct_neg_captions = coco_img.get("struct_neg_caption", None),
                                        cat_pool = tmp_neg_cls, des_pool = tmp_descriptions_pool)
        else:
            spans = None

        anno = {"image_id": image_id, "annotations": target, "caption": caption}
        # This dataset is used for Flickr & Mixed, so the sequence is maskable
        anno["greenlight_span_for_masked_lm_objective"] = [(0, len(caption))]
        if self.no_mask_for_gold:
            anno["greenlight_span_for_masked_lm_objective"].append((-1, -1, -1))
        img, anno = self.prepare(img, anno)

        # convert to BoxList (bboxes, labels)
        boxes = torch.as_tensor(anno["boxes"]).reshape(-1, 4)  # guard against no boxes
        target = BoxList(boxes, img.size, mode="xyxy")
        classes = anno["labels"]
        target.add_field("labels", classes)
        if self.prepare.return_masks:
            target.add_field("masks", anno.pop("masks"))
            target.add_field("is_box_mask", anno.pop("is_box_mask"))
        if not self.disable_clip_to_image:
            num_boxes = len(target.bbox)
            target = target.clip_to_image(remove_empty=True)
            assert num_boxes == len(target.bbox), "Box got removed in MixedDataset!!!"

        if self._transforms is not None:
            img, target = self._transforms(img, target)

        # add additional property
        for ann in anno:
            target.add_field(ann, anno[ann])

        target.add_field("dataset_name", dataset_name)
        for extra_key in ["sentence_id", "original_img_id", "original_id", "task_id"]:
            if extra_key in coco_img:
                target.add_field(extra_key, coco_img[extra_key])

        return img, target, idx


    def get_img_info(self, index):
        img_id = self.id_to_img_map[index]
        img_data = self.coco.imgs[img_id]
        return img_data

def chunks(lst, n):
    """Yield successive n-sized chunks from lst."""
    all_ = []
    for i in range(0, len(lst), n):
        data_index = lst[i:i + n]
        all_.append(data_index)
    counter = 0
    for i in all_:
        counter += len(i)
    assert(counter == len(lst))

    return all_

def clean_name(name):

    def _clean_name(name):
        name = re.sub(r"\(.*\)", "", name)
        name = re.sub(r"_", " ", name)
        name = re.sub(r"  ", " ", name)
        return name

    if ":" in name:
        obj_name, part_name = name.split(":")
        obj_name = _clean_name(obj_name)
        part_name = _clean_name(part_name) 
        return  part_name + " of " + obj_name
    else:
        return _clean_name(name)

def clean_string(input_string):
    # remove leading and trailing spaces
    input_string = input_string.strip()
    # remove trailing ";" and "."
    input_string = re.sub(r";$", "", input_string)
    input_string = re.sub(r"\.$", "", input_string)
    return input_string


class DetectionToGrounding():
    '''
    Convert detection data into grounding data;
    Construct prompts for training and inference;
    '''
    def __init__(self, version):
        pass

def sanity_check_target_after_processing(target):
    assert(len(target.bbox) == len(target.extra_fields["boxes"]))


def convert_od_to_grounding_simple(
    target, 
    image_id, 
    ind_to_class, 
    disable_shuffle=True, 
    add_detection_prompt=False, 
    separation_tokens=" ",
    caption_prompt=None):
    """
    Convert object detection data into grounding data format, on the fly.
    ind_to_class: {0: "__background__", 1 : "person" ...}, contiguous id
    """

    def generate_sentence_from_labels(positive_label_list, negative_label_list, disable_shuffle=True):
        label_to_positions = {}
        label_list = negative_label_list + positive_label_list
        if not disable_shuffle:
            random.shuffle(label_list)
            assert (caption_prompt is None), "Should not specify caption_prompt when shuffle is enabled!!"  # avoid potential bug

        if add_detection_prompt:
            pheso_caption = "object detection : "
        else:
            pheso_caption = ""
        
        for index, label in enumerate(label_list):
            if caption_prompt is not None:
                pheso_caption += caption_prompt[index]['prefix']

            start_index = len(pheso_caption)
            if caption_prompt is not None:
                pheso_caption += clean_name(caption_prompt[index]['name'])
            else:
                pheso_caption += clean_name(ind_to_class[label])  # NOTE: slight change...
            end_index = len(pheso_caption)

            if caption_prompt is not None:
                pheso_caption += caption_prompt[index]['suffix']

            # e.g.: pheso_caption = "cat dog", where cat is label 4, and dog is label 17
            # label_to_positions: {4: (0, 3), 17: (4, 7)}
            label_to_positions[label] = [start_index, end_index]

            if index != len(label_list) - 1:
                pheso_caption += separation_tokens

        return label_to_positions, pheso_caption

    label_list = list(sorted(ind_to_class.keys()))  # do not include the background
    label_to_positions, pheso_caption = generate_sentence_from_labels(
        positive_label_list=label_list,
        negative_label_list=[],
        disable_shuffle=disable_shuffle
    )

    new_target = []

    '''
    Convert into:
    {'area': 10506.0, 'iscrowd': 0, 'image_id': 571335, 'category_id': 1, 'id': 2999421, 'bbox': [221, 319, 103, 102], 'tokens_positive': [[0, 3]]} 
    tokens_positive is the char position
    '''
    areas = target.area()
    greenlight_span_for_masked_lm_objective = []
    for i in range(len(target)):
        new_target_i = {}
        new_target_i["area"] = areas[i]
        new_target_i["iscrowd"] = 0
        new_target_i["image_id"] = image_id
        new_target_i["category_id"] = target.extra_fields["labels"][i].item()
        new_target_i["id"] = None
        new_target_i['bbox'] = target.bbox[i].numpy().tolist()

        label_i = target.extra_fields["labels"][i].item()

        if label_i in label_to_positions:  # NOTE: Only add those that actually appear in the final caption
            new_target_i["tokens_positive"] = [label_to_positions[label_i]]
            new_target.append(new_target_i)
            greenlight_span_for_masked_lm_objective.append(label_to_positions[label_i])
    
    return new_target, pheso_caption, greenlight_span_for_masked_lm_objective


def check_for_positive_overflow(target, ind_to_class, tokenizer, max_seq_length=256):
    # NOTE: Only call this function for OD data; DO NOT USE IT FOR GROUNDING DATA
    # NOTE: called only in coco_dt

    # Check if we have too many positive labels
    # generate a caption by appending the positive labels
    positive_label_set = set()
    for i in range(len(target)):
        label_i = target.extra_fields["labels"][i].item()
        positive_label_set.add(label_i)
    positive_label_list = list(positive_label_set)

    # random shuffule so we can sample different annotations at different epochs
    random.shuffle(positive_label_list)

    kept_lables = []
    length = 0

    for index, label in enumerate(positive_label_list):

        label_text = clean_name(ind_to_class[label]) + ". " # "dog. "

        tokenized = tokenizer.tokenize(label_text)

        length += len(tokenized)

        if length > max_seq_length:
            break
        else:
            kept_lables.append(label)
    
    ## filter boxes
    keep_box_index = []
    for i in range(len(target)):
        label_i = target.extra_fields["labels"][i].item()
        if label_i in kept_lables:
            keep_box_index.append(i)
    
    keep_box_index = torch.LongTensor(keep_box_index)

    target = target[keep_box_index] ## filter boxes

    return target, length


def _label_drop_with_length_limit(label_list, ind_to_class, length_limit, tokenizer):
    screened_label_list = []
    random.shuffle(label_list) # randomly drop labels
    for label in label_list:
        label_text = clean_name(ind_to_class[label]) + ". " # "dog. "

        tokenized = tokenizer.tokenize(label_text)
        
        length_limit -= len(tokenized)

        if length_limit > 0: 
            screened_label_list.append(label) # keep this label
        else:
            break
    return screened_label_list

def _randomv1_od_to_grounding(all_labels, ind_to_class, max_seq_length, max_num_labels, tokenizer):
    
    label_num = np.random.randint(1, max_num_labels)
    selected_label_list = np.random.choice(all_labels, label_num, replace=False)
    screened_label_list = _label_drop_with_length_limit(selected_label_list, ind_to_class, max_seq_length, tokenizer)

    return screened_label_list

def _randomv2_od_to_grounding(all_labels, ind_to_class, max_seq_length, max_num_labels, tokenizer, positive_label_set):
    
    full_positive = len(positive_label_set)
    full_negative = max_num_labels - full_positive

    outer_prob = random.random()

    if outer_prob < 0.8:
        num_negatives = full_negative
        num_positives = full_positive
    elif outer_prob < 0.9:
        num_negatives = np.random.choice(max(1, full_negative)) + 1  # mininum 1
        num_positives = full_positive
    else:
        num_positives = np.random.choice(max(1, full_positive)) + 1  # mininum 1
        num_negatives = full_negative

    # Keep some negatives
    negative_label_list = [label for label in all_labels if label not in positive_label_set]
    random.shuffle(negative_label_list)
    negative_label_list = negative_label_list[:num_negatives]

    # Keep some positives
    positive_label_list = list(positive_label_set)
    random.shuffle(positive_label_list)
    positive_label_list = positive_label_list[:num_positives]

    selected_label_list = positive_label_list + negative_label_list
    screened_label_list = _label_drop_with_length_limit(selected_label_list, ind_to_class, max_seq_length, tokenizer)
    return screened_label_list

def od_to_grounding_optimized_streamlined(
        target,
        image_id,
        ind_to_class,
        tokenizer,
        od_to_grounding_version,
    ):

    if od_to_grounding_version == "random.v1":
        separation_tokens = ". "
        max_num_labels = 85
        max_seq_length = 254
    elif od_to_grounding_version == "random.v2":
        separation_tokens = ". "
        max_num_labels = 60
        max_seq_length = 254

    def generate_senetence_given_labels(
            label_list,
            disable_shuffle=False,
        ):
        label_to_positions = {}
        if not disable_shuffle:
            random.shuffle(label_list)

        pheso_caption = ""

        for index, label in enumerate(label_list):

            start_index = len(pheso_caption)
            pheso_caption += clean_name(ind_to_class[label])  # NOTE: slight change...
            end_index = len(pheso_caption)

            # e.g.: pheso_caption = "cat dog", where cat is label 4, and dog is label 17
            # label_to_positions: {4: (0, 3), 17: (4, 7)}
            label_to_positions[label] = [start_index, end_index]

            if index != len(label_list) - 1:
                pheso_caption += separation_tokens

        return label_to_positions, pheso_caption


    if od_to_grounding_version == "random.v1":
        # all_labels, ind_to_class, max_seq_length, max_num_labels, tokenizer
        screened_label_list = _randomv1_od_to_grounding(
            all_labels = list(ind_to_class.keys()),
            ind_to_class = ind_to_class,
            max_seq_length = max_seq_length,
            max_num_labels = max_num_labels,
            tokenizer = tokenizer,
        )
        label_to_positions, pheso_caption = generate_senetence_given_labels(
            label_list=screened_label_list, )
    elif od_to_grounding_version == "random.v2":
        screened_label_list = _randomv2_od_to_grounding(
            all_labels = list(ind_to_class.keys()),
            ind_to_class = ind_to_class,
            max_seq_length = max_seq_length,
            max_num_labels = max_num_labels,
            tokenizer = tokenizer,
            positive_label_set = set(target.extra_fields["labels"].tolist()),
        )
        label_to_positions, pheso_caption = generate_senetence_given_labels(
            label_list=screened_label_list, )
    else:
        raise NotImplementedError
    
    new_target = []

    '''
    Convert into:
    {'area': 10506.0, 'iscrowd': 0, 'image_id': 571335, 'category_id': 1, 'id': 2999421, 'bbox': [221, 319, 103, 102], 'tokens_positive': [[0, 3]]} 
    tokens_positive is the char position
    '''
    areas = target.area()
    greenlight_span_for_masked_lm_objective = []
    for i in range(len(target)):
        new_target_i = {}
        new_target_i["area"] = areas[i]
        new_target_i["iscrowd"] = 0
        new_target_i["image_id"] = image_id
        new_target_i["category_id"] = target.extra_fields["labels"][i].item()
        new_target_i["id"] = None
        new_target_i['bbox'] = target.bbox[i].numpy().tolist()

        label_i = target.extra_fields["labels"][i].item()
        new_target_i["original_od_label"] = label_i

        if label_i in label_to_positions:  # NOTE: Only add labels that actually appear in the final caption
            new_target_i["tokens_positive"] = [label_to_positions[label_i]]
            new_target.append(new_target_i)
            greenlight_span_for_masked_lm_objective.append(label_to_positions[label_i])

    # reconstruct the target 
    new_target_boxlist = BoxList(torch.as_tensor([i['bbox'] for i in new_target]).reshape(-1, 4), target.size, mode="xyxy")
    new_target_boxlist.add_field("labels", torch.as_tensor([i['category_id'] for i in new_target]))

    return new_target, pheso_caption, greenlight_span_for_masked_lm_objective, label_to_positions, new_target_boxlist



def generate_control_options_given_probabilities(
        control_probabilities,
        full_positive,
        full_negative):
    
    # The function was originally designed to perform data augmentation by randomly dropping negative and positive classes. Later, we decided to only consider dropping negative classes. So the returned 'num_positives' by this function will be ignored.

    outer_prob = random.random()

    probability_one_negative = control_probabilities[0]
    probability_one_positive = control_probabilities[1]
    probability_full = control_probabilities[2]
    probability_drop_positive = control_probabilities[3]

    assert(probability_drop_positive == 0)

    if outer_prob < probability_one_negative:
        # a. probability_one_negative: only give one negative class to mimic evaluation (10%)
        num_negatives = 1
        num_positives = 0
    elif outer_prob < probability_one_positive + probability_one_negative:
        # b. probability_one_positive: only give one positive class to mimic evaluation (10%)
        num_negatives = 0
        num_positives = 1
    elif outer_prob < probability_full + probability_one_positive + probability_one_negative:
        # c. probability_full: add both all positive and all negatives (20%)
        num_negatives = full_negative
        num_positives = full_positive
    else:
        if random.random() < 1.0:  # - probability_random_negative: probability of randomly sample X negatives (100%)
            num_negatives = np.random.choice(max(1, full_negative)) + 1  # mininum 1
        else:
            num_negatives = full_negative  # Full

        if random.random() < probability_drop_positive:  #
            num_positives = np.random.choice(max(1, full_positive)) + 1
        else:
            num_positives = full_positive  # Full

    return num_negatives, num_positives

class CocoDetectionTSV(ODTSVDataset):
    def __init__(
        self,
        name,
        yaml_file,
        transforms,
        return_tokens,
        tokenizer,
        extra_fields,
        random_sample_negative=-1,
        add_detection_prompt=False,
        add_detection_prompt_advanced=False,
        use_od_data_aug=False,
        control_probabilities={},
        disable_shuffle=False,
        prompt_engineer_version="v2",
        prompt_limit_negative=-1,
        positive_question_probability=0.6,
        negative_question_probability=0.8,
        full_question_probability=0.5,
        disable_clip_to_image=False,
        separation_tokens=" ",
        no_mask_for_od=False,
        max_num_labels=-1,
        max_query_len=256,
        od_to_grounding_version="legacy",
        description_file = None,
        similarity_file = None,
        mllm_description_file = None,
        **kwargs
    ):
        super(CocoDetectionTSV, self).__init__(yaml_file, extra_fields, **kwargs)

        self._transforms = transforms
        self.name = name
        self.max_query_len = max_query_len
        self.prepare = ConvertCocoPolysToMask(
            return_masks=False, return_tokens=return_tokens, tokenizer=tokenizer, max_query_len=max_query_len
        )
        self.tokenizer = tokenizer

        self.control_probabilities = control_probabilities
        self.random_sample_negative = random_sample_negative
        self.add_detection_prompt = add_detection_prompt
        self.add_detection_prompt_advanced = add_detection_prompt_advanced
        self.use_od_data_aug = use_od_data_aug

        self.prompt_engineer_version = prompt_engineer_version
        self.prompt_limit_negative = prompt_limit_negative
        self.positive_question_probability = positive_question_probability
        self.negative_question_probability = negative_question_probability
        self.full_question_probability = full_question_probability
        self.separation_tokens = separation_tokens
        self.disable_clip_to_image = disable_clip_to_image
        self.disable_shuffle = disable_shuffle
        self.no_mask_for_od = no_mask_for_od
        self.max_num_labels = max_num_labels

        self.od_to_grounding_version = od_to_grounding_version
        self.description_file = description_file
        self.similarity_file = similarity_file
        if "description" in self.od_to_grounding_version:
            od_to_grounding_version = "description.gpt.v10.mixed.allow_zero.v1"

        ### stat
        self.pos_rate = defaultdict(list)


    def __len__(self):
        return super(CocoDetectionTSV, self).__len__()

    def categories(self, no_background=True):
        categories = self.coco.dataset["categories"]
        label_list = {}
        for index, i in enumerate(categories):
            # assert(index + 1 == i["id"])
            if not no_background or (i["name"] != "__background__" and i["id"] != 0):
                label_list[i["id"]] = i["name"]
        return label_list

    def __getitem__(self, idx):
        # tgt is a BoxList
        img, target, _, scale = super(CocoDetectionTSV, self).__getitem__(idx)
        image_id = self.get_img_id(idx)
        return img, target, image_id

    def get_raw_image(self, idx):
        image, *_ = super(CocoDetectionTSV, self).__getitem__(idx)
        return image

    def get_img_id(self, idx):
        line_no = self.get_line_no(idx)
        if self.label_tsv is not None:
            row = self.label_tsv.seek(line_no)
            img_id = row[0]
            try:
                return int(img_id)
            except:
                return idx
            
            
class BoxList(object):
    """
    This class represents a set of bounding boxes.
    The bounding boxes are represented as a Nx4 Tensor.
    In order to uniquely determine the bounding boxes with respect
    to an image, we also store the corresponding image dimensions.
    They can contain extra information that is specific to each bounding box, such as
    labels.
    """

    def __init__(self, bbox, image_size, mode="xyxy"):
        device = bbox.device if isinstance(bbox, torch.Tensor) else torch.device("cpu")
        # only do as_tensor if isn't a "no-op", because it hurts JIT tracing
        if not isinstance(bbox, torch.Tensor) or bbox.dtype != torch.float32 or bbox.device != device:
            bbox = torch.as_tensor(bbox, dtype=torch.float32, device=device)
        if bbox.ndimension() != 2:
            raise ValueError("bbox should have 2 dimensions, got {}".format(bbox.ndimension()))
        if bbox.size(-1) != 4:
            raise ValueError("last dimenion of bbox should have a " "size of 4, got {}".format(bbox.size(-1)))
        if mode not in ("xyxy", "xywh"):
            raise ValueError("mode should be 'xyxy' or 'xywh'")

        self.bbox = bbox
        self.size = image_size  # (image_width, image_height)
        self.mode = mode
        self.extra_fields = {}

    # note: _jit_wrap/_jit_unwrap only work if the keys and the sizes don't change in between
    def _jit_unwrap(self):
        return (self.bbox,) + tuple(
            f for f in (self.get_field(field) for field in sorted(self.fields())) if isinstance(f, torch.Tensor)
        )

    def _jit_wrap(self, input_stream):
        self.bbox = input_stream[0]
        num_consumed = 1
        for f in sorted(self.fields()):
            if isinstance(self.extra_fields[f], torch.Tensor):
                self.extra_fields[f] = input_stream[num_consumed]
                num_consumed += 1
        return self, input_stream[num_consumed:]

    def add_field(self, field, field_data):
        self.extra_fields[field] = field_data

    def get_field(self, field):
        return self.extra_fields[field]

    def has_field(self, field):
        return field in self.extra_fields

    def fields(self):
        return list(self.extra_fields.keys())

    def _copy_extra_fields(self, bbox):
        for k, v in bbox.extra_fields.items():
            self.extra_fields[k] = v

    def convert(self, mode):
        if mode not in ("xyxy", "xywh"):
            raise ValueError("mode should be 'xyxy' or 'xywh'")
        if mode == self.mode:
            return self
        # we only have two modes, so don't need to check
        # self.mode
        xmin, ymin, xmax, ymax = self._split_into_xyxy()
        if mode == "xyxy":
            bbox = torch.cat((xmin, ymin, xmax, ymax), dim=-1)
            bbox = BoxList(bbox, self.size, mode=mode)
        else:
            TO_REMOVE = 1
            # NOTE: explicitly specify dim to avoid tracing error in GPU
            bbox = torch.cat((xmin, ymin, xmax - xmin + TO_REMOVE, ymax - ymin + TO_REMOVE), dim=1)
            bbox = BoxList(bbox, self.size, mode=mode)
        bbox._copy_extra_fields(self)
        return bbox

    def _split_into_xyxy(self):
        if self.mode == "xyxy":
            xmin, ymin, xmax, ymax = self.bbox.split(1, dim=-1)
            return xmin, ymin, xmax, ymax
        elif self.mode == "xywh":
            TO_REMOVE = 1
            xmin, ymin, w, h = self.bbox.split(1, dim=-1)
            return (
                xmin,
                ymin,
                xmin + (w - TO_REMOVE).clamp(min=0),
                ymin + (h - TO_REMOVE).clamp(min=0),
            )
        else:
            raise RuntimeError("Should not be here")

    def resize(self, size, *args, **kwargs):
        """
        Returns a resized copy of this bounding box

        :param size: The requested size in pixels, as a 2-tuple:
            (width, height).
        """

        ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(size, self.size))
        if ratios[0] == ratios[1]:
            ratio = ratios[0]
            scaled_box = self.bbox * ratio
            bbox = BoxList(scaled_box, size, mode=self.mode)
            # bbox._copy_extra_fields(self)
            for k, v in self.extra_fields.items():
                if not isinstance(v, torch.Tensor) and not isinstance(v, list):
                    v = v.resize(size, *args, **kwargs)
                bbox.add_field(k, v)
            return bbox

        ratio_width, ratio_height = ratios
        xmin, ymin, xmax, ymax = self._split_into_xyxy()
        scaled_xmin = xmin * ratio_width
        scaled_xmax = xmax * ratio_width
        scaled_ymin = ymin * ratio_height
        scaled_ymax = ymax * ratio_height
        scaled_box = torch.cat((scaled_xmin, scaled_ymin, scaled_xmax, scaled_ymax), dim=-1)
        bbox = BoxList(scaled_box, size, mode="xyxy")
        # bbox._copy_extra_fields(self)
        for k, v in self.extra_fields.items():
            if not isinstance(v, torch.Tensor) and not isinstance(v, list):
                v = v.resize(size, *args, **kwargs)
            bbox.add_field(k, v)

        return bbox.convert(self.mode)

    def transpose(self, method):
        """
        Transpose bounding box (flip or rotate in 90 degree steps)
        :param method: One of :py:attr:`PIL.Image.FLIP_LEFT_RIGHT`,
          :py:attr:`PIL.Image.FLIP_TOP_BOTTOM`, :py:attr:`PIL.Image.ROTATE_90`,
          :py:attr:`PIL.Image.ROTATE_180`, :py:attr:`PIL.Image.ROTATE_270`,
          :py:attr:`PIL.Image.TRANSPOSE` or :py:attr:`PIL.Image.TRANSVERSE`.
        """
        if method not in (FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM):
            raise NotImplementedError("Only FLIP_LEFT_RIGHT and FLIP_TOP_BOTTOM implemented")

        image_width, image_height = self.size
        xmin, ymin, xmax, ymax = self._split_into_xyxy()
        if method == FLIP_LEFT_RIGHT:
            TO_REMOVE = 1
            transposed_xmin = image_width - xmax - TO_REMOVE
            transposed_xmax = image_width - xmin - TO_REMOVE
            transposed_ymin = ymin
            transposed_ymax = ymax
        elif method == FLIP_TOP_BOTTOM:
            transposed_xmin = xmin
            transposed_xmax = xmax
            transposed_ymin = image_height - ymax
            transposed_ymax = image_height - ymin

        transposed_boxes = torch.cat((transposed_xmin, transposed_ymin, transposed_xmax, transposed_ymax), dim=-1)
        bbox = BoxList(transposed_boxes, self.size, mode="xyxy")
        # bbox._copy_extra_fields(self)
        for k, v in self.extra_fields.items():
            if not isinstance(v, torch.Tensor) and not isinstance(v, list):
                v = v.transpose(method)
            bbox.add_field(k, v)
        return bbox.convert(self.mode)

    def crop(self, box):
        """
        Cropss a rectangular region from this bounding box. The box is a
        4-tuple defining the left, upper, right, and lower pixel
        coordinate.
        """
        xmin, ymin, xmax, ymax = self._split_into_xyxy()
        w, h = box[2] - box[0], box[3] - box[1]
        cropped_xmin = (xmin - box[0]).clamp(min=0, max=w)
        cropped_ymin = (ymin - box[1]).clamp(min=0, max=h)
        cropped_xmax = (xmax - box[0]).clamp(min=0, max=w)
        cropped_ymax = (ymax - box[1]).clamp(min=0, max=h)

        # TODO should I filter empty boxes here?
        cropped_box = torch.cat((cropped_xmin, cropped_ymin, cropped_xmax, cropped_ymax), dim=-1)
        bbox = BoxList(cropped_box, (w, h), mode="xyxy")
        # bbox._copy_extra_fields(self)
        for k, v in self.extra_fields.items():
            if not isinstance(v, torch.Tensor) and not isinstance(v, list):
                v = v.crop(box)
            bbox.add_field(k, v)
        return bbox.convert(self.mode)

    # Tensor-like methods

    def to(self, device):
        bbox = BoxList(self.bbox.to(device), self.size, self.mode)
        for k, v in self.extra_fields.items():
            if hasattr(v, "to"):
                v = v.to(device)
            bbox.add_field(k, v)
        return bbox

    def __getitem__(self, item):
        bbox = BoxList(self.bbox[item], self.size, self.mode)
        for k, v in self.extra_fields.items():
            bbox.add_field(k, v[item])
        return bbox

    def __len__(self):
        return self.bbox.shape[0]

    def clip_to_image(self, remove_empty=True):
        TO_REMOVE = 1
        x1s = self.bbox[:, 0].clamp(min=0, max=self.size[0] - TO_REMOVE)
        y1s = self.bbox[:, 1].clamp(min=0, max=self.size[1] - TO_REMOVE)
        x2s = self.bbox[:, 2].clamp(min=0, max=self.size[0] - TO_REMOVE)
        y2s = self.bbox[:, 3].clamp(min=0, max=self.size[1] - TO_REMOVE)
        self.bbox = torch.stack((x1s, y1s, x2s, y2s), dim=-1)
        if remove_empty:
            box = self.bbox
            keep = (box[:, 3] > box[:, 1]) & (box[:, 2] > box[:, 0])
            return self[keep]
        return self

    def area(self):
        if self.mode == "xyxy":
            TO_REMOVE = 1
            box = self.bbox
            area = (box[:, 2] - box[:, 0] + TO_REMOVE) * (box[:, 3] - box[:, 1] + TO_REMOVE)
        elif self.mode == "xywh":
            box = self.bbox
            area = box[:, 2] * box[:, 3]
        else:
            raise RuntimeError("Should not be here")

        return area

    def copy_with_fields(self, fields):
        bbox = BoxList(self.bbox, self.size, self.mode)
        if not isinstance(fields, (list, tuple)):
            fields = [fields]
        for field in fields:
            bbox.add_field(field, self.get_field(field))
        return bbox

    def __repr__(self):
        s = self.__class__.__name__ + "("
        s += "num_boxes={}, ".format(len(self))
        s += "image_width={}, ".format(self.size[0])
        s += "image_height={}, ".format(self.size[1])
        s += "mode={})".format(self.mode)
        return s

    @staticmethod
    def concate_box_list(list_of_boxes):
        boxes = torch.cat([i.bbox for i in list_of_boxes], dim=0)
        extra_fields_keys = list(list_of_boxes[0].extra_fields.keys())
        extra_fields = {}
        for key in extra_fields_keys:
            extra_fields[key] = torch.cat([i.extra_fields[key] for i in list_of_boxes], dim=0)

        final = list_of_boxes[0].copy_with_fields(extra_fields_keys)

        final.bbox = boxes
        final.extra_fields = extra_fields
        return final
    
    
class ConvertCocoPolysToMask(object):
    def __init__(self, return_masks=False, return_tokens=False, tokenizer=None, max_query_len=256):
        self.return_masks = return_masks
        self.return_tokens = return_tokens
        self.tokenizer = tokenizer
        self.max_query_len = max_query_len

    def get_box_mask(self, rect, img_size, mode="poly"):
        assert mode == "poly", "Only support poly mask right now!"
        x1, y1, x2, y2 = rect[0], rect[1], rect[2], rect[3]
        return [[x1, y1, x1, y2, x2, y2, x2, y1]]

    def __call__(self, image, target, ignore_box_screen=False, box_format="xywh"):
        w, h = image.size

        image_id = target["image_id"]
        image_id = torch.tensor([image_id])

        anno = target["annotations"]
        caption = target["caption"] if "caption" in target else None
        label_to_positions = target.get("label_to_positions", {})
        spans = target.get("spans", [])
        greenlight_span_for_masked_lm_objective = target.get("greenlight_span_for_masked_lm_objective", None)

        anno = [obj for obj in anno if "iscrowd" not in obj or obj["iscrowd"] == 0]

        boxes = [obj["bbox"] for obj in anno]
        # guard against no boxes via resizing
        boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
        if box_format == "xywh":
            boxes[:, 2:] += boxes[:, :2] - 1  # TO_REMOVE = 1
            boxes[:, 0::2].clamp_(min=0, max=w - 1)  # TO_REMOVE = 1
            boxes[:, 1::2].clamp_(min=0, max=h - 1)  # TO_REMOVE = 1

        classes = [obj["category_id"] for obj in anno]
        classes = torch.tensor(classes, dtype=torch.int64)

        if self.return_masks:
            masks = []
            is_box_mask = []
            for obj, bbox in zip(anno, boxes):
                if "segmentation" in obj:
                    masks.append(obj["segmentation"])
                    is_box_mask.append(0)
                else:
                    masks.append(self.get_box_mask(bbox, image.size, mode="poly"))
                    is_box_mask.append(1)
            masks = SegmentationMask(masks, image.size, mode="poly")
            is_box_mask = torch.tensor(is_box_mask)

        keypoints = None
        if anno and "keypoints" in anno[0]:
            keypoints = [obj["keypoints"] for obj in anno]
            keypoints = torch.as_tensor(keypoints, dtype=torch.float32)
            num_keypoints = keypoints.shape[0]
            if num_keypoints:
                keypoints = keypoints.view(num_keypoints, -1, 3)

        isfinal = None
        if anno and "isfinal" in anno[0]:
            isfinal = torch.as_tensor([obj["isfinal"] for obj in anno], dtype=torch.float)

        tokens_positive = [] if self.return_tokens else None
        if self.return_tokens and anno and "tokens" in anno[0]:
            tokens_positive = [obj["tokens"] for obj in anno]
        elif self.return_tokens and anno and "tokens_positive" in anno[0]:
            tokens_positive = [obj["tokens_positive"] for obj in anno]

        keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
        boxes = boxes[keep]
        classes = classes[keep]
        if self.return_masks:
            masks = masks[keep]
            is_box_mask = is_box_mask[keep]
        if keypoints is not None:
            keypoints = keypoints[keep]

        target = {}
        target["boxes"] = boxes
        target["labels"] = classes
        if caption is not None:
            target["caption"] = caption
        if self.return_masks:
            target["masks"] = masks
            target["is_box_mask"] = is_box_mask
        target["image_id"] = image_id
        if keypoints is not None:
            target["keypoints"] = keypoints

        if tokens_positive is not None:
            target["tokens_positive"] = []

            for i, k in enumerate(keep):
                if k or ignore_box_screen:
                    target["tokens_positive"].append(tokens_positive[i])

        if isfinal is not None:
            target["isfinal"] = isfinal

        # for conversion to coco api
        area = torch.tensor([obj["area"] for obj in anno])
        iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno])
        target["area"] = area[keep]
        target["iscrowd"] = iscrowd[keep]

        target["orig_size"] = torch.as_tensor([int(h), int(w)])
        target["size"] = torch.as_tensor([int(h), int(w)])

        if self.return_tokens and self.tokenizer is not None:
            if not ignore_box_screen:
                assert len(target["boxes"]) == len(target["tokens_positive"])
            tokenized = self.tokenizer(caption, return_tensors="pt", max_length=self.max_query_len, truncation=True)
            target["positive_map"] = create_positive_map(tokenized, target["tokens_positive"])
            target["greenlight_map"] = create_greenlight_map(greenlight_span_for_masked_lm_objective, tokenized)
            target["positive_map_for_od_labels"] = create_positive_map_for_od_labels(tokenized, label_to_positions)
            if len(anno) > 0 and "spans_positive" in anno[0]:
                try:
                    target["span_map"] = transfer_token_mapping_to_span_mapping(spans, [i['spans_positive'] for i in anno])
                except:
                    pass
            # create another field called the span boundaries
            # target["span_boundaries"] = create_span_boundaries(tokenized)

        original_od_label = []
        for obj in anno:
            original_od_label.append(
                obj.get("original_od_label", -10)
            )  # NOTE: The padding value has to be not the same as -1 or -100
        target["original_od_label"] = torch.as_tensor(original_od_label)

        return image, target