import logging
import os
import os.path
import math
from PIL import Image, ImageDraw

import random
import numpy as np

import torch
import torchvision
import torch.utils.data as data
from pycocotools import mask as coco_mask

from maskrcnn_benchmark.structures.bounding_box import BoxList
from maskrcnn_benchmark.structures.segmentation_mask import SegmentationMask
from maskrcnn_benchmark.data.datasets.coco import has_valid_annotation
from .od_to_grounding import convert_od_to_grounding_simple, check_for_positive_overflow, sanity_check_target_after_processing, convert_object_detection_to_grounding_optimized_for_od
import pdb
import json

class CocoGrounding(torchvision.datasets.CocoDetection):
    def __init__(self,
                 img_folder,
                 ann_file,
                 transforms,
                 return_masks,
                 return_tokens,
                 is_train=False,
                 tokenizer=None,
                 disable_shuffle=False,
                 add_detection_prompt=False,
                 one_hot=False,
                 disable_clip_to_image=False,
                 no_minus_one_for_one_hot=False,
                 separation_tokens=" ",
                 few_shot=0,
                 no_mask_for_od=False,
                 override_category=None,
                 use_caption_prompt=False,
                 caption_prompt=None,
                 max_query_len=256,
                 special_safeguard_for_coco_grounding=False,
                 random_sample_negative=-1,
                 **kwargs
                 ):
        super(CocoGrounding, self).__init__(img_folder, ann_file)
        self.ids = sorted(self.ids)

        ids = []
        for img_id in self.ids:
            if isinstance(img_id, str):
                ann_ids = self.coco.getAnnIds(imgIds=[img_id], iscrowd=None)
            else:
                ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=None)
            anno = self.coco.loadAnns(ann_ids)
            if has_valid_annotation(anno):
                ids.append(img_id)

        self.ids = ids
        
        if few_shot:
            ids = []
            # cats_freq = [few_shot]*len(self.coco.cats.keys())
            cats_freq = [few_shot]*max(list(self.coco.cats.keys()))
            for img_id in self.ids:
                if isinstance(img_id, str):
                    ann_ids = self.coco.getAnnIds(imgIds=[img_id], iscrowd=None)
                else:
                    ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=None)
                anno = self.coco.loadAnns(ann_ids)
                cat = set([ann['category_id'] for ann in anno]) #set/tuple corresponde to instance/image level
                is_needed = sum([cats_freq[c-1]>0 for c in cat])
                if is_needed:
                    ids.append(img_id)
                    for c in cat:
                        cats_freq[c-1] -= 1
                    # print(cat, cats_freq)
            self.ids = ids



        self.json_category_id_to_contiguous_id = {
            v: i + 1 for i, v in enumerate(self.coco.getCatIds())
        }
        self.contiguous_category_id_to_json_id = {
            v: k for k, v in self.json_category_id_to_contiguous_id.items()
        }

        if override_category is not None:
            self.coco.dataset["categories"] = override_category
        self.use_caption_prompt = use_caption_prompt
        self.caption_prompt = caption_prompt
        self.special_safeguard_for_coco_grounding = special_safeguard_for_coco_grounding
        self.random_sample_negative = random_sample_negative
        self.ind_to_class = self.categories(no_background=False)
        self.id_to_img_map = {k: v for k, v in enumerate(self.ids)}
        self._transforms = transforms
        self.max_query_len = max_query_len
        self.prepare = ConvertCocoPolysToMask(False, return_tokens, tokenizer=tokenizer, max_query_len=max_query_len)
        self.tokenizer = tokenizer
        self.is_train = is_train

        self.ind_to_class = self.categories(no_background=False)

        self.disable_shuffle = disable_shuffle
        self.add_detection_prompt = add_detection_prompt
        self.one_hot = one_hot
        self.no_minus_one_for_one_hot = no_minus_one_for_one_hot

        self.disable_clip_to_image = disable_clip_to_image
        self.separation_tokens = separation_tokens
        self.no_mask_for_od = no_mask_for_od
        self.return_masks = return_masks

    def categories(self, no_background=True):
        categories = self.coco.dataset["categories"]
        label_list = {}
        for index, i in enumerate(categories):
            # assert(index + 1 == i["id"])
            if not no_background or (i["name"] != "__background__" and i['id'] != 0):
                label_list[self.json_category_id_to_contiguous_id[i["id"]]] = i["name"]
        return label_list

    def get_box_mask(self, rect, img_size, mode="poly"):
        assert mode=="poly", "Only support poly mask right now!"
        x1, y1, x2, y2 = rect[0], rect[1], rect[2], rect[3]
        return [[x1, y1, x1, y2, x2, y2, x2, y1]]

    def __getitem__(self, idx):
        img, tgt = super(CocoGrounding, self).__getitem__(idx)
        image_id = self.ids[idx]
        tgt = [obj for obj in tgt if obj["iscrowd"] == 0]
        boxes = [obj["bbox"] for obj in tgt]
        boxes = torch.as_tensor(boxes).reshape(-1, 4)  # guard against no boxes
        target = BoxList(boxes, img.size, mode="xywh").convert("xyxy")
        classes = [obj["category_id"] for obj in tgt]
        classes = [self.json_category_id_to_contiguous_id[c] for c in classes]
        classes = torch.tensor(classes)
        target.add_field("labels", classes)

        if self.return_masks:
            masks = []
            is_box_mask = []
            for obj, bbox in zip(tgt, target.bbox):
                if "segmentation" in obj:
                    masks.append(obj["segmentation"])
                    is_box_mask.append(0)
                else:
                    masks.append(self.get_box_mask(bbox, img.size, mode="poly"))
                    is_box_mask.append(1)
            masks = SegmentationMask(masks, img.size, mode="poly")
            is_box_mask = torch.tensor(is_box_mask)
            target.add_field("masks", masks)
            target.add_field("is_box_mask", is_box_mask)
        
        if not self.disable_clip_to_image:
            target = target.clip_to_image(remove_empty=True)
        
        if self.special_safeguard_for_coco_grounding:
            # Intended for LVIS
            assert(not self.use_caption_prompt)

            original_box_num = len(target)
            target, positive_caption_length = check_for_positive_overflow(target, self.ind_to_class, self.tokenizer, self.max_query_len-2) # leave some space for the special tokens
            if len(target) < original_box_num:
                print("WARNING: removed {} boxes due to positive caption overflow".format(original_box_num - len(target)))

            annotations, caption, greenlight_span_for_masked_lm_objective, label_to_positions = convert_object_detection_to_grounding_optimized_for_od(
                target=target,
                image_id=image_id,
                ind_to_class=self.ind_to_class,
                disable_shuffle=self.disable_shuffle,
                add_detection_prompt=False,
                add_detection_prompt_advanced=False,
                random_sample_negative=self.random_sample_negative,
                control_probabilities=(0.0, 0.0, 1.0, 0.0), # always try to add a lot of negatives
                restricted_negative_list=None,
                separation_tokens=self.separation_tokens,
                max_num_labels=-1,
                positive_caption_length=positive_caption_length,
                tokenizer=self.tokenizer,
                max_seq_length=self.max_query_len-2
            )
        else:
            # Intended for COCO / ODinW
            annotations, caption, greenlight_span_for_masked_lm_objective = convert_od_to_grounding_simple(
                target=target,
                image_id=image_id,
                ind_to_class=self.ind_to_class,
                disable_shuffle=self.disable_shuffle,
                add_detection_prompt=self.add_detection_prompt,
                separation_tokens=self.separation_tokens,
                caption_prompt=self.caption_prompt if self.use_caption_prompt else None,
            )

        anno = {"image_id": image_id, "annotations": annotations, "caption": caption}
        anno["greenlight_span_for_masked_lm_objective"] = greenlight_span_for_masked_lm_objective
        if self.no_mask_for_od:
            anno["greenlight_span_for_masked_lm_objective"].append((-1, -1, -1))
        img, anno = self.prepare(img, anno, box_format="xyxy")

        # for equivalence check
        if self.one_hot:
            logging.info("using one hot for equivalence check.")
            one_hot_map = torch.zeros_like(anno["positive_map"], dtype=torch.float)
            text_mask = torch.zeros(anno["positive_map"].shape[1], dtype=torch.int64)
            # create one hot mapping
            for ii, cls in enumerate(classes):
                if self.no_minus_one_for_one_hot:
                    one_hot_map[ii, cls] = 1.0
                else:
                    one_hot_map[ii, cls - 1] = 1.0
            if self.no_minus_one_for_one_hot:
                text_mask[:] = 1
            else:
                text_mask[:len(self.ind_to_class)] = 1
            anno["positive_map"] = one_hot_map
            anno["text_mask"] = text_mask

        if self._transforms is not None:
            img, target = self._transforms(img, target)

        # add additional property
        for ann in anno:
            target.add_field(ann, anno[ann])
        
        sanity_check_target_after_processing(target)

        return img, target, idx

    def get_img_info(self, index):
        img_id = self.id_to_img_map[index]
        img_data = self.coco.imgs[img_id]
        return img_data


