from datasets import Dataset, load_from_disk, DatasetDict, Features, Value, Image, load_dataset,Sequence
import json
from refer_seg_dataset import ReferSegDataset
from pycocotools import mask
import numpy as np
import cv2
from PIL import Image as PILImage
from PIL import ImageOps
import matplotlib.pyplot as plt
import io
import os
import base64
import torch
from scipy import ndimage
# from sam2.build_sam import build_sam2
# from sam2.sam2_image_predictor import SAM2ImagePredictor

# sam2_checkpoint = "facebook/sam2.1_hiera_large.pt" # your own SAM2 path
# model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"

# device = "cuda" if torch.cuda.is_available() else "cpu"
# sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device)

# predictor = SAM2ImagePredictor(sam2_model)


def show_mask(mask, ax, random_color=False, borders=True):
    # mask: 2D array, 0/1
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])  
    h, w = mask.shape[:2]
    mask = mask.astype(np.uint8)
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    
    if borders:
        contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)

        contours = [cv2.approxPolyDP(contour, epsilon=0.01*cv2.arcLength(contour, True), closed=True) 
                    for contour in contours]
        mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2)
    
    ax.imshow(mask_image)

def show_points(coords, labels, ax, marker_size=100):
    # coords: Nx2 array, labels: 0/1
    coords = np.array(coords)
    labels = np.array(labels)
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, 
               edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, 
               edgecolor='white', linewidth=1.25)

def show_box(box, ax):
    # box: [x0, y0, x1, y1]
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))

def show_masks(image, masks, scores, point_coords=None, box_coords=None, input_labels=None, borders=True):
    # masks: list of 2D arrays
    # scores: list of floats
    for i, (mask, score) in enumerate(zip(masks, scores)):
        plt.figure(figsize=(10, 10))
        plt.imshow(image)
        show_mask(mask, plt.gca(), borders=borders)
        
        # if point_coords is not None:
        #     assert input_labels is not None
        #     show_points(point_coords, input_labels, plt.gca())
        
        if box_coords is not None:
            show_box(box_coords, plt.gca())
        
        if len(scores) > 1:
            plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
        
        plt.axis('off')
        # plt.show()
        plt.savefig("my_plot.png")

def scale_box_coordinates(bbox_2d, x_factor, y_factor):
    return [
        int(bbox_2d[0] * x_factor + 0.5),
        int(bbox_2d[1] * y_factor + 0.5),
        int(bbox_2d[2] * x_factor + 0.5),
        int(bbox_2d[3] * y_factor + 0.5)
    ]

def scale_point_coordinates(point_2d, x_factor, y_factor):
    return [
        int(point_2d[0] * x_factor + 0.5),
        int(point_2d[1] * y_factor + 0.5)
    ]

def encode_mask_to_base64(mask_array):
    pil_img = PILImage.fromarray(mask_array * 255)  # Convert 0/1 to 0/255 grayscale
    buffer = io.BytesIO()
    pil_img.save(buffer, format="PNG")
    return base64.b64encode(buffer.getvalue()).decode("utf-8")

def compute_iou(boxA, boxB):
    """
    boxA, boxB: [x1, y1, x2, y2]
    return: IoU 值（float）
    """
    xA = max(boxA[0], boxB[0])
    yA = max(boxA[1], boxB[1])
    xB = min(boxA[2], boxB[2])
    yB = min(boxA[3], boxB[3])

    interW = max(0, xB - xA + 1)
    interH = max(0, yB - yA + 1)
    interArea = interW * interH

    boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1)
    boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1)

    iou = interArea / float(boxAArea + boxBArea - interArea + 1e-6) 
    return iou

def compute_iou_mask(mask1, mask2):
    intersection = np.logical_and(mask1, mask2).sum()
    union = np.logical_or(mask1, mask2).sum()
    if union == 0:
        return 0
    return intersection / union

def get_mask_from_point(predictor, input_point, input_label, box):
    masks, scores, logits = predictor.predict(
        point_coords=input_point,
        point_labels=input_label,
        box=box,
        multimask_output=False,
    )
    sorted_ind = np.argsort(scores)[::-1]
    masks = masks[sorted_ind]
    scores = scores[sorted_ind]
    logits = logits[sorted_ind]
    return masks


data_path = "data/ReasonSeg_train/train"
jpg_files = [f for f in os.listdir(data_path) if f.lower().endswith(".jpg")]
jpg_files.sort()

image_resize = 840
threshold_iou = 0.6

features = Features({
    "id": Value("string"),
    "problem": Sequence(Value("string")),
    "solution":  Sequence({
        "mask": Image(),   
        "bbox_2d": Sequence(Value("int32")),
        "point_2d": Sequence(Value("int32")),
    }),
    "image": Image(),    
    "img_height": Value("int32"),
    "img_width": Value("int32"),
})

data_list = []

for jpg_file in jpg_files:
    id_name = jpg_file.split(".")[0]
    json_file = id_name + '.json'
    json_path = os.path.join(data_path, json_file)
    jpg_path = os.path.join(data_path, jpg_file)
    image = PILImage.open(jpg_path)
    image = ImageOps.exif_transpose(PILImage.open(jpg_path)) 
    width, height = image.size
    print("jpg_path:", jpg_path)
    resized_image = image.resize((image_resize,image_resize), PILImage.Resampling.BILINEAR)

    with open(json_path, "r", encoding="utf-8") as f:
        anno = json.load(f)
    problem_list = anno["text"]
    solution_dict_list = []
    for ann in anno["shapes"]:
        if ann["label"] == "target":
            solution_dict = {}
            assert len(ann['points']) >= 3
            segmentation = [np.array(ann['points'], dtype=int).flatten()]
            rle = mask.frPyObjects(
                segmentation, height, width
            )
            m = mask.decode(rle)
            m = np.sum(m, axis=2).astype(np.uint8)
            m = cv2.resize(m, (image_resize, image_resize), interpolation=cv2.INTER_NEAREST)
            solution_dict["mask"] = PILImage.fromarray(m * 255).convert("L")
            left = np.where(m == 1)[1].min()
            top = np.where(m == 1)[0].min()
            right = np.where(m == 1)[1].max()
            bottom = np.where(m == 1)[0].max()
            box = [left, top, right, bottom]
            cy, cx = ndimage.center_of_mass(m)
            solution_dict["bbox_2d"] = box
            solution_dict["point_2d"] = [int(cx), int(cy)]
            solution_dict_list.append(solution_dict)

    data_list.append({
        "id": id_name,
        "problem": problem_list,
        "solution": solution_dict_list,   
        "image": resized_image, 
        "img_height": height,
        "img_width": width,
    })

dataset = Dataset.from_list(data_list, features=features)
dataset.save_to_disk("data/ReasonSeg_train_hf")
