import transformers
import cv2
import os
import json
import torch
import copy

import numpy as np
from PIL import Image
from typing import Dict, Optional, Sequence, List
from torch.utils.data import Dataset
from torchvision import transforms

import pycocotools.mask as maskUtils

from .base_class import DataArguments, rank0_print, preprocess, preprocess_multimodal
from TraceVLM.utils_format import NumberSegFormatter, MultiNumberSegFormatter, BinPolar36pointGDSegFormatter, BinPolar36pointcenterGDSegFormatter, MultiPolar36pointGDSegFormatter, DictNumberSegFormatter, BinSegFormatter, BinPolarGDSegFormatter, BinPolarSegFormatter, BinPolar36pointSegFormatter, BinPolar36point_allmask_SegFormatter, BinPolarInstanceSegFormatter, BinPolar_allmask_SegFormatter
from TraceVLM.utils_format import MultiPolarSegFormatter, MultiPolarGDSegFormatter, MultiPolar36centerGDSegFormatter
from TraceVLM.utils_format import BinPolar12GDSegFormatter, BinPolar24GDSegFormatter, BinPolar48GDSegFormatter, BinPolar72GDSegFormatter
from TraceVLM.utils_format import BinPolar48pointGDSegFormatter, BinPolar12pointGDSegFormatter, BinPolar24pointGDSegFormatter, BinPolar72pointGDSegFormatter
from TraceVLM.utils_format import BinPolarAngleSegFormatter, BinPolarAngle7SegFormatter, BinPolarAngleSEGSegFormatter, BinPolarAngleSEGALLSegFormatter
from TraceVLM.utils import box_xyxy_expand2square, resize_box, resize_seg, resize_mask, reorder, reorder_seg_and_merge, get_single_centerpoint, get_all_centerpoint, reorder_mask_and_merge
from TraceVLM.utils import is_clockwise, reorder_points, revert_direction, interpolate_polygons, approximate_polygons, polygons_to_string, sample_polygons, calculate_angle
from TraceVLM.constants import REF_TASK_PALCEHOLDER, DEFAULT_SEGES_PLACEHOLDER

import re

def sample_point_seg(segmentation, number=200):
    sampled_segmentations = []
    for segment in segmentation:
        sampled_segmentation = []
        for idx in range(len(segment)):
            if len(segment[idx]) < number:
                sampled_segmentation.append(segment[idx])
            else:
                indices = [int(i) for i in np.linspace(0, len(segment[idx]) - 1, num=number)]
                over_segmentation_list = []
                for index in indices:
                    over_segmentation_list.append(segment[idx][index])
                sampled_segmentation.append(over_segmentation_list)
        sampled_segmentations.append(sampled_segmentation)
    return sampled_segmentations

def polygon_reorder(segmentation, max_length=400):
    polygons_processed = []
    for polygon in segmentation:
        if isinstance(polygon[0], list):
            polygons_per_processed = []
            for segm in polygon:
                # make the polygon clockwise
                if not is_clockwise(segm):
                    polygon = revert_direction(segm)
                
                # reorder the polygon so that the first vertex is the one closest to image origin
                polygon = reorder_points(polygon)
                polygons_per_processed.append(polygon)
            polygons_per_processed = sorted(polygons_per_processed, key=lambda x: (x[0] ** 2 + x[1] ** 2, x[0], x[1]))
            polygons = approximate_polygons(polygons_per_processed, tolerance=1.0, max_length=max_length)
            polygons_processed.append(polygons)   
            # polygons_processed.append(polygons_per_processed) 
            
        else:
            # make the polygon clockwise
            if not is_clockwise(polygon):
                polygon = revert_direction(polygon)
            
            # reorder the polygon so that the first vertex is the one closest to image origin
            polygon = reorder_points(polygon)
            polygons_processed.append(polygon)
            polygons_processed = sorted(polygons_processed, key=lambda x: (x[0] ** 2 + x[1] ** 2, x[0], x[1]))
            polygons_interpolated = interpolate_polygons(polygons_processed)
            polygons_processed = approximate_polygons(polygons_interpolated, tolerance=1.0, max_length=max_length)
    segmentation = polygons_processed
    return segmentation