class ModulatedDataset(torchvision.datasets.CocoDetection):
    def __init__(self,
                 img_folder,
                 ann_file,
                 transforms,
                 return_masks,
                 return_tokens,
                 is_train=False,
                 tokenizer=None,
                 disable_clip_to_image=False,
                 no_mask_for_gold=False,
                 max_query_len=256,
                 **kwargs):
        super(ModulatedDataset, self).__init__(img_folder, ann_file)
        self.ids = sorted(self.ids)

        ids = []
        for img_id in self.ids:
            if isinstance(img_id, str):
                ann_ids = self.coco.getAnnIds(imgIds=[img_id], iscrowd=None)
            else:
                ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=None)
            anno = self.coco.loadAnns(ann_ids)
            if has_valid_annotation(anno):
                ids.append(img_id)
        self.ids = ids

        self.id_to_img_map = {k: v for k, v in enumerate(self.ids)}
        self._transforms = transforms
        self.max_query_len = max_query_len
        self.prepare = ConvertCocoPolysToMask(return_masks, return_tokens, tokenizer=tokenizer, max_query_len=max_query_len)
        self.is_train = is_train
        self.disable_clip_to_image = disable_clip_to_image
        self.no_mask_for_gold = no_mask_for_gold

    def __getitem__(self, idx):
        img, target = super(ModulatedDataset, self).__getitem__(idx)
        image_id = self.ids[idx]
        coco_img = self.coco.loadImgs(image_id)[0]
        caption = coco_img["caption"]
        dataset_name = coco_img["dataset_name"] if "dataset_name" in coco_img else None
        anno = {"image_id": image_id, "annotations": target, "caption": caption}

        # This dataset is used for Flickr & Mixed, so the sequence is maskable
        anno["greenlight_span_for_masked_lm_objective"] = [(0, len(caption))]
        if self.no_mask_for_gold:
            anno["greenlight_span_for_masked_lm_objective"].append((-1, -1, -1))
        img, anno = self.prepare(img, anno)

        # convert to BoxList (bboxes, labels)
        boxes = torch.as_tensor(anno["boxes"]).reshape(-1, 4)  # guard against no boxes
        target = BoxList(boxes, img.size, mode="xyxy")
        classes = anno["labels"]
        target.add_field("labels", classes)
        if self.prepare.return_masks:
            target.add_field("masks", anno.pop("masks"))
            target.add_field("is_box_mask", anno.pop("is_box_mask"))
        if not self.disable_clip_to_image:
            num_boxes = len(target.bbox)
            target = target.clip_to_image(remove_empty=True)
            assert num_boxes == len(target.bbox), "Box got removed in MixedDataset!!!"

        # Check if bboxes are correct
        # draw = ImageDraw.Draw(img)
        # boxes = target.bbox
        # for box in boxes:
        #     draw.rectangle([box[0], box[1], box[2], box[3]])
        # img.save('OUTPUT/images/{}.jpg'.format(idx))

        if self._transforms is not None:
            img, target = self._transforms(img, target)

        # add additional property
        for ann in anno:
            target.add_field(ann, anno[ann])

        target.add_field("dataset_name", dataset_name)
        for extra_key in ["sentence_id", "original_img_id", "original_id", "task_id"]:
            if extra_key in coco_img:
                target.add_field(extra_key, coco_img[extra_key])

        if "tokens_positive_eval" in coco_img and not self.is_train:
            tokenized = self.prepare.tokenizer(caption, return_tensors="pt")
            # print(f"original_img_id {coco_img['original_img_id']}", tokenized)
            target.add_field("positive_map_eval", create_positive_map(tokenized, coco_img["tokens_positive_eval"]))
            target.add_field("nb_eval", len(target.get_field("positive_map_eval")))

        sanity_check_target_after_processing(target)
        return img, target, idx

    def get_img_info(self, index):
        img_id = self.id_to_img_map[index]
        img_data = self.coco.imgs[img_id]
        return img_data


