import os
import json
from tqdm import tqdm
from ImageReward import load
from PIL import Image
import torch

# Paths
input_path = "carrot-bowl/metadata.json"
output_path = "carrot-bowl/metadata_IR.json"
image_root = "carrot-bowl"

# Load ImageReward model
model = load("ImageReward-v1.0")
model.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"ImageReward loaded on device: {model.device}")

# Function to load and convert an image
def load_image(img_path):
    try:
        return Image.open(img_path).convert("RGB")
    except Exception as e:
        print(f"❌ Loading error {img_path}: {e}")
        return None

# Load existing JSON
with open(input_path, "r", encoding="utf-8") as f:
    data = json.load(f)

# New content with ImageReward
new_data = []

for entry in tqdm(data, desc="Computing ImageReward scores"):
    prompt = entry["prompt"]
    model_name = entry["model"]
    filenames = entry["filenames"]

    # Load images
    img_paths = [os.path.join(image_root, fname) for fname in filenames]
    images = [load_image(p) for p in img_paths]
    images = [img for img in images if img is not None]

    # Compute scores
    if images:
        try:
            scores = model.score(prompt, images)
        except Exception as e:
            print(f"⚠️ Erreur pour prompt: {prompt} -> {e}")
            scores = [None] * len(images)
    else:
        scores = [None] * len(filenames)

    # Update structure
    new_entry = {
        "prompt": prompt,
        "model": model_name,
        "filenames": filenames,
        "image_reward_scores": scores  # remplace clip_scores
    }
    new_data.append(new_entry)

# Save new JSON
with open(output_path, "w", encoding="utf-8") as f:
    json.dump(new_data, f, indent=2)

print(f"\n✅ File saved to: {output_path}")
