import copy
import json
import math
import os
import random
import re
import ast
from typing import Dict

import torch
import transformers
import yaml
from qwen_vl_utils import smart_resize, process_vision_info
from torch.utils.data import Dataset

from V2P.constants import (
    IGNORE_INDEX,
    DEFAULT_IMAGE_TOKEN,
    DEFAULT_POINTER_START_TOKEN,
    DEFAULT_POINTER_PAD_TOKEN,
    DEFAULT_POINTER_END_TOKEN,
    ACTION_PATTENS_XY,
    ADDITIONAL_SPECIAL_TOKENS,
    assistant_template,
    chat_template,
    grounding_system_message,
)
from V2P.trainer import rank0_print


def reformat_coordinates(text):
    """
    (1) Find all the coordinates in the text.
    (2) Replace the coordinates with the special tokens.
    (3) Return the new text and the coordinates as a list of (x, y), where x in [0, 1] and y in [0, 1].
    """
    epsilon = 0.001
    def adjust_coord(c):
        """
        Adjust coordinate if it is too close to 0 or 1.
        """
        if abs(c) < epsilon:
            return epsilon
        elif abs(c - 1) < epsilon:
            return 1 - epsilon
        return c

    all_matches = []
    for pattern in ACTION_PATTENS_XY:
        matches = list(re.finditer(pattern, text))
        for match in matches:
            all_matches.append((match.start(), match.groups()))
        if pattern == ACTION_PATTENS_XY[0]:
            target_text = f"{DEFAULT_POINTER_START_TOKEN}{DEFAULT_POINTER_PAD_TOKEN}{DEFAULT_POINTER_END_TOKEN}"
        else:
            target_text = f"{DEFAULT_POINTER_START_TOKEN}{DEFAULT_POINTER_PAD_TOKEN}{DEFAULT_POINTER_END_TOKEN}, {DEFAULT_POINTER_START_TOKEN}{DEFAULT_POINTER_PAD_TOKEN}{DEFAULT_POINTER_END_TOKEN}"
        text = re.sub(
            pattern,
            target_text,
            text
        )
    
    coordinates = []
    all_matches.sort(key=lambda x: x[0])
    # Extract coordinates in order
    for _, groups in all_matches:
        # When two coordinate values are found, parse them as one (x, y) pair.
        if len(groups) == 2:
            x_str, y_str = groups
            x = adjust_coord(ast.literal_eval(x_str))
            y = adjust_coord(ast.literal_eval(y_str))
            coordinates.append((x, y))
        # When four coordinate values are found, parse them as two pairs.
        elif len(groups) == 4:
            x1_str, y1_str, x2_str, y2_str = groups
            x1 = adjust_coord(ast.literal_eval(x1_str))
            y1 = adjust_coord(ast.literal_eval(y1_str))
            x2 = adjust_coord(ast.literal_eval(x2_str))
            y2 = adjust_coord(ast.literal_eval(y2_str))
            coordinates.append((x1, y1))
            coordinates.append((x2, y2))
    
    return text, coordinates