class CocoDetection(data.Dataset):
    """`MS Coco Detection <http://mscoco.org/dataset/#detections-challenge2016>`_ Dataset.

    Args:
        root (string): Root directory where images are downloaded to.
        annFile (string): Path to json annotation file.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.ToTensor``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
    """

    def __init__(self, root, annFile, transform=None, target_transform=None):
        from pycocotools.coco import COCO
        self.root = root
        self.coco = COCO(annFile)
        self.ids = list(self.coco.imgs.keys())
        self.transform = transform
        self.target_transform = target_transform

    def __getitem__(self, index, return_meta=False):
        """
        Args:
            index (int): Index

        Returns:
            tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``.
        """
        coco = self.coco
        img_id = self.ids[index]
        if isinstance(img_id, str):
            img_id = [img_id]
        ann_ids = coco.getAnnIds(imgIds=img_id)
        target = coco.loadAnns(ann_ids)

        meta = coco.loadImgs(img_id)[0]
        path = meta['file_name']
        img = pil_loader(os.path.join(self.root, path))

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        if return_meta:
            return img, target, meta
        else:
            return img, target

    def __len__(self):
        return len(self.ids)

    def __repr__(self):
        fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
        fmt_str += '    Number of datapoints: {}\n'.format(self.__len__())
        fmt_str += '    Root Location: {}\n'.format(self.root)
        tmp = '    Transforms (if any): '
        fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
        tmp = '    Target Transforms (if any): '
        fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
        return fmt_str


