from datasets import load_from_disk, DatasetDict, Features, Value, Image
import json
from refer_seg_dataset import ReferSegDataset
from pycocotools import mask
import numpy as np
import cv2
from PIL import Image as PILImage
import io
import base64

def scale_box_coordinates(bbox_2d, x_factor, y_factor):
    return [
        int(bbox_2d[0] * x_factor + 0.5),
        int(bbox_2d[1] * y_factor + 0.5),
        int(bbox_2d[2] * x_factor + 0.5),
        int(bbox_2d[3] * y_factor + 0.5)
    ]

def scale_point_coordinates(point_2d, x_factor, y_factor):
    return [
        int(point_2d[0] * x_factor + 0.5),
        int(point_2d[1] * y_factor + 0.5)
    ]

def encode_mask_to_base64(mask_array):
    pil_img = PILImage.fromarray(mask_array * 255)  # Convert 0/1 to 0/255 grayscale
    buffer = io.BytesIO()
    pil_img.save(buffer, format="PNG")
    return base64.b64encode(buffer.getvalue()).decode("utf-8")

def compute_iou(boxA, boxB):
    """
    boxA, boxB: [x1, y1, x2, y2]
    return: IoU 值（float）
    """
    xA = max(boxA[0], boxB[0])
    yA = max(boxA[1], boxB[1])
    xB = min(boxA[2], boxB[2])
    yB = min(boxA[3], boxB[3])

    interW = max(0, xB - xA + 1)
    interH = max(0, yB - yA + 1)
    interArea = interW * interH

    boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1)
    boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1)

    iou = interArea / float(boxAArea + boxBArea - interArea + 1e-6)  
    return iou


dataset = load_from_disk("data/VisionReasoner_multi_object_1k_840")['train']
ref_dataset_name = "refcocog"
ref_dataset = ReferSegDataset(base_image_dir="data", refer_seg_data=ref_dataset_name, data_split="train")
refer_seg_ds = ref_dataset.refer_seg_data[ref_dataset_name]
annotations = refer_seg_ds["annotations"]
img2refs = refer_seg_ds["img2refs"]
images = refer_seg_ds["images"]

image_resize = 840

# 添加 "mask" 字段
def add_mask_field(example):
    solution_dict_list = example["solution"]
    if isinstance(solution_dict_list, str):
        solution_dict_list = json.loads(solution_dict_list)
    new_solution_dict_list = []

    ann_id_list = example["id"].split("_")[1:]
    for ann_idx, ann_id in enumerate(ann_id_list):
        ann = annotations[int(ann_id)]
        if ann_idx == 0:
            image_id = int(ann["image_id"])
        else:
            assert image_id == int(ann["image_id"])

        select_image_info = next((image for image in images if int(image['id']) == image_id), None)
        assert select_image_info is not None

        rle = mask.frPyObjects(
            ann["segmentation"], select_image_info["height"], select_image_info["width"]
        )
        m = mask.decode(rle)
        m = np.sum(m, axis=2).astype(np.uint8)
        m = cv2.resize(m, (image_resize, image_resize), interpolation=cv2.INTER_NEAREST)

        left = np.where(m == 1)[1].min()
        top = np.where(m == 1)[0].min()
        right = np.where(m == 1)[1].max()
        bottom = np.where(m == 1)[0].max()
        box = [left, top, right, bottom]

        solution_dict = solution_dict_list[ann_idx]

        assert compute_iou(box, solution_dict["bbox_2d"]) > 0.9
        solution_dict["mask"] = encode_mask_to_base64(m)  
        solution_dict["bbox_2d"] = box
        new_solution_dict_list.append(solution_dict)

    return {
        "id": example["id"],
        "problem": example["problem"],
        "solution": new_solution_dict_list,
        "image": example["image"],
        "img_height": example["img_height"],
        "img_width": example["img_width"]
    }


new_dataset = dataset.map(add_mask_field)


features = Features({
    "id": Value("string"),
    "problem": Value("string"),
    "solution": [{
        "bbox_2d": [Value("int64")],
        "point_2d": [Value("int64")],
        "mask": Value("string")  # base64
    }],
    "image": Image(),
    "img_height": Value("int64"),
    "img_width": Value("int64")
})


new_dataset = new_dataset.cast(features)


updated_dataset_dict = DatasetDict({"train": new_dataset})
save_path = "data/VisionReasoner_multi_object_1k_840_with_mask"
updated_dataset_dict.save_to_disk(save_path)

