import re
from pathlib import Path

import numpy as np
import cv2
import skimage


system_message = """You are a helpful and thoughtful vision language agent.

Your response must be structured as follows:

<think>THOUGHT</think>: A detailed breakdown of the reasoning process, referencing specific regions of the image if necessary.
<answer>ANSWER</answer>: The final concise answer derived from the reasoning.

Whenever necessary, refer to relevant regions in the image to revisit and better ground your reasoning.
"""


def build_exp_name(model_name: str, data_name: str):
    model_key = Path(model_name).name.replace(".", "_")
    return f"{data_name}_{model_key}"


def extract_bboxes(txt: str):
    # Find all occurrences of lists inside <bbox> tags
    matches = re.findall(r"<bbox>\[([\d,\s]+)\]</bbox>", txt)
    bboxes = []
    for match in matches:
        # Split the matched string by comma and convert each trimmed part to an integer
        numbers = [int(num.strip()) for num in match.split(",") if num.strip()]
        bboxes.append(numbers)

    # Replace each <bbox>[...]</bbox> with <|bbox|>
    txt_converted = re.sub(r"<bbox>\[([\d,\s]+)\]</bbox>", "<|region|>", txt)
    return txt_converted, bboxes


def format_data(image_root, sample):
    reasoning, bboxes = extract_bboxes(sample["reasoning"])
    return [
        {
            "role": "system",
            "content": [{"type": "text", "text": system_message}],
        },
        {
            "role": "user",
            "content": [
                {
                    "type": "image",
                    "image": str(image_root / sample["image"]),
                },
                {
                    "type": "text",
                    "text": sample["question"],
                },
            ],
        },
        {
            "role": "assistant",
            "content": [{"type": "text", "text": reasoning}],
        },
    ], bboxes


def polygon2mask(image_size, pts):
    # (h,w),  [(y,x), ...]
    # return skimage.draw.polygon2mask(image_size, pts)
    mask = np.zeros(image_size, dtype=np.uint8)
    cv2.fillPoly(mask, [np.flip(pts, -1).astype(np.int32)], 1)
    return mask.astype(bool)
