import os
from PIL import Image, ExifTags
from tqdm import tqdm

import torch
from torchvision import transforms
from transformers import CLIPProcessor, CLIPModel

def calculate_clip_scores(image_folder, prompt_file="prompt.txt", limit=None):
    """
    计算CLIP Score
    
    Args:
        image_folder: 图像文件夹路径
        prompt_file: prompt文件路径
        limit: 限制处理的图像数量，None表示处理所有
    
    Returns:
        平均CLIP Score
    """
    # 模型名
    model_id = "/root/autodl-tmp/pretrained_models/laion/CLIP-ViT-g-14-laion2B-s12B-b42K"

    # 加载模型与处理器
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = CLIPModel.from_pretrained(model_id).to(device)
    processor = CLIPProcessor.from_pretrained(model_id)

    # 读取prompts
    with open(prompt_file, "r", encoding="utf-8") as f:
        prompts = [line.strip() for line in f.readlines()]

    # 获取图像文件列表和对应的prompts
    images_and_prompts = []
    
    # 遍历图像文件夹
    for filename in sorted(os.listdir(image_folder)):
        if filename.lower().endswith((".jpg", ".jpeg", ".png")):
            img_path = os.path.join(image_folder, filename)
            try:
                # 尝试从EXIF中读取prompt
                img = Image.open(img_path)
                exif_data = img._getexif()
                if exif_data is not None:
                    exif = {}
                    for tag_id, value in exif_data.items():
                        tag = ExifTags.TAGS.get(tag_id, tag_id)
                        exif[tag] = value
                    prompt = exif.get("ImageDescription", None)
                    if prompt:
                        images_and_prompts.append((img_path, prompt))
                        if limit and len(images_and_prompts) >= limit:
                            break
            except Exception as e:
                print(f"处理图像 {filename} 时出错：{e}")

    if not images_and_prompts:
        print("未找到有效的图像-prompt对")
        return 0.0

    print(f"找到 {len(images_and_prompts)} 个图像-prompt对")
    if limit:
        print(f"限制处理前 {limit} 个")

    clip_scores = []

    # 图像转 tensor 的 transforms
    to_rgb = transforms.Lambda(lambda x: x.convert("RGB"))

    # 遍历每一张图像与对应 prompt
    for img_path, prompt in tqdm(images_and_prompts, desc="计算CLIP Score"):
        # 打开图像并处理
        image = to_rgb(Image.open(img_path))

        # 构造输入
        inputs = processor(
            text=prompt,
            images=image,
            return_tensors="pt",
            padding=True,
            truncation=True  # 自动截断到最大长度（77）
        ).to(device)

        # 模型前向并计算相似度
        with torch.no_grad():
            outputs = model(**inputs)
            logits_per_image = outputs.logits_per_image  # shape: [1, 1]
            score = logits_per_image.item()
            clip_scores.append(score)

    # 平均分
    average_score = sum(clip_scores) / len(clip_scores) if clip_scores else 0.0
    print(f"Average CLIP Score (HuggingFace CLIP-ViT-g-14): {average_score:.4f}")
    
    return average_score

if __name__ == "__main__":
    # 原来的脚本执行逻辑，用于向后兼容
    # 读取 prompts
    with open("prompt.txt", "r", encoding="utf-8") as f:
        prompts = [line.strip() for line in f.readlines()]

    assert len(prompts) == 200, f"Expected 200 prompts, got {len(prompts)}"

    image_folder = "/root/autodl-tmp/TaylorSeer/TaylorSeer-FLUX/ts/interval6"
    
    # 模型名
    model_id = "/root/autodl-tmp/pretrained_models/laion/CLIP-ViT-g-14-laion2B-s12B-b42K"

    # 加载模型与处理器
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = CLIPModel.from_pretrained(model_id).to(device)
    processor = CLIPProcessor.from_pretrained(model_id)

    clip_scores = []

    # 图像转 tensor 的 transforms
    to_rgb = transforms.Lambda(lambda x: x.convert("RGB"))

    # 遍历每一张图像与对应 prompt
    for i in tqdm(range(200)):
        img_path = os.path.join(image_folder, f"img_{i}.jpg")
        prompt = prompts[i]

        # 打开图像并处理
        image = to_rgb(Image.open(img_path))

        # 构造输入
        inputs = processor(
        text=prompt,
        images=image,
        return_tensors="pt",
        padding=True,
        truncation=True  # ✨ 自动截断到最大长度（77）
        ).to(device)

        # 模型前向并计算相似度
        with torch.no_grad():
            outputs = model(**inputs)
            logits_per_image = outputs.logits_per_image  # shape: [1, 1]
            score = logits_per_image.item()
            clip_scores.append(score)

    # 平均分
    average_score = sum(clip_scores) / len(clip_scores)
    print(f"Average CLIP Score (HuggingFace CLIP-ViT-g-14): {average_score:.4f}")