def get_token_index(image_processor, image, point_x, point_y):
    """
    Get the index of the visual token that contains the point (x, y).
    Args:
        image_processor: the image processor
        image: the image in PIL format
        point_x: the x coordinate of the point, in [0, 1].
        point_y: the y coordinate of the point, in [0, 1].
    """
    if len(image) != 1:
        raise ValueError(f"Expected 1 image, got {len(image)}")
    
    # get the original image size and the resized image size
    image = image[0]
    w, h = image.size
    px, py = w * point_x, h * point_y
    # rank0_print(f"px: {px}, py: {py}")
    # get the token index
    merge_patch_size = image_processor.patch_size * image_processor.merge_size
    x_index = math.floor(px / merge_patch_size)
    y_index = math.floor(py / merge_patch_size)
    
    visual_token_index = y_index * (w // merge_patch_size) + x_index

    # merge all above print into one line
    return visual_token_index

def get_multi_patch_labels(image_processor, image, bbox_gt):
    """
    Get the multi-patch labels for the bounding box.
    Args:
        image_processor: the image processor
        image: the image in PIL format
        bbox_gt: the bounding box in the format of (x_min, y_min, x_max, y_max) [0,1]
    """
    if len(image) != 1:
        raise ValueError(f"Expected 1 image, got {len(image)}")

    # Get the original image size and the resized image size
    image = image[0]
    w, h = image.size

    bbox_gt = [bbox_gt[0]*w, bbox_gt[1]*h, bbox_gt[2]*w, bbox_gt[3]*h]
    # Extract bounding box coordinates
    x_min, y_min, x_max, y_max = bbox_gt
    x_min = max(0, x_min)
    y_min = max(0, y_min)
    x_max = min(w, x_max)
    y_max = min(h, y_max)

    merge_patch_size = image_processor.patch_size * image_processor.merge_size
    assert w % merge_patch_size == 0 and h % merge_patch_size == 0, f"Image size {w}x{h} is not divisible by merge_patch_size {merge_patch_size}"
    grid_h, grid_w = h // merge_patch_size, w // merge_patch_size

    binary_mask = torch.zeros(grid_h * grid_w)
    # Iterate through all patches, check if they overlap with the bounding box
    for y_idx in range(grid_h):
        for x_idx in range(grid_w):
            # Calculate patch boundaries
            patch_x_min = x_idx * merge_patch_size
            patch_y_min = y_idx * merge_patch_size
            patch_x_max = patch_x_min + merge_patch_size
            patch_y_max = patch_y_min + merge_patch_size
            
            # Check if patch overlaps with the bounding box
            if not (patch_x_max <= x_min or patch_x_min >= x_max or 
                    patch_y_max <= y_min or patch_y_min >= y_max):
                # Calculate patch index in the flattened grid
                patch_idx = y_idx * grid_w + x_idx
                binary_mask[patch_idx] = 1

    return binary_mask


def get_multi_patch_labels_center(image_processor, image, bbox_gt):
    """
    Get the multi-patch labels for the bounding box.
    Args:
        image_processor: the image processor
        image: the image in PIL format
        bbox_gt: the bounding box in the format of (x_min, y_min, x_max, y_max) [0,1]
    """
    if len(image) != 1:
        raise ValueError(f"Expected 1 image, got {len(image)}")

    # Get the original image size and the resized image size
    image = image[0]
    w, h = image.size

    bbox_gt = [bbox_gt[0] * w, bbox_gt[1] * h, bbox_gt[2] * w, bbox_gt[3] * h]
    # Extract bounding box coordinates
    x_min, y_min, x_max, y_max = bbox_gt
    x_min = max(0, x_min)
    y_min = max(0, y_min)
    x_max = min(w, x_max)
    y_max = min(h, y_max)

    # 计算bounding box的中心点
    center_x = (x_min + x_max) / 2
    center_y = (y_min + y_max) / 2

    merge_patch_size = image_processor.patch_size * image_processor.merge_size
    assert w % merge_patch_size == 0 and h % merge_patch_size == 0, f"Image size {w}x{h} is not divisible by merge_patch_size {merge_patch_size}"
    grid_h, grid_w = h // merge_patch_size, w // merge_patch_size

    binary_mask = torch.zeros(grid_h * grid_w)
    # Iterate through all patches, check if they overlap with the bounding box
    for y_idx in range(grid_h):
        for x_idx in range(grid_w):
            # Calculate patch boundaries
            patch_x_min = x_idx * merge_patch_size
            patch_y_min = y_idx * merge_patch_size
            patch_x_max = patch_x_min + merge_patch_size
            patch_y_max = patch_y_min + merge_patch_size

            # Check gt center in patch
            if (patch_x_min <= center_x < patch_x_max) and (patch_y_min <= center_y < patch_y_max):
                patch_idx = y_idx * grid_w + x_idx
                binary_mask[patch_idx] = 1

    return binary_mask

def get_multi_patch_labels_gauss(image_processor, image, bbox_gt, sigma_factor=4.1325):
    """
    Get the multi-patch labels for the bounding box.
    Args:
        image_processor: the image processor
        image: the image in PIL format
        bbox_gt: the bounding box in the format of (x_min, y_min, x_max, y_max) [0,1]
    """
    if len(image) != 1:
        raise ValueError(f"Expected 1 image, got {len(image)}")

        # Get the original image size and the resized image size
    
    image = image[0]
    w, h = image.size

    bbox_gt = [bbox_gt[0] * w, bbox_gt[1] * h, bbox_gt[2] * w, bbox_gt[3] * h]
    # Extract bounding box coordinates
    x_min, y_min, x_max, y_max = bbox_gt  # 这个时候是绝对坐标
    x_min = max(0, x_min)
    y_min = max(0, y_min)
    x_max = min(w, x_max)
    y_max = min(h, y_max)

    x_center = x_min + (x_max - x_min) / 2  # 绝对坐标
    y_center = y_min + (y_max - y_min) / 2

    # # 计算标准差（基于96%置信区间）
    # sigma_x = (x_max - x_min) / 4.1325
    # sigma_y = (y_max - y_min) / 4.1325

    # 修改成超参数

    # 计算标准差（基于 96% 置信区间，可调节）
    sigma_x = (x_max - x_min) / sigma_factor
    sigma_y = (y_max - y_min) / sigma_factor

    merge_patch_size = image_processor.patch_size * image_processor.merge_size
    assert w % merge_patch_size == 0 and h % merge_patch_size == 0, f"Image size {w}x{h} is not divisible by merge_patch_size {merge_patch_size}"
    grid_h, grid_w = h // merge_patch_size, w // merge_patch_size

    binary_mask = torch.zeros(grid_h * grid_w)
    # Iterate through all patches, check if they overlap with the bounding box
    for y_idx in range(grid_h):
        for x_idx in range(grid_w):
            # Calculate patch boundaries
            patch_x_min = x_idx * merge_patch_size
            patch_y_min = y_idx * merge_patch_size
            patch_x_max = patch_x_min + merge_patch_size
            patch_y_max = patch_y_min + merge_patch_size

            # 只在 patch 与 GT 的交集区间内积分
            inter_xmin = max(patch_x_min, x_min)
            inter_xmax = min(patch_x_max, x_max)
            inter_ymin = max(patch_y_min, y_min)
            inter_ymax = min(patch_y_max, y_max)

            from scipy.stats import norm
            prob_x = norm.cdf(inter_xmax, x_center, sigma_x) - norm.cdf(inter_xmin, x_center, sigma_x)
            prob_y = norm.cdf(inter_ymax, y_center, sigma_y) - norm.cdf(inter_ymin, y_center, sigma_y)

            # from scipy.stats import norm
            # prob_x = norm.cdf(patch_x_max, x_center, sigma_x) - norm.cdf(patch_x_min, x_center, sigma_x)
            # prob_y = norm.cdf(patch_y_max, y_center, sigma_y) - norm.cdf(patch_y_min, y_center, sigma_y)

            # Check if patch overlaps with the bounding box
            if not (patch_x_max <= x_min or patch_x_min >= x_max or
                    patch_y_max <= y_min or patch_y_min >= y_max):
                # Calculate patch index in the flattened grid
                patch_idx = y_idx * grid_w + x_idx
                binary_mask[patch_idx] = prob_x * prob_y

    return binary_mask

def token_index_to_coordinates(image_processor, visual_token_index, image_width, image_height):
    merge_patch_size = image_processor.patch_size * image_processor.merge_size
    x_index = visual_token_index % (image_width // merge_patch_size)
    y_index = visual_token_index // (image_width // merge_patch_size)
    px = x_index * merge_patch_size + merge_patch_size / 2
    py = y_index * merge_patch_size + merge_patch_size / 2
    return px, py

class LazySupervisedDataset(Dataset):
    def __init__(
        self,
        tokenizer: transformers.PreTrainedTokenizer,
        processor: transformers.ProcessorMixin,
        data_path: str,
        label_style: str,
        sigma_factor: float,
        data_args,
    ):
        super().__init__()
        self.tokenizer = tokenizer
        self.processor = processor
        self.list_data_dict = []
        self.list_image_path = []
        self.pointer_pad_token_id = tokenizer.encode(DEFAULT_POINTER_PAD_TOKEN)[0]
        self.pointer_start_token_id = tokenizer.encode(DEFAULT_POINTER_START_TOKEN)[0]
        self.pointer_end_token_id = tokenizer.encode(DEFAULT_POINTER_END_TOKEN)[0]
        self.label_style = label_style
        self.sigma_factor = sigma_factor

        print(f"label_style:{self.label_style}")
        print(f"sigma_factor:{self.sigma_factor}")
        # Handle multiple JSON files specified in the data_path
        if "{" in data_path and "}" in data_path:
            base_path, file_pattern = re.match(r"^(.*)\{(.*)\}\.json$", data_path).groups()
            file_names = file_pattern.split(",")
            rank0_print(f"Loading {file_names} from {base_path}")
            data_args.dataset_paths = []
            for file_name in file_names:
                data_args.dataset_paths.append(f"{base_path}{file_name}.json")
                full_path = f"{base_path}{file_name}.json"
                rank0_print(f"Loading {full_path}")
                with open(full_path) as file:
                    cur_data_dict = json.load(file)
                    rank0_print(f"Loaded {len(cur_data_dict)} samples from {full_path}")
                    self.list_data_dict.extend(cur_data_dict)
        elif data_path.endswith(".yaml"):
            with open(data_path) as file:
                yaml_data = yaml.safe_load(file)
                datasets = yaml_data.get("datasets")
                # file should be in the format of:
                # datasets:
                #   - json_path: xxxx1.json
                #     sampling_strategy: first:1000
                #   - json_path: xxxx2.json
                #     sampling_strategy: end:3000
                #   - json_path: xxxx3.json
                #     sampling_strategy: random:999
                data_args.dataset_paths = [dataset.get("json_path") for dataset in datasets]
                for dataset in datasets:
                    json_path = dataset.get("json_path")
                    sampling_strategy = dataset.get("sampling_strategy", "all")
                    images_folder = dataset.get("images_folder")
                    sampling_number = None

                    rank0_print(f"Loading {json_path} with {sampling_strategy} sampling strategy")

                    if json_path.endswith(".jsonl"):
                        cur_data_dict = []
                        with open(json_path) as json_file:
                            for line in json_file:
                                cur_data_dict.append(json.loads(line.strip()))
                    elif json_path.endswith(".json"):
                        # NOTE: we only use json_path with .json now
                        # Handle the images_folder in yaml
                        with open(json_path) as json_file:
                            cur_data_dict = json.load(json_file)
                    else:
                        raise ValueError(f"Unsupported file type: {json_path}")

                    if ":" in sampling_strategy:
                        sampling_strategy, sampling_number = sampling_strategy.split(":")
                        if "%" in sampling_number:
                            sampling_number = math.ceil(int(sampling_number.split("%")[0]) * len(cur_data_dict) / 100)
                        else:
                            sampling_number = int(sampling_number)

                    # Apply the sampling strategy
                    if sampling_strategy == "first" and sampling_number is not None:
                        cur_data_dict = cur_data_dict[:sampling_number]
                    elif sampling_strategy == "end" and sampling_number is not None:
                        cur_data_dict = cur_data_dict[-sampling_number:]
                    elif sampling_strategy == "random" and sampling_number is not None:
                        random.shuffle(cur_data_dict)
                        cur_data_dict = cur_data_dict[:sampling_number]

                    rank0_print(f"Loaded {len(cur_data_dict)} samples from {json_path}")
                    self.list_data_dict.extend(cur_data_dict)
                    self.list_image_path.extend([images_folder] * len(cur_data_dict))
        else:
            data_args.dataset_paths = [data_path]
            rank0_print(f"Loading {data_path}")
            with open(data_path) as file:
                cur_data_dict = json.load(file)
                rank0_print(f"Loaded {len(cur_data_dict)} samples from {data_path}")
                self.list_data_dict.extend(cur_data_dict)
                self.list_image_path.extend([""] * len(cur_data_dict))  # NOTE: the image subfolder is empty...

        rank0_print(f"Loaded {len(self.list_data_dict)} samples from {data_path}")
        rank0_print("Formatting inputs...Skip in lazy mode")
        self.tokenizer = tokenizer
        self.data_args = data_args

    def __len__(self):
        return len(self.list_data_dict)

    @property
    def lengths(self):
        length_list = []
        for sample in self.list_data_dict:
            img_tokens = (
                1200 * len(sample["image"]) if isinstance(sample["image"], list) else 1200 if "image" in sample else 0
            )
            length_list.append(sum(len(conv["value"].split()) for conv in sample["conversations"]) + img_tokens)
        return length_list

    @property
    def modality_lengths(self):
        length_list = []
        for sample in self.list_data_dict:
            cur_len = sum(len(conv["value"].split()) for conv in sample["conversations"])
            assert cur_len > 0, f"Conversation length is 0 for {sample}"

            img_tokens = (
                1200 * len(sample["image"]) if isinstance(sample["image"], list) else 1200 if "image" in sample else 0
            )

            if "image" in sample or "video" in sample or self.data_args.early_mix_text:
                length_list.append(cur_len + img_tokens)
            else:
                length_list.append(-cur_len)
        return length_list

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        sample = self._get_item(i)
        if sample is None:
            new_index = random.randint(0, len(self.list_data_dict) - 1)
            return self.__getitem__(new_index)
        else:
            return sample
        try:
            sample = self._get_item(i)
            if sample is None:
                new_index = random.randint(0, len(self.list_data_dict) - 1)
                return self.__getitem__(new_index)
        except Exception as e:
            print(f"Failed to fetch sample {i}. Exception:", e)
            new_index = random.randint(0, len(self.list_data_dict) - 1)
            return self.__getitem__(new_index)
        return sample

    def _get_item(self, i) -> Dict[str, torch.Tensor]:
        sources = self.list_data_dict[i]
        image_path = os.path.join(self.data_args.image_folder, self.list_image_path[i])

        if "image" in sources:
            image_file = self.list_data_dict[i]["image"]
            if type(image_file) is list:
                image_list = [os.path.join(image_path, image_file) for image_file in image_file]
            else:
                image_list = [os.path.join(image_path, image_file)]

            sources = copy.deepcopy(sources["conversations"])
        elif "video" in sources:
            raise NotImplementedError("Video is not supported for Qwen2VL")
        else:
            sources = copy.deepcopy(sources["conversations"])

        item_id = self.list_data_dict[i].get("id", i)

        data_dict = self.preprocess_qwen2vl(sources, self.tokenizer, self.processor, image_list, id=item_id, label_style=self.label_style, sigma_factor=self.sigma_factor)
        if isinstance(i, int):
            data_dict = {
                "input_ids": data_dict["input_ids"][0],
                "labels": data_dict["labels"][0],
                "coordinates": data_dict["coordinates"][0],
                "visual_token_indices_of_coordinates": data_dict["visual_token_indices_of_coordinates"][0],
                "pixel_values": data_dict["pixel_values"],
                "image_grid_thw": data_dict["image_grid_thw"],
                "multi_patch_labels": data_dict["multi_patch_labels"][0],   # add multi_patch_labels                
            }

        data_dict["id"] = item_id

        # return None if the input_ids is longer than the model_max_length
        n_image_tokens = (
            data_dict["image_grid_thw"][0][0] * 
            data_dict["image_grid_thw"][0][1] * 
            data_dict["image_grid_thw"][0][2] / 
            self.processor.image_processor.merge_size / 
            self.processor.image_processor.merge_size
        )
        if (len(data_dict["input_ids"]) + n_image_tokens) > self.tokenizer.model_max_length:
            rank0_print(f"=== Removed data_dict {i} because it is longer than the model_max_length: {len(data_dict['input_ids'])} + {n_image_tokens} > {self.tokenizer.model_max_length}")
            return None

        return data_dict

    def preprocess_qwen2vl(
        self,
        source, # conversations
        tokenizer: transformers.PreTrainedTokenizer,
        processor: transformers.ProcessorMixin,
        image: list,
        system_message: str = grounding_system_message,
        agent_mode: bool = True,
        chat_template: str = chat_template,
        assistant_template: str = assistant_template,
        id: int = None,
        label_style: str = "all",
        sigma_factor: float = 4.1325
    ) -> Dict:
        roles = {"human": "user", "gpt": "assistant", "system": "system"}
        assistant_template = assistant_template if agent_mode else chat_template
        processor.tokenizer = tokenizer
        assert tokenizer.additional_special_tokens == ADDITIONAL_SPECIAL_TOKENS

        # Apply prompt templates
        pixel_values, image_grid_thw = None, None

        input_id, target = [], []
        coordinates = []
        visual_token_indices_of_coordinates = []
        multi_patch_labels = []
        
        image_list = []
        image_index = 0

        ## prepare the system message
        if roles[source[0]["from"]] == "system":
            system_message = source[0]["value"]
            source = source[1:self.data_args.max_conv_turns]
        # else: use the constant system message
        system_input_id = tokenizer.apply_chat_template(
            conversation=[{"role": "system", "content": [{"type": "text", "text": system_message}]}],
            chat_template=chat_template,
        )
        input_id += system_input_id
        target += [IGNORE_INDEX] * len(system_input_id)

        ## prepare user-assistant conversation
        for conv in source:
            # regularize the conversation format
            try:
                role = conv["role"]
                content = conv["content"]
            except Exception:
                role = conv["from"]
                content = conv["value"]
            role = roles.get(role, role)

            # Count the number of <image> tokens in the content
            image_count = content.count(DEFAULT_IMAGE_TOKEN)
            if image_count > 0:
                assert role == "user", "Images are only supported for user messages"
                # include image information regarding to current conversation turn
                image_placeholders = []
                for _ in range(image_count):
                    image_placeholders.append({
                        "type": "image",
                        "image": image[image_index],
                        "min_pixels": self.processor.image_processor.min_pixels,
                        "max_pixels": self.processor.image_processor.max_pixels,
                    })
                    image_index += 1

                content = content.replace(DEFAULT_IMAGE_TOKEN, "")
                conv = {"role": role, "content": image_placeholders + [{"type": "text", "text": content}]}

                image_inputs, _ = process_vision_info([conv]) # list of PIL.Image.Image
                image_list.extend(image_inputs)
                
                templated_conv = tokenizer.apply_chat_template(
                    conversation=[conv], chat_template=chat_template, tokenize=False
                )
                inputs = processor(text=[templated_conv], images=image_inputs, return_tensors="pt")

                if pixel_values is None and image_grid_thw is None:
                    pixel_values = inputs["pixel_values"]
                    image_grid_thw = inputs["image_grid_thw"]
                else:
                    pixel_values = torch.concat([pixel_values, inputs["pixel_values"]], dim=0)
                    image_grid_thw = torch.concat([image_grid_thw, inputs["image_grid_thw"]], dim=0)
            else:
                if role in ["user", "system"]:
                    conv = {"role": role, "content": [{"type": "text", "text": content}]}
                else:  # assistant
                    conv = {
                        "role": role,
                        "content": [{"type": "text", "text": content}],
                        "recipient": conv.get("recipient", "os"),
                        "end_turn": conv.get("end_turn", True),
                        "bbox_gt": conv.get("bbox_gt", None),
                    }
                    if conv["recipient"] == "os":
                        if len(image_inputs) == 0:
                            raise ValueError("No image found for visual grounding")
                        # replace the coordinates with the special tokens
                        text, coord = reformat_coordinates(conv["content"][0]["text"])
                        conv["content"][0]["text"] = text
                        # rank0_print(f"coord: {coord}")

                        # get the visual token indices of the coordinates
                        coordinates.extend(coord)
                        for (point_x, point_y) in coord:
                            visual_token_index = get_token_index(
                                processor.image_processor,
                                image_list,
                                point_x,
                                point_y
                            )
                            # px, py = token_index_to_coordinates(
                            #     processor.image_processor,
                            #     visual_token_index,
                            #     image_list[0].size[0], # make sure the size here is after qwen2vl processing
                            #     image_list[0].size[1]
                            # )
                            # rank0_print(f"estimated px: {px}, py: {py}")
                            visual_token_indices_of_coordinates.append(visual_token_index)

                            if conv["bbox_gt"] is not None:

                                # three method

                                if label_style == "all":
                                    patch_mask = get_multi_patch_labels(
                                        processor.image_processor,
                                        image_list,
                                        conv["bbox_gt"]
                                    )
                                elif label_style == "center":
                                    patch_mask = get_multi_patch_labels_center(
                                        processor.image_processor,
                                        image_list,
                                        conv["bbox_gt"]
                                    )
                                elif label_style == "gauss":
                                    patch_mask = get_multi_patch_labels_gauss(
                                        processor.image_processor,
                                        image_list,
                                        conv["bbox_gt"],
                                        sigma_factor
                                    )

                                multi_patch_labels.append(patch_mask)

                templated_conv = tokenizer.apply_chat_template(
                    conversation=[conv],
                    chat_template=assistant_template,
                    tokenize=False,
                )
                inputs = processor(text=[templated_conv], return_tensors="pt")

            encode_id = inputs.input_ids[0].tolist()

            input_id += encode_id
            if role in ["user", "system"]:
                target += [IGNORE_INDEX] * len(encode_id)
            else:
                target += encode_id

        assert len(input_id) == len(target), f"{len(input_id)} != {len(target)}"

        # make the labels of all pointer_end_token_id to be IGNORE_INDEX
        target = [IGNORE_INDEX if token == self.pointer_end_token_id else token for token in target]

        input_ids = torch.tensor([input_id], dtype=torch.long)
        targets = torch.tensor([target], dtype=torch.long)
        visual_token_indices_of_coordinates = torch.tensor([visual_token_indices_of_coordinates], dtype=torch.long) if len(visual_token_indices_of_coordinates) > 0 else [None]
        coordinates = [coordinates] if len(coordinates) > 0 else [None]

        # process multi_patch_labels
        if len(multi_patch_labels) > 0:
            multi_patch_labels = [torch.stack(multi_patch_labels)]
        else:
            multi_patch_labels = [None]

        data_dict = {
            "input_ids": input_ids,  # tensor(bs x seq_len)
            "labels": targets,  # tensor(bs x seq_len)
        }

        if pixel_values is not None:
            data_dict["pixel_values"] = pixel_values
            data_dict["image_grid_thw"] = image_grid_thw
        
        # if len(coordinates[0]) != len(visual_token_indices_of_coordinates[0]):
        #     raise ValueError(f"The number of coordinates ({len(coordinates[0])}) does not match the number of image token indices ({len(visual_token_indices_of_coordinates[0])})")
        data_dict["coordinates"] = coordinates
        data_dict["visual_token_indices_of_coordinates"] = visual_token_indices_of_coordinates
        data_dict["multi_patch_labels"] = multi_patch_labels
        
        return data_dict