class RESSegDataset(Dataset):
    "dataset for referring image segmentation"
    def __init__(self, data_path: str,
                 tokenizer: transformers.PreTrainedTokenizer,
                 data_args: DataArguments,
                 image_folder: str = None,
                 template_file: str = None,
                 prompt: str = None,
                 debug=False,
                ):
        super(RESSegDataset, self).__init__()

        f = open(data_path, "r", encoding="utf-8")
        list_data_dict = f.readlines()

        rank0_print("Formatting inputs...RES Dataset")
        self.tokenizer = tokenizer
        self.list_data_dict = list_data_dict
        self.data_args = data_args
        self.image_folder = image_folder
        assert not(prompt != None and template_file != None)
        if template_file is not None:
            self.prompts = json.load(open(template_file, 'r', encoding='utf8'))
            self.rng = np.random.default_rng(1203)
            self.prompt = None
        else:
            self.prompt = prompt # "<image><expr><single>"
            self.prompts = None
        # 这里暂定，目前没有更多的详细设置
        # TODO: 后续更新
        # polygon结合、不结合、只保留最大
        # 不管点数、固定点数
        self.point_format = 1 # 1:不管点数 2:固定点数 3:随机点数

        # data format # NumberSegFormatter、DictNumber、Bin、BinPolar、BinPolarInstance、BinPolar36point
        if data_args.formatter.lower() == "NumberSegFormatter".lower():
            self.segformat = NumberSegFormatter(self.data_args.image_processor, precision=3)
        elif data_args.formatter.lower() == "DictNumber".lower():
            self.segformat = DictNumberSegFormatter(self.data_args.image_processor, precision=3)
        elif data_args.formatter.lower() == "Bin".lower():  # 得需要顺序
            self.segformat = BinSegFormatter(self.data_args.image_processor, seg_split_placeholder=self.data_args.seg_split, precision=3)
        elif data_args.formatter.lower() == "BinPolar".lower():   # 得需要确认是1.只要最大的polygon；2.所有polygon结合; 3.所有polygon都要
            self.segformat = BinPolarSegFormatter(self.data_args.image_processor, seg_split_placeholder=self.data_args.seg_split, precision=3)
        elif data_args.formatter.lower() == "BinPolarGD".lower():   # 得需要确认是1.只要最大的polygon；2.所有polygon结合; 3.所有polygon都要
            self.segformat = BinPolarGDSegFormatter(self.data_args.image_processor, seg_split_placeholder=self.data_args.seg_split, precision=3)
        elif data_args.formatter.lower() == "BinPolarGDAngle".lower():   # 得需要确认是1.只要最大的polygon；2.所有polygon结合; 3.所有polygon都要
            self.segformat = BinPolarAngleSegFormatter(self.data_args.image_processor, seg_split_placeholder=self.data_args.seg_split, precision=3)
        elif data_args.formatter.lower() == "BinPolarGDAngleSEG".lower():   # 得需要确认是1.只要最大的polygon；2.所有polygon结合; 3.所有polygon都要
            self.segformat = BinPolarAngleSEGSegFormatter(self.data_args.image_processor, seg_split_placeholder=self.data_args.seg_split, precision=3)
        elif data_args.formatter.lower() == "BinPolarGDAngleSEGALL".lower():   # 得需要确认是1.只要最大的polygon；2.所有polygon结合; 3.所有polygon都要
            self.segformat = BinPolarAngleSEGALLSegFormatter(self.data_args.image_processor, seg_split_placeholder=self.data_args.seg_split, precision=3)
        elif data_args.formatter.lower() == "BinPolarGDAngle7".lower():   # 得需要确认是1.只要最大的polygon；2.所有polygon结合; 3.所有polygon都要
            self.segformat = BinPolarAngle7SegFormatter(self.data_args.image_processor, seg_split_placeholder=self.data_args.seg_split, precision=3)
        elif data_args.formatter.lower() == "BinPolar12GD".lower():   # 得需要确认是1.只要最大的polygon；2.所有polygon结合; 3.所有polygon都要
            self.segformat = BinPolar12GDSegFormatter(self.data_args.image_processor, seg_split_placeholder=self.data_args.seg_split, precision=3)
        elif data_args.formatter.lower() == "BinPolar24GD".lower():   # 得需要确认是1.只要最大的polygon；2.所有polygon结合; 3.所有polygon都要
            self.segformat = BinPolar24GDSegFormatter(self.data_args.image_processor, seg_split_placeholder=self.data_args.seg_split, precision=3)
        elif data_args.formatter.lower() == "BinPolar48GD".lower():   # 得需要确认是1.只要最大的polygon；2.所有polygon结合; 3.所有polygon都要
            self.segformat = BinPolar48GDSegFormatter(self.data_args.image_processor, seg_split_placeholder=self.data_args.seg_split, precision=3)
        elif data_args.formatter.lower() == "BinPolar72GD".lower():   # 得需要确认是1.只要最大的polygon；2.所有polygon结合; 3.所有polygon都要
            self.segformat = BinPolar72GDSegFormatter(self.data_args.image_processor, seg_split_placeholder=self.data_args.seg_split, precision=3)
        elif data_args.formatter.lower() == "BinPolar_allmask".lower():   # 得需要确认是1.只要最大的polygon；2.所有polygon结合; 3.所有polygon都要
            self.segformat = BinPolar_allmask_SegFormatter(self.data_args.image_processor, seg_split_placeholder=self.data_args.seg_split, precision=3)
        elif data_args.formatter.lower() == "BinPolarInstance".lower():   
            self.segformat = BinPolarInstanceSegFormatter(self.data_args.image_processor, seg_split_placeholder=self.data_args.seg_split, precision=3)
        elif data_args.formatter.lower() == "BinPolar36point".lower():   
            self.segformat = BinPolar36pointSegFormatter(self.data_args.image_processor, precision=3)
        elif data_args.formatter.lower() == "BinPolar36point_allmask".lower():   
            self.segformat = BinPolar36point_allmask_SegFormatter(self.data_args.image_processor, precision=3)
        elif data_args.formatter.lower() == "MultiNumber".lower():
            self.segformat = MultiNumberSegFormatter(self.data_args.image_processor, seg_split_placeholder='&', precision=3)
        elif data_args.formatter.lower() == "MultiNumberpolarpoint".lower():
            self.segformat = MultiPolarSegFormatter(self.data_args.image_processor, seg_split_placeholder=self.data_args.seg_split, precision=3)
        elif data_args.formatter.lower() == "MultiNumberpolarpointGD".lower():
            self.segformat = MultiPolarGDSegFormatter(self.data_args.image_processor, seg_split_placeholder=self.data_args.seg_split, precision=3)
        elif data_args.formatter.lower() == "MultiNumberpolar36centerGD".lower():
            self.segformat = MultiPolar36centerGDSegFormatter(self.data_args.image_processor, seg_split_placeholder=self.data_args.seg_split, precision=3)
        elif data_args.formatter.lower() == "MultiNumber36polarpGDoint".lower():
            self.segformat = MultiPolar36pointGDSegFormatter(self.data_args.image_processor, seg_split_placeholder=self.data_args.seg_split, precision=3)
        elif data_args.formatter.lower() == "Bin12polarGDpoint".lower():
            self.segformat = BinPolar12pointGDSegFormatter(self.data_args.image_processor, seg_split_placeholder=self.data_args.seg_split, precision=3)
        elif data_args.formatter.lower() == "Bin24polarGDpoint".lower():
            self.segformat = BinPolar24pointGDSegFormatter(self.data_args.image_processor, seg_split_placeholder=self.data_args.seg_split, precision=3)
        elif data_args.formatter.lower() == "Bin36polarGDpoint".lower():
            self.segformat = BinPolar36pointGDSegFormatter(self.data_args.image_processor, seg_split_placeholder=self.data_args.seg_split, precision=3)
        elif data_args.formatter.lower() == "Bin36polarcenterGDpoint".lower():
            self.segformat = BinPolar36pointcenterGDSegFormatter(self.data_args.image_processor, seg_split_placeholder=self.data_args.seg_split, precision=3)
        elif data_args.formatter.lower() == "Bin48polarGDpoint".lower():
            self.segformat = BinPolar48pointGDSegFormatter(self.data_args.image_processor, seg_split_placeholder=self.data_args.seg_split, precision=3)
        elif data_args.formatter.lower() == "Bin72polarGDpoint".lower():
            self.segformat = BinPolar72pointGDSegFormatter(self.data_args.image_processor, seg_split_placeholder=self.data_args.seg_split, precision=3)
        # image resize
        self.debug = debug
        if not self.data_args.image_processor.do_resize:
            self.resize = transforms.Resize((self.data_args.image_processor.size["shortest_edge"], self.data_args.image_processor.size["shortest_edge"]))

    def __len__(self):
        return len(self.list_data_dict)

    def transform2conv(self, i):
        if self.prompts is not None:
            prompt = self.rng.choice(self.prompts)
        else:
            prompt = self.prompt
        if not prompt.endswith("\n"):
            prompt += "\n"
        if "<image>" not in prompt:
            prompt = "<image>" + prompt

        item = json.loads(self.list_data_dict[i])
        if "file_path" in item:
            img_path = item["file_path"]
        elif "img_path" in item:
            img_path = item["img_path"]
        else:
            img_path = item["image_name"]
        expr = item["expression"]
        if "bbox" in item:
            bbox = item["bbox"]
        else:
            bbox = item["bboxes"]
        if "segmentations" in item:
            segmentation = item['segmentations']
        else:
            segmentation = item['seg']

        width = item["width"]
        height = item["height"]
        sentence_flag = True
        # bbox = [bbox[0], bbox[1], bbox[0] + bbox[2] - 1, bbox[1] + bbox[3] - 1]
        # with mask   
        if self.data_args.with_mask:
            if type(segmentation[0]) == list:
                rles = maskUtils.frPyObjects(segmentation, height, width)
                rle = maskUtils.merge(rles)
            else:
                rle = segmentation
                if segmentation[0]['size'][0] != height:
                    height, width = width, height
            masks = maskUtils.decode(rle)
            binary_mask = np.array(masks, dtype=np.uint8)
            contours_mask, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)  # 这里contours可能有个polygon
            segmentation = [contour_mask.reshape(contour_mask.shape[0] * 2).tolist() for contour_mask in contours_mask]

            if self.data_args.with_only_gd:
                gd_polygon = contours_mask[0][:, 0, :].tolist()
                gd_polygon.extend([gd_polygon[0]])
                gd_angle = []
                for i_gd in range(0, len(gd_polygon) - 1):
                    if i_gd == 0:
                        p1 = gd_polygon[len(gd_polygon) - 2]
                        p2 = gd_polygon[i_gd]
                        p3 = gd_polygon[i_gd+1]
                    else:
                        p1 = gd_polygon[i_gd-1]
                        p2 = gd_polygon[i_gd]
                        p3 = gd_polygon[i_gd+1]
                    angle = calculate_angle(p1, p2, p3)
                    gd_angle.append(angle)
                gd_angle = np.array(gd_angle)
                angle_indice = np.argsort(gd_angle)[::-1][:36]
                norm_angle_indice = np.argsort(angle_indice)
                gd_contour_mask = contours_mask[0][angle_indice[norm_angle_indice], :, :]
                segmentation = [gd_contour_mask.reshape(gd_contour_mask.shape[0] * 2).tolist()]

            if self.data_args.with_polar:  # for data_format BinPolarInstance
                mask_centers = []
                mask_contours = []
                if masks.ndim == 2:
                    masks = [masks]
                else:
                    mask_list = []
                    for index in range(masks.shape[-1]):
                        mask_list.append(masks[:, :, index])
                    masks = mask_list
                for mask in masks:
                    try:
                        cnt, contour = get_single_centerpoint(mask)
                        # cnt, contour = get_all_centerpoint(mask)
                    except:
                        continue
                    contour = contour[0][:, 0, :]
                    x, y = cnt
                    mask_centers.append([x,y])
                    mask_contours.append(contour.tolist())
                    # mask_centers.append(cnt)
                    # mask_contours.append([cont[:, 0, :].tolist() for cont in contour])
                # if self.data_args.with_gd:
                #     mask_final_angle = []
                #     each_mask_angle = []
                #     try:
                #         for each_contour in contour:
                #             gd_polygon = each_contour[:, 0, :].tolist()
                #             gd_polygon.extend([gd_polygon[0]])
                #             gd_angle = []
                #             for i_gd in range(0, len(gd_polygon) - 1):
                #                 if i_gd == 0:
                #                     p1 = gd_polygon[len(gd_polygon) - 2]
                #                     p2 = gd_polygon[i_gd]
                #                     p3 = gd_polygon[i_gd+1]
                #                 else:
                #                     p1 = gd_polygon[i_gd-1]
                #                     p2 = gd_polygon[i_gd]
                #                     p3 = gd_polygon[i_gd+1]
                #                 angle = calculate_angle(p1, p2, p3)
                #                 gd_angle.append(angle)
                #             gd_angle = np.array(gd_angle)
                #             each_mask_angle.append(torch.Tensor(gd_angle).float())
                #         mask_final_angle.append(each_mask_angle)
                #     except:
                #         print(item)
                #         print('*****************')
                #         import pdb;pdb.set_trace()
                # if self.data_args.with_gd:
                #     gd_polygon = contour.tolist()
                #     gd_polygon.extend([gd_polygon[0]])
                #     gd_angle = []
                #     for i_gd in range(0, len(gd_polygon) - 1):
                #         if i_gd == 0:
                #             p1 = gd_polygon[len(gd_polygon) - 2]
                #             p2 = gd_polygon[i_gd]
                #             p3 = gd_polygon[i_gd+1]
                #         else:
                #             p1 = gd_polygon[i_gd-1]
                #             p2 = gd_polygon[i_gd]
                #             p3 = gd_polygon[i_gd+1]
                #         angle = calculate_angle(p1, p2, p3)
                #         gd_angle.append(angle)
                #     gd_angle = np.array(gd_angle)
                #     mask_final_angle = [torch.Tensor(gd_angle).float()]
                # if len(mask_centers) == 0:
                #     sentence_flag = False
        else:
            if type(segmentation[0]) == list:  # refcoco系列、grefcoco
                segmentation = segmentation  # [[183.06, 220.54, 200.36, 230.63, 221.98, 240.72, 224.86, 247.93, 220.54, ...], [253.69, 155.68, 229.19, 142.7, 204.68, 134.05, 185.95, 138.38, 171.53, ...]]
                # polygons = sample_polygons(segmentation, max_length=256)
                # segmentation = polygons
            else:  # refclef、flickr30k,其中refclef的表现形式：；flickr30k的表现形式：
                segmentation_list = []
                for seg in segmentation:  # [{'size': [...], 'counts': 'am>2U::I6FW1oN3L5L3M3N3L4L4L4L3L4M3iHcMY6[3M3M2O1N2O1O1O100OInITLQ6m3oISLn5Y3ZJnLIHi5]3bJhLFJg5Z3hJlLBId5X3QKmLV5n2QKQMn4o2SKQMm4l2WKSMi4k2YKTMh4k2YKUMh4k2XKTMh4l2\\11O01M3F:G9L4L5O0000001bNgGf0Y8WOQHa0S8UObGK`0l0Q9L2N3N2L5L1N2K9JfmP4'}, {'size': [...], 'counts': '[k_11Z:7J2M3J6M2O2N1L2QOZOWHi0j7XOSHl0k7VOoGo0P8SOmGP1Y2hNP3;dJo0Y2jNQ3:cJm0[2jNo2m2RLUMl3`3_KaL`4W4hJjKX5f4000O10000O2N1O2O1O1lN]JZMe5b2YKbLk4Y3Z1OO120K6J4L7J4L1O1dIQLU6U40002OO001N1UNoIXOQ6c0jJfNX5:hIXOP2KY4j0\\LTOf3i0]LUOd3j0]LSOg3j0S3L4K6I4L4JdZP3'}]
                    binary_mask = np.array(maskUtils.decode(seg), dtype=np.uint8)
                    contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)  # 这里contours可能有个polygon
                    polygons = [contour.reshape(contour.shape[0] * 2).tolist() for contour in contours]
                    # polygons = sample_polygons(polygons, max_length=256)
                    segmentation_list.append(polygons)
                    # segmentation_list.append([contour.reshape(contour.shape[0] * 2).tolist() for contour in contours])
                segmentation = segmentation_list  # segmentation_list存在多个实例，每个实例可能有多个polygon
                if not isinstance(bbox[0], list):  # for refclef
                    bbox = [bbox]
        
        try:
            if self.data_args.with_polygon_reorder:
                segmentation = polygon_reorder(segmentation, max_length=400)
        except:
            print("The error is {}".format(item))

        try:
            if self.data_args.sample_polygon:
                segmentation = sample_polygons(segmentation, max_length=400)
        except:
            print("The error is {}".format(item))

        # if len(bbox) > 0 and sentence_flag == True and len(mask_contours) > 0:
        if len(bbox) > 0 and sentence_flag == True:
            if not self.data_args.image_processor.do_resize:
                bbox, _, _ = resize_box(bbox, self.data_args.image_processor.size, height, width, pre=True)
                if self.data_args.with_mask:
                    ori_masks = masks.copy()
                    masks, _, _ = resize_mask(masks, self.data_args.image_processor.size, height, width, pre=True)
                    if self.data_args.with_polar:
                        mask_centers, _, _ = resize_seg(mask_centers, self.data_args.image_processor.size, height, width, pre=True)
                        mask_contours, _, _ = resize_seg(mask_contours, self.data_args.image_processor.size, height, width, pre=True)
                else:
                    segmentation, _, _ = resize_seg(segmentation, self.data_args.image_processor.size, height, width, pre=True)
            sentence = "{}".format(DEFAULT_SEGES_PLACEHOLDER) 

            if self.data_args.with_gd:
                for contour_test in mask_contours:
                    gd_polygon = contour_test.copy()
                    gd_polygon.extend([gd_polygon[0]])
                    gd_angle = []
                    for i_gd in range(0, len(gd_polygon) - 1):
                        if i_gd == 0:
                            p1 = gd_polygon[len(gd_polygon) - 2]
                            p2 = gd_polygon[i_gd]
                            p3 = gd_polygon[i_gd+1]
                        else:
                            p1 = gd_polygon[i_gd-1]
                            p2 = gd_polygon[i_gd]
                            p3 = gd_polygon[i_gd+1]
                        angle = calculate_angle(p1, p2, p3)
                        gd_angle.append(angle)
                    gd_angle = np.array(gd_angle)
                    mask_final_angle = [torch.Tensor(gd_angle).float()]

            # new_size
            if self.data_args.image_aspect_ratio == "pad":
                new_height = max(width, height)
                new_width = max(width, height)
            elif self.data_args.image_aspect_ratio == "org_pad":
                # 直接padding右侧 或者下方，从而保持原始的横纵比，且坐标直接resize即可，不需要进行padding处理
                new_height = max(width, height)
                new_width = max(width, height)
            elif not self.data_args.image_processor.do_resize:
                # 直接resize ！！！
                new_height = self.data_args.image_processor.size["shortest_edge"]
                new_width = self.data_args.image_processor.size["shortest_edge"]
            else:
                new_height = height
                new_width = width
            
            # data format
            if self.data_args.formatter == "NumberBoxFormatter":
                if isinstance(bbox[0], list):
                    sentence = self.segformat(sentence, [segmentation], new_height, new_width)
                else:
                    sentence = self.segformat(sentence, [[segmentation]], new_height, new_width)
            elif self.data_args.formatter == "DictNumber":
                pass
            elif self.data_args.formatter == "Bin":  # <Bin_0>, <Bin_1>
                if isinstance(bbox[0], list):  # 一般是什么数据会进行，是flickr30k还是VG
                    sentence = self.segformat(sentence, [(len(segmentation) * ['referent'], segmentation)], new_height, new_width)
                else:
                    if self.data_args.sample_point:
                        segmentation = sample_point_seg([segmentation])[0]
                    sentence = self.segformat(sentence, [(['referent'], [segmentation])], new_height, new_width)
                masks = None
            elif self.data_args.formatter == "BinPolar": # <Bin_中心点>, <Bin_距离> 
                sentence = self.segformat(sentence, [(['referent'], mask_centers, mask_contours)], new_height, new_width)
            elif self.data_args.formatter == "BinPolarGD": # <Bin_中心点>, <Bin_距离> 
                sentence = self.segformat(sentence, [(['referent'], mask_centers, mask_contours, mask_final_angle)], new_height, new_width)
            elif self.data_args.formatter == "BinPolarGDAngle": # <Bin_中心点>, <Bin_距离> 
                sentence = self.segformat(sentence, [(['referent'], mask_centers, mask_contours, mask_final_angle)], new_height, new_width)
            elif self.data_args.formatter == "BinPolarGDAngleSEG": # <Bin_中心点>, <Bin_距离> 
                sentence = self.segformat(sentence, [(['referent'], mask_centers, mask_contours, mask_final_angle)], new_height, new_width)
            elif self.data_args.formatter == "BinPolarGDAngleSEGALL": # <Bin_中心点>, <Bin_距离> 
                sentence = self.segformat(sentence, [(['referent'], mask_centers, mask_contours, mask_final_angle)], new_height, new_width)
            elif self.data_args.formatter == "BinPolarGDAngle7": # <Bin_中心点>, <Bin_距离> 
                sentence = self.segformat(sentence, [(['referent'], mask_centers, mask_contours, mask_final_angle)], new_height, new_width)
            elif self.data_args.formatter == "BinPolar12GD": # <Bin_中心点>, <Bin_距离> 
                sentence = self.segformat(sentence, [(['referent'], mask_centers, mask_contours, mask_final_angle)], new_height, new_width)
            elif self.data_args.formatter == "BinPolar24GD": # <Bin_中心点>, <Bin_距离> 
                sentence = self.segformat(sentence, [(['referent'], mask_centers, mask_contours, mask_final_angle)], new_height, new_width)
            elif self.data_args.formatter == "BinPolar48GD": # <Bin_中心点>, <Bin_距离> 
                sentence = self.segformat(sentence, [(['referent'], mask_centers, mask_contours, mask_final_angle)], new_height, new_width)
            elif self.data_args.formatter == "BinPolar72GD": # <Bin_中心点>, <Bin_距离> 
                sentence = self.segformat(sentence, [(['referent'], mask_centers, mask_contours, mask_final_angle)], new_height, new_width)
            elif self.data_args.formatter == "BinPolar_allmask": # <Bin_中心点>, <Bin_距离>
                sentence = self.segformat(sentence, [(['referent'], [masks])], new_height, new_width)
            elif self.data_args.formatter == "BinPolarInstance":  # with_polar
                if isinstance(bbox[0], list):
                    cate_names, _, masks, mask_centers, mask_contours = reorder_mask_and_merge(len(segmentation) * ['referent'], bbox, masks, mask_centers, mask_contours)
                    sentence = self.segformat(sentence, [(cate_names, masks, mask_centers, mask_contours)], new_height, new_width)
                else:
                    sentence = self.segformat(sentence, [(['referent'], masks, mask_centers, mask_contours)], new_height, new_width)
            elif self.data_args.formatter == "BinPolar36point":  # <Bin_point>
                sentence = self.segformat(sentence, [(['referent'], mask_centers, mask_contours)], new_height, new_width)
            elif self.data_args.formatter == "BinPolar36point_allmask":  # <Bin_point>
                sentence = self.segformat(sentence, [(['referent'], [masks])], new_height, new_width)
            elif self.data_args.formatter == "MultiNumberpolarpoint":
                sentence = self.segformat(sentence, [(['referent'], mask_centers, mask_contours)], new_height, new_width)      
            elif self.data_args.formatter == "MultiNumberpolarpointGD":
                sentence = self.segformat(sentence, [(['referent'], mask_centers, mask_contours, mask_final_angle)], new_height, new_width)   
            elif self.data_args.formatter == "MultiNumberpolar36centerGD":
                sentence = self.segformat(sentence, [(['referent'], mask_centers, mask_contours, mask_final_angle)], new_height, new_width)
            elif self.data_args.formatter == "MultiNumber36polarpGDoint":
                sentence = self.segformat(sentence, [(['referent'], mask_centers, mask_contours, mask_final_angle)], new_height, new_width)
            elif self.data_args.formatter == "Bin12polarGDpoint":
                sentence = self.segformat(sentence, [(['referent'], mask_centers, mask_contours, mask_final_angle)], new_height, new_width)
            elif self.data_args.formatter == "Bin24polarGDpoint":
                sentence = self.segformat(sentence, [(['referent'], mask_centers, mask_contours, mask_final_angle)], new_height, new_width)
            elif self.data_args.formatter == "Bin36polarGDpoint":
                sentence = self.segformat(sentence, [(['referent'], mask_centers, mask_contours, mask_final_angle)], new_height, new_width)
            elif self.data_args.formatter == "Bin36polarcenterGDpoint":
                sentence = self.segformat(sentence, [(['referent'], mask_centers, mask_contours, mask_final_angle)], new_height, new_width)
            elif self.data_args.formatter == "Bin48polarGDpoint":
                sentence = self.segformat(sentence, [(['referent'], mask_centers, mask_contours, mask_final_angle)], new_height, new_width)
            elif self.data_args.formatter == "Bin72polarGDpoint":
                sentence = self.segformat(sentence, [(['referent'], mask_centers, mask_contours, mask_final_angle)], new_height, new_width)
            elif self.data_args.formatter == "MultiNumber":
                if isinstance(segmentation[0], list): # Multi instance
                    if isinstance(bbox[0], list):  # grit 系列 
                        if self.data_args.sample_point:
                            segmentation = sample_point_seg(segmentation)
                        cate_names, _, seges = reorder_seg_and_merge(len(segmentation) * ['referent'], segmentation, bbox)
                        sentence = self.segformat(sentence, [(cate_names, seges)], new_height, new_width)
                    elif isinstance(bbox[0], int) or isinstance(bbox[0], float):  # refcoco系列
                        cate_names, _, seges = reorder_seg_and_merge(len(segmentation) * ['referent'], segmentation, [bbox])
                        sentence = self.segformat(sentence, [(cate_names, [seges])], new_height, new_width)
                else:
                    sentence = self.segformat(sentence, [([expr], [segmentation])], new_height, new_width)
        else:
            sentence = "None"

        ret = {
            "image": img_path,
            "conversations": [
                {
                    "from": "human",
                    "value": prompt.replace("<expr>", expr)
                },
                {
                    "from": "gpt",
                    "value": sentence
                }
            ],
            "mask": masks,
            # "ori_mask": ori_masks,
        }
        if i == 0:
            print(prompt)
            print(sentence)
        return ret
    
    @property
    def lengths(self):
        length_list = []
        for sample in self.list_data_dict:
            img_tokens = 128
            sample = json.loads(sample)
            if "segmentations" in sample:
                seg_num = len(sample['segmentations'])
            else:
                seg_num = len(sample['seg'])
            length_list.append(140 * seg_num + img_tokens)
        return length_list
    
    @property
    def modality_types(self):
        # 在此处，纯文本的长度赋予-cur_len进行区分，同理，对于region，可进行该处理
        type_list = []
        for _, _ in enumerate(self.list_data_dict):
            cur_type = 2 # 1:LAZY 2:RES 3:GRES 4:IS 5:SEG
            type_list.append(cur_type)
        return type_list
    
    @property
    def modality_lengths(self):  
        length_list = []
        for sample in self.list_data_dict:
            img_tokens = 128
            sample = json.loads(sample)
            if "segmentations" in sample:
                seg_num = len(sample['segmentations'])
            else:
                seg_num = len(sample['seg'])
            length_list.append(140 * seg_num + img_tokens)
        return length_list

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        if i % 50 == 0:
            aaa = np.int16(i) / np.int16(0)
        #得到原始的item
        sources = self.transform2conv(i)
        # sources = self.list_data_dict[i]
        if isinstance(i, int):
            sources = [sources]
        assert len(sources) == 1, "Don't know why it is wrapped to a list"  # FIXME
        if 'mask' in sources[0]:
            mask = sources[0]['mask']
        #     # ori_mask = sources[0]['ori_mask']
            has_mask = True
        else:
            has_mask = False
        if 'image' in sources[0]:
            image_file = sources[0]['image']
            if self.image_folder is None:
                image_folder = self.data_args.image_folder
            else:
                image_folder = self.image_folder
            processor = self.data_args.image_processor #CLIP中的process，将图像resize norm等
            try:
                image = Image.open(os.path.join(image_folder, image_file)).convert('RGB')
            except:
                image = Image.fromarray(cv2.cvtColor(cv2.imread(os.path.join(image_folder, image_file)), cv2.COLOR_BGR2RGB))
            if self.data_args.image_aspect_ratio == 'pad':
                def expand2square(pil_img, background_color):
                    width, height = pil_img.size
                    if width == height:
                        return pil_img
                    elif width > height:
                        result = Image.new(pil_img.mode, (width, width), background_color)
                        result.paste(pil_img, (0, (width - height) // 2))
                        return result
                    else:
                        result = Image.new(pil_img.mode, (height, height), background_color)
                        result.paste(pil_img, ((height - width) // 2, 0))
                        return result
                image = expand2square(image, tuple(int(x*255) for x in processor.image_mean)) #预先padding成为square
                image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
            elif not self.data_args.image_processor.do_resize:
                image = self.resize(image)
                if self.debug:
                    image = processor.preprocess(image, do_normalize=False, return_tensors='pt')['pixel_values'][0]
                else:
                    image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
            else:
                image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
            sources = preprocess_multimodal(
                copy.deepcopy([e["conversations"] for e in sources]),
                self.data_args)
            has_image = True
        else:
            sources = copy.deepcopy([e["conversations"] for e in sources])
            has_image = False
        data_dict = preprocess(
            sources,
            self.tokenizer,
            has_image=has_image,
            )
        if isinstance(i, int):
            data_dict = dict(input_ids=data_dict["input_ids"][0],
                             labels=data_dict["labels"][0],
                             )
        if has_mask:
            data_dict['mask'] = mask
            # data_dict['ori_mask'] = ori_mask
        # image exist in the data
        data_dict['source'] = sources
        data_dict['image_path'] = os.path.join(image_folder, image_file)
        if has_image:
            data_dict['image'] = image
        elif self.data_args.is_multimodal:
            # image does not exist in the data, but the model is multimodal
            crop_size = self.data_args.image_processor.crop_size
            data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width'])
        return data_dict
