# %%
import os
import uuid
import json
import pandas as pd
import requests
from PIL import Image, ImageDraw
from tqdm import tqdm
import ast

# === Config ===
CSV_PATH = "train_final_limited.csv"
OUTPUT_IMAGE_DIR = "/fs/scratch/PAS2099/Lemeng/DatasetResult/Spatial_bbox/train/images"
OUTPUT_JSON_PATH = "/fs/scratch/PAS2099/Lemeng/DatasetResult/Spatial_bbox/train/train.json"

OBJECT365_ROOT = "/fs/scratch/PAS2099/Lemeng/object365/"
NYU_ROOT = "/fs/scratch/PAS2099/Lemeng/NYU_depth/nyu_rgb_images/"

os.makedirs(OUTPUT_IMAGE_DIR, exist_ok=True)
df = pd.read_csv(CSV_PATH)

# === Spatial Answer Mapping ===
answer_map = {
    "left up": "A. Left above",
    "left bottom": "B. Left below",
    "right up": "C. Right above",
    "right bottom": "D. Right below"
}

train_data = []

def get_image_path(row):
    source = str(row["source"]).lower()
    image_id = row["image"]

    if source.startswith("lvis"):
        return image_id, "lvis"
    elif source.startswith("nyu"):
        try:
            int_id = int(float(image_id))  # Handles both "76" and "76.0"
            filename = f"image_{int_id:04d}.png"
            return os.path.join(NYU_ROOT, filename), "nyu"
        except:
            print(f"⚠️ Failed to parse NYU image ID: {image_id}")
            return None, None
    elif source.startswith("object365"):
        return os.path.join(OBJECT365_ROOT, image_id), "object365"
    else:
        return None, None

def load_image(image_path, source_type):
    if source_type == "lvis":
        try:
            response = requests.get(image_path, timeout=10)
            img = Image.open(requests.get(image_path, stream=True).raw).convert("RGB")
            return img
        except Exception as e:
            print(f"❌ Failed to load LVIS image: {image_path}")
            return None
    else:
        if not os.path.exists(image_path):
            print(f"❌ Image not found: {image_path}")
            return None
        try:
            img = Image.open(image_path).convert("RGB")
            if source_type == "nyu":
                img = img.transpose(Image.ROTATE_270)
            return img
        except Exception as e:
            print(f"❌ Failed to open: {image_path}")
            return None

# === Processing ===
for i, row in tqdm(df.iterrows(), total=len(df), desc="🔄 Generating annotated dataset"):
    source = row["source"]
    image_path, source_type = get_image_path(row)
    if not image_path:
        continue

    img = load_image(image_path, source_type)
    if img is None:
        continue

    try:
        anchor_bbox = ast.literal_eval(row["anchor_bbox"])
        target_bbox = ast.literal_eval(row["target_bbox"])
    except Exception as e:
        print(f"⚠️ Invalid bbox at row {i}")
        continue

    draw = ImageDraw.Draw(img)
    # draw.rectangle([target_bbox[0], target_bbox[1], target_bbox[0]+target_bbox[2], target_bbox[1]+target_bbox[3]], outline="red", width=3)
    # draw.rectangle([anchor_bbox[0], anchor_bbox[1], anchor_bbox[0]+anchor_bbox[2], anchor_bbox[1]+anchor_bbox[3]], outline="blue", width=3)
    if source_type == "nyu":
        # bbox = [y1, y2, x1, x2]
        def draw_nyu_box(draw, bbox, color):
            y1, y2, x1, x2 = bbox
            draw.rectangle([x1, y1, x2, y2], outline=color, width=3)

        draw_nyu_box(draw, target_bbox, "red")
        draw_nyu_box(draw, anchor_bbox, "blue")

    else:
        # bbox = [x, y, w, h] (LVIS and Object365)
        def draw_xywh_box(draw, bbox, color):
            x, y, w, h = bbox
            draw.rectangle([x, y, x + w, y + h], outline=color, width=3)

        draw_xywh_box(draw, target_bbox, "red")
        draw_xywh_box(draw, anchor_bbox, "blue")


    img_id = uuid.uuid4().hex
    save_path = os.path.join(OUTPUT_IMAGE_DIR, f"{img_id}.jpg")
    # save_path = os.path.join(OUTPUT_IMAGE_DIR, f"{i}.jpg")
    try:
        img.save(save_path)
    except Exception as e:
        print(f"❌ Could not save image: {save_path}")
        continue

    question = (
        f"<image>\nConsidering the relative positions of two objects in the image, "
        f"where is the {row['target_object']} (annotated by the red box) located with respect to the "
        f"{row['anchor_object']} (annotated by the blue box)? Choose from A. Left above B. Left below C. Right above D. Right below."
    )

    spatial_answer = row["spatial_answer"]
    answer = answer_map.get(spatial_answer.strip().lower(), "Unknown")

    entry = {
        "id": img_id,
        "image": f"images/{img_id}.jpg",
        "conversations": [
            {
                "from": "human",
                "value": question
            },
            {
                "from": "gpt",
                "value": answer
            }
        ]
    }

    train_data.append(entry)

# === Save final JSON ===
with open(OUTPUT_JSON_PATH, "w") as f:
    json.dump(train_data, f, indent=2)

print(f"\n✅ Done! Saved {len(train_data)} entries to {OUTPUT_JSON_PATH}")
