from datasets import load_from_disk, DatasetDict, Features, Value, Image
import json
from refer_seg_dataset import ReferSegDataset
from pycocotools import mask
import numpy as np
import cv2
from PIL import Image as PILImage
import io
import base64
import torch

from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor

sam2_checkpoint = "facebook/sam2.1_hiera_large.pt" # your own SAM2 path
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"

device = "cuda" if torch.cuda.is_available() else "cpu"
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device)

predictor = SAM2ImagePredictor(sam2_model)

def scale_box_coordinates(bbox_2d, x_factor, y_factor):
    return [
        int(bbox_2d[0] * x_factor + 0.5),
        int(bbox_2d[1] * y_factor + 0.5),
        int(bbox_2d[2] * x_factor + 0.5),
        int(bbox_2d[3] * y_factor + 0.5)
    ]

def scale_point_coordinates(point_2d, x_factor, y_factor):
    return [
        int(point_2d[0] * x_factor + 0.5),
        int(point_2d[1] * y_factor + 0.5)
    ]

def encode_mask_to_base64(mask_array):
    pil_img = PILImage.fromarray(mask_array * 255)  # Convert 0/1 to 0/255 grayscale
    buffer = io.BytesIO()
    pil_img.save(buffer, format="PNG")
    return base64.b64encode(buffer.getvalue()).decode("utf-8")

def compute_iou(boxA, boxB):
    """
    boxA, boxB: [x1, y1, x2, y2]
    return: IoU 值（float）
    """
    xA = max(boxA[0], boxB[0])
    yA = max(boxA[1], boxB[1])
    xB = min(boxA[2], boxB[2])
    yB = min(boxA[3], boxB[3])

    interW = max(0, xB - xA + 1)
    interH = max(0, yB - yA + 1)
    interArea = interW * interH

    boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1)
    boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1)

    iou = interArea / float(boxAArea + boxBArea - interArea + 1e-6)  
    return iou

def compute_iou_mask(mask1, mask2):
    intersection = np.logical_and(mask1, mask2).sum()
    union = np.logical_or(mask1, mask2).sum()
    if union == 0:
        return 0
    return intersection / union

def get_mask_from_point(predictor, input_point, input_label, box):
    masks, scores, logits = predictor.predict(
        point_coords=input_point,
        point_labels=input_label,
        box=box,
        multimask_output=False,
    )
    sorted_ind = np.argsort(scores)[::-1]
    masks = masks[sorted_ind]
    scores = scores[sorted_ind]
    logits = logits[sorted_ind]
    return masks

# 加载原始数据集
dataset = load_from_disk("data/VisionReasoner_multi_object_1k_840")['train']
ref_dataset_name = "refcocog"
ref_dataset = ReferSegDataset(base_image_dir="data", refer_seg_data=ref_dataset_name, data_split="train")
refer_seg_ds = ref_dataset.refer_seg_data[ref_dataset_name]
annotations = refer_seg_ds["annotations"]
img2refs = refer_seg_ds["img2refs"]
images = refer_seg_ds["images"]

image_resize = 840
threshold_iou = 0.6
# 添加 "mask" 字段
def add_mask_field(example):
    solution_dict_list = example["solution"]
    if isinstance(solution_dict_list, str):
        solution_dict_list = json.loads(solution_dict_list)
    new_solution_dict_list = []

    ann_id_list = example["id"].split("_")[1:]
    for ann_idx, ann_id in enumerate(ann_id_list):
        ann = annotations[int(ann_id)]
        if ann_idx == 0:
            image_id = int(ann["image_id"])
        else:
            assert image_id == int(ann["image_id"])

        select_image_info = next((image for image in images if int(image['id']) == image_id), None)
        assert select_image_info is not None

        file_name = select_image_info["file_name"]
        image = cv2.imread(file_name)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = cv2.resize(image,  (image_resize, image_resize), interpolation=cv2.INTER_NEAREST)
        predictor.set_image(image)

        rle = mask.frPyObjects(
            ann["segmentation"], select_image_info["height"], select_image_info["width"]
        )
        m = mask.decode(rle)
        m = np.sum(m, axis=2).astype(np.uint8)
        m = cv2.resize(m, (image_resize, image_resize), interpolation=cv2.INTER_NEAREST)

        left = np.where(m == 1)[1].min()
        top = np.where(m == 1)[0].min()
        right = np.where(m == 1)[1].max()
        bottom = np.where(m == 1)[0].max()
        box = [left, top, right, bottom]

        solution_dict = solution_dict_list[ann_idx]
        mask_pred = get_mask_from_point(predictor, np.array([solution_dict["point_2d"]]), np.array([1]), np.array(solution_dict["bbox_2d"]))
        
        mask_pred = mask_pred[0].astype(bool)
        mask_gt = m.astype(bool)
        iou = compute_iou_mask(mask_pred, mask_gt)
        
        if iou < threshold_iou:
            print("iou:", iou)
            continue

new_dataset = dataset.map(add_mask_field)

