#%%
import json

json_path = "combined_filtered/train/train.json"

with open(json_path, "r") as f:
    data = json.load(f)

for item in data:
    if "image" in item and item["image"].startswith("combined/"):
        item["image"] = item["image"].replace("combined/", "", 1)

with open(json_path, "w") as f:
    json.dump(data, f, indent=2)

print("✅ Removed 'combined/' prefix from all image paths in train.json")


#%%
import os
import json
import re
import cv2
from tqdm import tqdm

def extract_class_names(train_json_path, val_ans_json_path):
    class_set = set()

    with open(train_json_path, "r") as f:
        train_data = json.load(f)
        for item in train_data:
            for c in item.get("conversations", []):
                if c["from"] == "human":
                    match = re.search(r"for (.+?)\.", c["value"])
                    if match:
                        class_set.add(match.group(1).strip().lower())

    with open(val_ans_json_path, "r") as f:
        val_ans = json.load(f)
        for item in val_ans:
            match = re.search(r"for (.+?)\.", item["prompt"])
            if match:
                class_set.add(match.group(1).strip().lower())

    return sorted(class_set)

def build_multiline_question(all_classes):
    return "<image>\nWhat is in the red bounding box? Choose from the following option:\n" + \
        "\n".join([f"{i+1}. {cls}" for i, cls in enumerate(all_classes)])

def convert_train(train_path, image_root, out_json, out_img_dir, all_classes):
    os.makedirs(out_img_dir, exist_ok=True)
    with open(train_path, "r") as f:
        data = json.load(f)

    results = []
    for item in tqdm(data, desc="Converting train"):
        image_path = os.path.join(image_root, item["image"])
        if not os.path.exists(image_path):
            continue

        obj = None
        bbox = None
        for c in item["conversations"]:
            if c["from"] == "human":
                match = re.search(r"for (.+?)\.", c["value"])
                if match:
                    obj = match.group(1).strip().lower()
            elif c["from"] == "gpt":
                bbox = json.loads(c["value"])

        if not obj or not bbox:
            continue

        # draw bbox
        img = cv2.imread(image_path)
        h, w = img.shape[:2]
        x1, y1 = int(bbox[0]*w), int(bbox[1]*h)
        x2, y2 = int(bbox[2]*w), int(bbox[3]*h)
        cv2.rectangle(img, (x1, y1), (x2, y2), (0, 0, 255), 2)
        out_path = os.path.join(out_img_dir, os.path.basename(image_path))
        cv2.imwrite(out_path, img)

        question = build_multiline_question(all_classes)
        label_idx = all_classes.index(obj) + 1
        results.append({
            "id": item["id"],
            "image": f"train/images/{os.path.basename(image_path)}",
            "category": item.get("category", ""),
            "area": item.get("area", None),
            "bbox": bbox,
            "conversations": [
                {"from": "human", "value": question},
                {"from": "gpt", "value": f"{label_idx}. {obj}"}
            ]
        })

    with open(out_json, "w") as f:
        json.dump(results, f, indent=2)

def convert_val(val_path, val_ans_path, image_root, out_val_json, out_val_ans_json, out_img_dir, all_classes):
    os.makedirs(out_img_dir, exist_ok=True)

    with open(val_path, "r") as f:
        val_meta = {x["question_id"]: x for x in json.load(f)}

    with open(val_ans_path, "r") as f:
        val_ans = {x["question_id"]: x for x in json.load(f)}

    val_result = []
    val_ans_result = []

    for qid in tqdm(val_meta, desc="Converting val"):
        meta = val_meta[qid]
        ans = val_ans.get(qid)
        if not ans:
            continue

        bbox = json.loads(ans["text"])
        obj = None
        match = re.search(r"for (.+?)\.", ans["prompt"])
        if match:
            obj = match.group(1).strip().lower()
        if obj is None:
            continue

        image_path = os.path.join(image_root, "val/images", meta["image"])
        if not os.path.exists(image_path):
            continue

        # draw box
        img = cv2.imread(image_path)
        h, w = img.shape[:2]
        x1, y1 = int(bbox[0]*w), int(bbox[1]*h)
        x2, y2 = int(bbox[2]*w), int(bbox[3]*h)
        cv2.rectangle(img, (x1, y1), (x2, y2), (0, 0, 255), 2)
        out_path = os.path.join(out_img_dir, meta["image"])
        cv2.imwrite(out_path, img)

        question = build_multiline_question(all_classes)
        label_idx = all_classes.index(obj) + 1

        val_result.append({
            "question_id": qid,
            "image": meta["image"],
            "category": meta.get("category", ""),
            "bbox": bbox,
            "area": ans.get("area", None),
            "text": question,
            "id": qid
        })
        val_ans_result.append({
            "question_id": qid,
            "prompt": question,
            "text": f"{label_idx}. {obj}",
            "area": ans.get("area", None),
            "answer_id": None,
            "model_id": None
        })

    with open(out_val_json, "w") as f:
        json.dump(val_result, f, indent=2)

    with open(out_val_ans_json, "w") as f:
        json.dump(val_ans_result, f, indent=2)

# === Main ===
input_root = "combined_filtered"
output_root = "recognition"

os.makedirs(output_root, exist_ok=True)

# Input
train_json = os.path.join(input_root, "train/train.json")
val_json = os.path.join(input_root, "val/val.json")
val_ans_json = os.path.join(input_root, "val/val_ans.json")

# Output
train_out_json = os.path.join(output_root, "train/train.json")
val_out_json = os.path.join(output_root, "val/val.json")
val_out_ans_json = os.path.join(output_root, "val/val_ans.json")
train_img_dir = os.path.join(output_root, "train/images")
val_img_dir = os.path.join(output_root, "val/images")

# Run
classes = extract_class_names(train_json, val_ans_json)
convert_train(train_json, input_root, train_out_json, train_img_dir, classes)
convert_val(val_json, val_ans_json, input_root, val_out_json, val_out_ans_json, val_img_dir, classes)
