import os
import uuid
import json
import random
import pandas as pd
import requests
from PIL import Image
from tqdm import tqdm

# === Config ===
CSV_PATH = "train_final_limited.csv"
OUTPUT_IMAGE_DIR = "/fs/scratch/PAS2099/Lemeng/DatasetResult/Spatial/train/images"
OUTPUT_JSON_PATH = "/fs/scratch/PAS2099/Lemeng/DatasetResult/Spatial/train/train.json"

OBJECT365_ROOT = "/fs/scratch/PAS2099/Lemeng/object365/"
NYU_ROOT = "/fs/scratch/PAS2099/Lemeng/NYU_depth/nyu_rgb_images/"

os.makedirs(OUTPUT_IMAGE_DIR, exist_ok=True)
df = pd.read_csv(CSV_PATH)

answer_map = {
    "left up": "A. Left above",
    "left bottom": "B. Left below",
    "right up": "C. Right above",
    "right bottom": "D. Right below"
}

train_data = []

def get_image_path(row):
    source = str(row["source"]).lower()
    image_id = row["image"]
    if source.startswith("lvis"):
        return image_id, "lvis"
    elif source.startswith("nyu"):
        try:
            int_id = int(float(image_id))
            filename = f"image_{int_id:04d}.png"
            return os.path.join(NYU_ROOT, filename), "nyu"
        except:
            print(f"⚠️ Invalid NYU image ID: {image_id}")
            return None, None
    elif source.startswith("object365"):
        return os.path.join(OBJECT365_ROOT, image_id), "object365"
    else:
        return None, None

def load_image(image_path, source_type):
    if source_type == "lvis":
        try:
            response = requests.get(image_path, timeout=10)
            img = Image.open(requests.get(image_path, stream=True).raw).convert("RGB")
            return img
        except Exception as e:
            print(f"❌ Failed to load LVIS image: {image_path}")
            return None
    else:
        if not os.path.exists(image_path):
            print(f"❌ Image not found: {image_path}")
            return None
        try:
            img = Image.open(image_path).convert("RGB")
            if source_type == "nyu":
                img = img.transpose(Image.ROTATE_270)
            return img
        except Exception as e:
            print(f"❌ Failed to open: {image_path}")
            return None

# === Main loop ===
for _, row in tqdm(df.iterrows(), total=len(df), desc="🧪 Generating train entries"):
    image_path, source_type = get_image_path(row)
    if not image_path:
        continue

    img = load_image(image_path, source_type)
    if img is None:
        continue

    image_id = uuid.uuid4().hex
    save_path = os.path.join(OUTPUT_IMAGE_DIR, f"{image_id}.jpg")
    try:
        img.save(save_path)
    except:
        print(f"❌ Could not save image: {save_path}")
        continue

    question = (
        f"<image>\nConsidering the relative positions of two objects in the image, "
        f"where is the {row['target_object']} located with respect to the "
        f"{row['anchor_object']}? Choose from A. Left above B. Left below C. Right above D. Right below."
    )
    answer = answer_map.get(str(row["spatial_answer"]).strip().lower(), "Unknown")

    train_data.append({
        "id": image_id,
        "image": f"images/{image_id}.jpg",
        "conversations": [
            {"from": "human", "value": question},
            {"from": "gpt", "value": answer}
        ]
    })

# === Shuffle and save
random.shuffle(train_data)
with open(OUTPUT_JSON_PATH, "w") as f:
    json.dump(train_data, f, indent=2)

print(f"\n✅ Saved {len(train_data)} entries to {OUTPUT_JSON_PATH}")