class ConvertCocoPolysToMask(object):
    def __init__(self, return_masks=False, return_tokens=False, tokenizer=None, max_query_len=256):
        self.return_masks = return_masks
        self.return_tokens = return_tokens
        self.tokenizer = tokenizer
        self.max_query_len = max_query_len

    def get_box_mask(self, rect, img_size, mode="poly"):
        assert mode=="poly", "Only support poly mask right now!"
        x1, y1, x2, y2 = rect[0], rect[1], rect[2], rect[3]
        return [[x1, y1, x1, y2, x2, y2, x2, y1]]

    def __call__(self, image, target, ignore_box_screen=False, box_format="xywh"):
        w, h = image.size

        image_id = target["image_id"]
        image_id = torch.tensor([image_id])

        # pdb.set_trace()
        anno = target["annotations"]
        caption = target["caption"] if "caption" in target else None
        label_to_positions = target.get("label_to_positions", {})

        greenlight_span_for_masked_lm_objective = target.get("greenlight_span_for_masked_lm_objective", None)

        anno = [obj for obj in anno if "iscrowd" not in obj or obj["iscrowd"] == 0]

        boxes = [obj["bbox"] for obj in anno]
        # guard against no boxes via resizing
        boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
        if box_format == "xywh":
            boxes[:, 2:] += boxes[:, :2] - 1  # TO_REMOVE = 1
            boxes[:, 0::2].clamp_(min=0, max=w-1)  # TO_REMOVE = 1
            boxes[:, 1::2].clamp_(min=0, max=h-1)  # TO_REMOVE = 1

        classes = [obj["category_id"] for obj in anno]
        classes = torch.tensor(classes, dtype=torch.int64)

        if self.return_masks:
            masks = []
            is_box_mask = []
            for obj, bbox in zip(anno, boxes):
                if "segmentation" in obj:
                    masks.append(obj["segmentation"])
                    is_box_mask.append(0)
                else:
                    masks.append(self.get_box_mask(bbox, image.size, mode='poly'))
                    is_box_mask.append(1)
            masks = SegmentationMask(masks, image.size, mode='poly')
            is_box_mask = torch.tensor(is_box_mask)

        keypoints = None
        if anno and "keypoints" in anno[0]:
            keypoints = [obj["keypoints"] for obj in anno]
            keypoints = torch.as_tensor(keypoints, dtype=torch.float32)
            num_keypoints = keypoints.shape[0]
            if num_keypoints:
                keypoints = keypoints.view(num_keypoints, -1, 3)

        isfinal = None
        if anno and "isfinal" in anno[0]:
            isfinal = torch.as_tensor([obj["isfinal"] for obj in anno], dtype=torch.float)

        tokens_positive = [] if self.return_tokens else None
        if self.return_tokens and anno and "tokens" in anno[0]:
            tokens_positive = [obj["tokens"] for obj in anno]
        elif self.return_tokens and anno and "tokens_positive" in anno[0]:
            tokens_positive = [obj["tokens_positive"] for obj in anno]

        keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
        boxes = boxes[keep]
        classes = classes[keep]
        if self.return_masks:
            masks = masks[keep]
            is_box_mask = is_box_mask[keep]
        if keypoints is not None:
            keypoints = keypoints[keep]

        target = {}
        target["boxes"] = boxes
        target["labels"] = classes
        if caption is not None:
            target["caption"] = caption
        if self.return_masks:
            target["masks"] = masks
            target["is_box_mask"] = is_box_mask
        target["image_id"] = image_id
        if keypoints is not None:
            target["keypoints"] = keypoints

        if tokens_positive is not None:
            target["tokens_positive"] = []

            for i, k in enumerate(keep):
                if k or ignore_box_screen:
                    target["tokens_positive"].append(tokens_positive[i])

        if isfinal is not None:
            target["isfinal"] = isfinal

        # for conversion to coco api
        area = torch.tensor([obj["area"] for obj in anno])
        iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno])
        target["area"] = area[keep]
        target["iscrowd"] = iscrowd[keep]

        target["orig_size"] = torch.as_tensor([int(h), int(w)])
        target["size"] = torch.as_tensor([int(h), int(w)])

        if self.return_tokens and self.tokenizer is not None:
            if not ignore_box_screen:
                assert len(target["boxes"]) == len(target["tokens_positive"])
            tokenized = self.tokenizer(caption, return_tensors="pt",
                max_length=self.max_query_len,
                truncation=True)
            target["positive_map"] = create_positive_map(tokenized, target["tokens_positive"])
            target['greenlight_map'] = create_greenlight_map(greenlight_span_for_masked_lm_objective,tokenized)
            target["positive_map_for_od_labels"] = create_positive_map_for_od_labels(tokenized, label_to_positions)

        original_od_label = []
        for obj in anno:
            original_od_label.append(
                obj.get("original_od_label", -10))  # NOTE: The padding value has to be not the same as -1 or -100
        target["original_od_label"] = torch.as_tensor(original_od_label)

        return image, target

