#%%
import os
import json
import shortuuid
import shutil
from collections import OrderedDict
from tqdm import tqdm 

def extract_all_classes(*json_paths):
    classes = set()
    for path in json_paths:
        with open(path, "r") as f:
            data = json.load(f)
            classes.update(item["annotation"] for item in data)
    return sorted(classes)

def convert_json(input_path, output_path, mode, all_classes, val_ans_output=None, image_root="image_500", save_root="action"):
    with open(input_path, "r") as f:
        data = json.load(f)

    result = []
    val_ans_result = []

    print(f"🔁 Processing {mode} set from {input_path} ...")

    for item in tqdm(data, desc=f"Copying {mode} images", unit="img"):
        old_img_rel_path = item["image_path"]
        annotation = item["annotation"]
        height = item["image_height"]
        width = item["image_width"]
        uid = shortuuid.uuid()

        set_name = "train" if mode == "train" else "val"

        new_img_rel_path = f"{set_name}/images/{uid}.jpg"
        new_img_abs_path = os.path.join(save_root, new_img_rel_path)
        os.makedirs(os.path.dirname(new_img_abs_path), exist_ok=True)

        old_img_abs_path = os.path.join(image_root, old_img_rel_path)
        if not os.path.exists(old_img_abs_path):
            print(f"[WARNING] Image not found: {old_img_abs_path}, skipping")
            continue
        shutil.copy2(old_img_abs_path, new_img_abs_path)

        question = "<image>\nWhich action or activity is shown in the image? Choose from the following option: " + \
                   " ".join(f"({i+1}) {cls}" for i, cls in enumerate(all_classes))
        answer = f"({all_classes.index(annotation)+1}) {annotation}"

        if mode == "train":
            result.append(OrderedDict({
                "id": uid,
                "image": new_img_rel_path,
                "height": height,
                "width": width,
                "conversations": [
                    {"from": "human", "value": question},
                    {"from": "gpt", "value": answer}
                ]
            }))
        elif mode == "val":
            image_id = os.path.splitext(os.path.basename(old_img_rel_path))[0]
            result.append(OrderedDict({
                "question_id": uid,
                "image": new_img_rel_path,
                "height": height,
                "width": width,
                "text": question,
                "id": image_id
            }))
            val_ans_result.append({
                "question_id": uid,
                "prompt": question,
                "text": answer,
                "height": height,
                "width": width,
                "answer_id": None,
                "model_id": None
            })

    with open(os.path.join(save_root, output_path), "w") as f:
        json.dump(result, f, indent=2)
    print(f"✅ Saved: {save_root}/{output_path}")

    if mode == "val" and val_ans_output:
        with open(os.path.join(save_root, val_ans_output), "w") as f:
            json.dump(val_ans_result, f, indent=2)
        print(f"✅ Saved: {save_root}/{val_ans_output}")

# === USAGE ===

# Input JSON paths
raw_train_json = "image_500/train_output.json"
raw_val_json = "image_500/val_output.json"

# Output will go to `action/` structure
output_root = "action"
os.makedirs(output_root, exist_ok=True)

# Prepare full class list
all_classes = extract_all_classes(raw_train_json, raw_val_json)

# Convert + copy
convert_json(
    input_path=raw_train_json,
    output_path="train/train.json",
    mode="train",
    all_classes=all_classes,
    image_root="image_500",
    save_root=output_root
)

convert_json(
    input_path=raw_val_json,
    output_path="val/val.json",
    val_ans_output="val/val_ans.json",
    mode="val",
    all_classes=all_classes,
    image_root="image_500",
    save_root=output_root
)
