#%%
import os
import json
import uuid
import ast
import pandas as pd
import requests
from PIL import Image, ImageDraw
from tqdm import tqdm

# === Paths ===
VAL_CSV = "val_final.csv"
VAL_IMAGE_DIR = "/fs/scratch/PAS2099/Lemeng/DatasetResult/Spatial_bbox/val/images"
VAL_JSON = "/fs/scratch/PAS2099/Lemeng/DatasetResult/Spatial_bbox/val/val.json"
VAL_ANS_JSON = "/fs/scratch/PAS2099/Lemeng/DatasetResult/Spatial_bbox/val/val_ans.json"

NYU_DIR = "/fs/scratch/PAS2099/Lemeng/NYU_depth/nyu_rgb_images/"
OBJ365_DIR = "/fs/scratch/PAS2099/Lemeng/object365/"

# === Answer Map ===
answer_map = {
    "left up": "A. Left above",
    "left bottom": "B. Left below",
    "right up": "C. Right above",
    "right bottom": "D. Right below"
}

os.makedirs(VAL_IMAGE_DIR, exist_ok=True)
df = pd.read_csv(VAL_CSV)

val_entries = []
val_ans_entries = []

def get_image_path(row):
    source = str(row["source"]).lower()
    image_id = str(row["image"])

    if source.startswith("lvis"):
        return image_id, "lvis"
    elif source.startswith("nyu"):
        try:
            int_id = int(float(image_id))
            filename = f"image_{int_id:04d}.png"
            return os.path.join(NYU_DIR, filename), "nyu"
        except:
            return None, None
    elif source.startswith("object365"):
        return os.path.join(OBJ365_DIR, image_id), "object365"
    else:
        return None, None

def load_image(image_path, source_type):
    if source_type == "lvis":
        try:
            img = Image.open(requests.get(image_path, stream=True).raw).convert("RGB")
            return img
        except:
            print(f"❌ Failed to load LVIS image: {image_path}")
            return None
    else:
        if not os.path.exists(image_path):
            print(f"❌ Not found: {image_path}")
            return None
        img = Image.open(image_path).convert("RGB")
        if source_type == "nyu":
            img = img.transpose(Image.ROTATE_270)
        return img

def draw_bbox(draw, bbox, color, source_type):
    if source_type == "nyu":
        y1, y2, x1, x2 = bbox
        draw.rectangle([x1, y1, x2, y2], outline=color, width=3)
        return (x2 - x1) * (y2 - y1)
    else:
        x, y, w, h = bbox
        draw.rectangle([x, y, x + w, y + h], outline=color, width=3)
        return w * h

# === Main loop ===
for idx, row in tqdm(df.iterrows(), total=len(df), desc="🖼️ Processing val rows"):
    source = str(row["source"]).lower()
    anchor = row["anchor_object"]
    target = row["target_object"]
    spatial_answer = str(row["spatial_answer"]).strip().lower()

    try:
        anchor_bbox = ast.literal_eval(row["anchor_bbox"])
        target_bbox = ast.literal_eval(row["target_bbox"])
    except:
        print(f"⚠️ Skipping row {idx} due to invalid bbox")
        continue

    image_path, source_type = get_image_path(row)
    if not image_path:
        continue

    img = load_image(image_path, source_type)
    if img is None:
        continue

    draw = ImageDraw.Draw(img)
    anchor_area = draw_bbox(draw, anchor_bbox, "blue", source_type)
    target_area = draw_bbox(draw, target_bbox, "red", source_type)
    img_width, img_height = img.size
    img_area = img_width * img_height

    anchor_ratio = round(anchor_area / img_area, 6)
    target_ratio = round(target_area / img_area, 6)

    image_id = uuid.uuid4().hex
    image_filename = f"{image_id}.jpg"
    img_save_path = os.path.join(VAL_IMAGE_DIR, image_filename)
    img.save(img_save_path)

    category = (
        "lvis" if source.startswith("lvis") else
        "object365" if source.startswith("object365") else
        "nyu_depth"
    )

    prompt = (
        f"<image>\nConsidering the relative positions of two objects in the image, "
        f"where is the {target} (annotated by the red box) located with respect to the "
        f"{anchor} (annotated by the blue box)? Choose from A. Left above B. Left below C. Right above D. Right below."
    )
    answer_text = answer_map.get(spatial_answer, "Unknown")

    val_entries.append({
        "question_id": image_id,
        "image": image_filename,
        "category": category,
        "text": prompt,
        "anchor_bbox": anchor_ratio,
        "target_bbox": target_ratio
    })

    val_ans_entries.append({
        "question_id": image_id,
        "prompt": prompt,
        "text": answer_text,
        "answer_id": None,
        "model_id": None,
        "metadata": {}
    })

# === Save JSON outputs ===
with open(VAL_JSON, "w") as f:
    json.dump(val_entries, f, indent=2)

with open(VAL_ANS_JSON, "w") as f:
    json.dump(val_ans_entries, f, indent=2)

print(f"\n✅ Done! Saved {len(val_entries)} entries to val.json and val_ans.json")
