import os
import json
import uuid
import random
import ast
import pandas as pd
import requests
from PIL import Image
from tqdm import tqdm

# === Config ===
VAL_CSV = "val_final.csv"
VAL_IMAGE_DIR = "/fs/scratch/PAS2099/Lemeng/DatasetResult/Spatial/val/images"
VAL_JSON = "/fs/scratch/PAS2099/Lemeng/DatasetResult/Spatial/val/val.json"
VAL_ANS_JSON = "/fs/scratch/PAS2099/Lemeng/DatasetResult/Spatial/val/val_ans.json"

NYU_DIR = "/fs/scratch/PAS2099/Lemeng/NYU_depth/nyu_rgb_images/"
OBJ365_DIR = "/fs/scratch/PAS2099/Lemeng/object365/"

answer_map = {
    "left up": "A. Left above",
    "left bottom": "B. Left below",
    "right up": "C. Right above",
    "right bottom": "D. Right below"
}

os.makedirs(VAL_IMAGE_DIR, exist_ok=True)
df = pd.read_csv(VAL_CSV)

val_entries = []
val_ans_entries = []

def get_image_path(row):
    source = str(row["source"]).lower()
    image_id = str(row["image"])
    if source.startswith("lvis"):
        return image_id, "lvis"
    elif source.startswith("nyu"):
        try:
            int_id = int(float(image_id))
            filename = f"image_{int_id:04d}.png"
            return os.path.join(NYU_DIR, filename), "nyu"
        except:
            print(f"⚠️ Invalid NYU ID: {image_id}")
            return None, None
    elif source.startswith("object365"):
        return os.path.join(OBJ365_DIR, image_id), "object365"
    else:
        return None, None

def load_image(image_path, source_type):
    if source_type == "lvis":
        try:
            img = Image.open(requests.get(image_path, stream=True).raw).convert("RGB")
            return img
        except:
            print(f"❌ Failed to load LVIS image: {image_path}")
            return None
    else:
        if not os.path.exists(image_path):
            print(f"❌ Not found: {image_path}")
            return None
        img = Image.open(image_path).convert("RGB")
        if source_type == "nyu":
            img = img.transpose(Image.ROTATE_270)
        return img

def compute_area_ratio(bbox, img_size, source_type):
    w, h = img_size
    img_area = w * h
    try:
        if source_type == "nyu":
            y1, y2, x1, x2 = bbox
            area = (x2 - x1) * (y2 - y1)
        else:
            x, y, bw, bh = bbox
            area = bw * bh
        return round(area / img_area, 6)
    except:
        return None

# === Main loop ===
for _, row in tqdm(df.iterrows(), total=len(df), desc="🧪 Generating val entries"):
    image_path, source_type = get_image_path(row)
    if not image_path:
        continue

    try:
        anchor_bbox = ast.literal_eval(row["anchor_bbox"])
        target_bbox = ast.literal_eval(row["target_bbox"])
    except:
        print(f"⚠️ Invalid bbox at row")
        continue

    img = load_image(image_path, source_type)
    if img is None:
        continue

    image_id = uuid.uuid4().hex
    image_filename = f"{image_id}.jpg"
    save_path = os.path.join(VAL_IMAGE_DIR, image_filename)

    try:
        img.save(save_path)
    except:
        print(f"❌ Failed to save: {save_path}")
        continue

    anchor_area_ratio = compute_area_ratio(anchor_bbox, img.size, source_type)
    target_area_ratio = compute_area_ratio(target_bbox, img.size, source_type)

    anchor = row["anchor_object"]
    target = row["target_object"]
    spatial_answer = str(row["spatial_answer"]).strip().lower()

    category = (
        "lvis" if "lvis" in source_type else
        "object365" if "object365" in source_type else
        "nyu_depth"
    )

    prompt = (
        f"<image>\nConsidering the relative positions of two objects in the image, "
        f"where is the {target} located with respect to the "
        f"{anchor}? Choose from A. Left above B. Left below C. Right above D. Right below."
    )
    answer_text = answer_map.get(spatial_answer, "Unknown")

    val_entries.append({
        "question_id": image_id,
        "image": image_filename,
        "category": category,
        "text": prompt,
        "anchor_bbox": anchor_area_ratio,
        "target_bbox": target_area_ratio
    })

    val_ans_entries.append({
        "question_id": image_id,
        "prompt": prompt,
        "text": answer_text,
        "answer_id": None,
        "model_id": None,
        "metadata": {}
    })

# === Shuffle together with fixed seed to keep alignment ===
random.seed(42)
combined = list(zip(val_entries, val_ans_entries))
random.shuffle(combined)
val_entries, val_ans_entries = zip(*combined)
val_entries = list(val_entries)
val_ans_entries = list(val_ans_entries)

with open(VAL_JSON, "w") as f:
    json.dump(val_entries, f, indent=2)

with open(VAL_ANS_JSON, "w") as f:
    json.dump(val_ans_entries, f, indent=2)

print(f"\n✅ Saved {len(val_entries)} QA entries to val.json and val_ans.json")