def create_greenlight_map(tok_list, tokenized):
    # An example tok_list:
    # [(0, 5), (10, 13), (-1, -1, -1)]
    # The last one is a special indicator..

    greenlight_map = torch.zeros(256, dtype=torch.float)
    for item in tok_list:
        if len(item) != 2:
            assert(len(item) == 3)
            # Make everything unmakable
            greenlight_map[:] = -1
            break

        beg, end = item
        beg_pos = tokenized.char_to_token(beg)
        end_pos = tokenized.char_to_token(end - 1)
        if beg_pos is None:
            try:
                beg_pos = tokenized.char_to_token(beg + 1)
                if beg_pos is None:
                    beg_pos = tokenized.char_to_token(beg + 2)
            except:
                beg_pos = None
        if end_pos is None:
            try:
                end_pos = tokenized.char_to_token(end - 2)
                if end_pos is None:
                    end_pos = tokenized.char_to_token(end - 3)
            except:
                end_pos = None
        if beg_pos is None or end_pos is None:
            continue

        assert beg_pos is not None and end_pos is not None
        greenlight_map[beg_pos: end_pos + 1].fill_(1)
    return greenlight_map


def create_positive_map_for_od_labels(tokenized, label_to_positions):
    """construct a map such that positive_map[i] = j, where j is the object detection label of the token i"""
    """
    {3: [1: 5)}
    256 : -1 3 3 3 3 -1 .. 8 8 ..
    the woman in the garden
    -1 -1 -1 -1 -1
    """
    positive_map = torch.ones(256, dtype=torch.float) * -1  # -1 means no match
    keys = list(label_to_positions.keys())
    for j, key in enumerate(keys):
        tok_list = label_to_positions[key]
        # one label only mapps to one location
        beg, end = tok_list
        beg_pos = tokenized.char_to_token(beg)
        end_pos = tokenized.char_to_token(end - 1)
        if beg_pos is None:
            try:
                beg_pos = tokenized.char_to_token(beg + 1)
                if beg_pos is None:
                    beg_pos = tokenized.char_to_token(beg + 2)
            except:
                beg_pos = None
        if end_pos is None:
            try:
                end_pos = tokenized.char_to_token(end - 2)
                if end_pos is None:
                    end_pos = tokenized.char_to_token(end - 3)
            except:
                end_pos = None
        if beg_pos is None or end_pos is None:
            continue
        assert beg_pos is not None and end_pos is not None
        positive_map[beg_pos: end_pos + 1].fill_(key)
    return positive_map


