import os
import json
import torch
import clip
from PIL import Image
from tqdm import tqdm

# Paths
input_folder = "/mnt/task_runtime/IIE/code/inference_result_may8/mc_eval"
output_folder = "/mnt/task_runtime/IIE/code/inference_result_may8/clip_whole_image"
os.makedirs(output_folder, exist_ok=True)

# Load CLIP model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

def compute_clip_similarity(img_path1, img_path2):
    try:
        image1 = preprocess(Image.open(img_path1).convert("RGB")).unsqueeze(0).to(device)
        image2 = preprocess(Image.open(img_path2).convert("RGB")).unsqueeze(0).to(device)
    except Exception as e:
        print(f"Failed to load images: {img_path1}, {img_path2} — {e}")
        return None

    with torch.no_grad():
        feat1 = model.encode_image(image1)
        feat2 = model.encode_image(image2)

    feat1 /= feat1.norm(dim=-1, keepdim=True)
    feat2 /= feat2.norm(dim=-1, keepdim=True)

    similarity = (feat1 @ feat2.T).item()
    return similarity

# Process each JSON file
for fname in tqdm(os.listdir(input_folder)):
    if fname == 'univg_mc_eval.json':
        if not fname.endswith(".json"):
            continue

        in_path = os.path.join(input_folder, fname)
        out_path = os.path.join(output_folder, fname)

        with open(in_path, "r") as f:
            try:
                entries = json.load(f)
            except Exception as e:
                print(f"Failed to load JSON {fname}: {e}")
                continue

        for entry in entries:
            original_path = entry.get("image")
            edited_path = entry.get("edited_image_path")

            if not original_path or not edited_path:
                print(f"Missing paths in entry of {fname}")
                continue

            score = compute_clip_similarity(original_path, edited_path)
            if score is not None:
                entry["clip_score_whole_image"] = score

        with open(out_path, "w") as f:
            json.dump(entries, f, indent=2)
    else:
        pass
