import contextlib
import io
import logging
import os

import pycocotools.mask as mask_util
from fvcore.common.timer import Timer
from detectron2.data import DatasetCatalog, MetadataCatalog

from detectron2.structures import Boxes, BoxMode, PolygonMasks, RotatedBoxes
from detectron2.utils.file_io import PathManager
from .refcoco import _get_refcoco_meta

"""
This file contains functions to parse RefCOCO-format annotations into dicts in "Detectron2 format".
"""


logger = logging.getLogger(__name__)

__all__ = ["load_refcoco_json"]


def load_grefcoco_json(image_root, ref_root, json_file, dataset_name, split, extra_refer_keys=None, dataset_name_in_dict='grefcoco'):
    from .grefer import G_REFER
    #logger.info('Loading dataset {} ({}-{}) ...'.format(dataset_name, splitby, split))
    #logger.info('Refcoco root: {}'.format(refer_root))
    timer = Timer()
    with contextlib.redirect_stdout(io.StringIO()):
        refer_api = G_REFER(data_root=image_root, ref_root=ref_root, json_file=json_file,)
    if timer.seconds() > 1:
        logger.info("Loading {} takes {:.2f} seconds.".format(dataset_name, timer.seconds()))

    ref_ids = refer_api.getRefIds(split=split)
    img_ids = refer_api.getImgIds(ref_ids)
    refs = refer_api.loadRefs(ref_ids)
    imgs = [refer_api.loadImgs(ref['image_id'])[0] for ref in refs]
    anns = [refer_api.loadAnns(ref['ann_id']) for ref in refs]
    imgs_refs_anns = list(zip(imgs, refs, anns))

    dataset_dicts = []

    ann_keys = ["iscrowd", "bbox", "category_id"]
    ref_keys = ["raw", "sent_id"] + (extra_refer_keys or [])

    ann_lib = {}

    NT_count = 0
    MT_count = 0

    for (img_dict, ref_dict, anno_dicts) in imgs_refs_anns:
        record = {}
        record["source"] = 'grefcoco'
        record["file_name"] = os.path.join(image_root, img_dict["file_name"])
        record["height"] = img_dict["height"]
        record["width"] = img_dict["width"]
        image_id = record["image_id"] = img_dict["id"]

        # Check that information of image, ann and ref match each other
        # This fails only when the data parsing logic or the annotation file is buggy.
        assert ref_dict['image_id'] == image_id
        assert ref_dict['split'] == split
        if not isinstance(ref_dict['ann_id'], list):
            ref_dict['ann_id'] = [ref_dict['ann_id']]

        # No target samples
        if None in anno_dicts:
            assert anno_dicts == [None]
            assert ref_dict['ann_id'] == [-1]
            record['empty'] = True
            obj = {key: None for key in ann_keys if key in ann_keys}
            obj["bbox_mode"] = BoxMode.XYWH_ABS
            obj["empty"] = True
            obj = [obj]

        # Multi target samples
        else:
            record['empty'] = False
            obj = []
            for anno_dict in anno_dicts:
                ann_id = anno_dict['id']
                if anno_dict['iscrowd']:
                    continue
                assert anno_dict["image_id"] == image_id
                assert ann_id in ref_dict['ann_id']

                if ann_id in ann_lib:
                    ann = ann_lib[ann_id]
                else:
                    ann = {key: anno_dict[key] for key in ann_keys if key in anno_dict}
                    ann["bbox_mode"] = BoxMode.XYWH_ABS
                    ann["empty"] = False
 
                    segm = anno_dict.get("segmentation", None)
                    assert segm  # either list[list[float]] or dict(RLE)
                    if isinstance(segm, dict):
                        if isinstance(segm["counts"], list):
                            # convert to compressed RLE
                            segm = mask_util.frPyObjects(segm, *segm["size"])
                    else:
                        # filter out invalid polygons (< 3 points)
                        segm = [poly for poly in segm if len(poly) % 2 == 0 and len(poly) >= 6]
                        if len(segm) == 0:
                            num_instances_without_valid_segmentation += 1
                            continue  # ignore this instance
                    ann["segmentation"] = segm
                    ann_lib[ann_id] = ann

                obj.append(ann)

        record["annotations"] = obj

        # language-guided detection
        #record["dataset_name"] = dataset_name_in_dict
        record["task"] = "grounding"

        # Process referring expressions
        #sents = ref_dict['sentences']
        #ref_list = [sent["raw"] for sent in sents]
        #record["expressions"] = ref_list

        # Process referring expressions
        sents = ref_dict['sentences']
        for sent in sents:
            ref_record = record.copy()
            ref = {key: sent[key] for key in ref_keys if key in sent}
            ref["ref_id"] = ref_dict["ref_id"]
            ref_record["sentence"] = ref
            ref_record["expressions"] = ref["raw"]
            ref_record["has_mask"] = True
            
            dataset_dicts.append(ref_record)
    #         if ref_record['empty']:
    #             NT_count += 1
    #         else:
    #             MT_count += 1

    # logger.info("NT samples: %d, MT samples: %d", NT_count, MT_count)

    # Debug mode
    # return dataset_dicts[:100]

    return dataset_dicts[:32]


_PREDEFINED_SPLITS_REFCOCO = {
        # grefcoco
    "grefcoco-unc-train": ("train", "coco/train2014", "gres/annotations/grefcoco_unc_train.json"),
    "grefcoco-unc-val": ("val", "coco/train2014", "gres/annotations/grefcoco_val.json"),
    "grefcoco-unc-testA": ("testA", "coco/train2014", "gres/annotations/grefcoco_unc_testA.json"),
    "grefcoco-unc-testB": ("testB", "coco/train2014", "gres/annotations/grefcoco_unc_testB.json"),
}
        
        
def _get_grefcoco_builtin_metadata(name, split, image_root, json_file, ref_root):
    """
    Args:
        name (str): the name that identifies a dataset, e.g. "coco_2014_train".
        metadata (dict): extra metadata associated with this dataset.  You can
            leave it as an empty dict.
        json_file (str): path to the json instance annotation file.
        image_root (str or path-like): directory which contains all the images.
    """
    assert isinstance(name, str), name
    assert isinstance(json_file, (str, os.PathLike)), json_file
    assert isinstance(image_root, (str, os.PathLike)), image_root
    # 1. register a function which returns dicts 
    DatasetCatalog.register(name, lambda: load_grefcoco_json(
        image_root, ref_root, json_file, name, split))
    # 2. Optionally, add metadata about this dataset,
    # since they might be useful in evaluation, visualization or logging
    MetadataCatalog.get(name).set(
        evaluator_type="grefcoco", dataset_name=name, splitby="unc",
        split=split, root="/data/gres", image_root=image_root,
    )
    
def register_grefcoco(root):
    for key, (split, image_root, json_file) in _PREDEFINED_SPLITS_REFCOCO.items():
        # Assume pre-defined datasets live in `./datasets`.
        image_root = os.path.join(root, image_root)
        json_file = os.path.join(root, json_file) if "://" not in json_file else json_file
        ref_root = os.path.join(root, "gres/annotations")
        _get_grefcoco_builtin_metadata(key, split, image_root, json_file, ref_root)