def convert_coco_poly_to_mask(segmentations, height, width):
    masks = []
    for polygons in segmentations:
        rles = coco_mask.frPyObjects(polygons, height, width)
        mask = coco_mask.decode(rles)
        if len(mask.shape) < 3:
            mask = mask[..., None]
        mask = torch.as_tensor(mask, dtype=torch.uint8)
        mask = mask.any(dim=2)
        masks.append(mask)
    if masks:
        masks = torch.stack(masks, dim=0)
    else:
        masks = torch.zeros((0, height, width), dtype=torch.uint8)
    return masks


def create_positive_map(tokenized, tokens_positive):
    """construct a map such that positive_map[i,j] = True iff box i is associated to token j"""
    positive_map = torch.zeros((len(tokens_positive), 256), dtype=torch.float)

    for j, tok_list in enumerate(tokens_positive):
        for (beg, end) in tok_list:
            beg_pos = tokenized.char_to_token(beg)
            end_pos = tokenized.char_to_token(end - 1)
            if beg_pos is None:
                try:
                    beg_pos = tokenized.char_to_token(beg + 1)
                    if beg_pos is None:
                        beg_pos = tokenized.char_to_token(beg + 2)
                except:
                    beg_pos = None
            if end_pos is None:
                try:
                    end_pos = tokenized.char_to_token(end - 2)
                    if end_pos is None:
                        end_pos = tokenized.char_to_token(end - 3)
                except:
                    end_pos = None
            if beg_pos is None or end_pos is None:
                continue

            assert beg_pos is not None and end_pos is not None
            positive_map[j, beg_pos: end_pos + 1].fill_(1)
    return positive_map / (positive_map.sum(-1)[:, None] + 1e-6)


def pil_loader(path, retry=5):
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    ri = 0
    while ri < retry:
        try:
            with open(path, 'rb') as f:
                img = Image.open(f)
                return img.convert('RGB')
        except:
            ri += 1
