import torch
import pandas as pd
from PIL import Image
import os
import torchvision.transforms as transforms
from transformers import CLIPProcessor, CLIPModel
import argparse

def calculate_clip_scores(image_folder, csv_file, max_rows=10000, batch_size=256):
    """
    计算文件夹中图片的平均CLIP分数
    
    Args:
        image_folder: 包含图片的文件夹路径
        csv_file: 包含图片ID和prompt的CSV文件路径
        max_rows: 处理的最大行数
        batch_size: 批处理大小
    
    Returns:
        平均CLIP分数
    """
    # 读取CSV文件
    df = pd.read_csv(csv_file, nrows=max_rows)
    
    # 检查是否有可用的GPU
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # 加载CLIP模型和处理器
    try:
        # 尝试使用safetensors加载模型（避免torch.load的安全问题）
        model = CLIPModel.from_pretrained(
            "openai/clip-vit-base-patch32",
            use_safetensors=True
        ).to(device)
        processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
        print("成功使用safetensors加载模型")
    except Exception as e:
        print(f"使用safetensors加载失败: {e}")
        try:
            # 如果safetensors不可用，尝试使用传统方式加载
            model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
            processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
            print("使用传统方式加载模型")
        except Exception as e2:
            print(f"加载模型失败: {e2}")
            return 0
    
    scores = []
    
    # 批量处理图片
    for i in range(0, len(df), batch_size):
        batch_df = df.iloc[i:i+batch_size]
        batch_images = []
        batch_prompts = []
        
        # 收集当前批次的数据
        for index, row in batch_df.iterrows():
            coco_id = row['coco_id']
            prompt = row['prompt']
            
            # 构造图片路径
            image_path = os.path.join(image_folder, f"{coco_id}.png")
            
            # 检查图片是否存在
            if not os.path.exists(image_path):
                print(f"图片 {image_path} 不存在，跳过")
                continue
                
            try:
                # 加载图片
                image = Image.open(image_path).convert("RGB")
                batch_images.append(image)
                batch_prompts.append(prompt)
                
            except Exception as e:
                print(f"处理图片 {image_path} 时出错: {e}")
                continue
        
        # 如果当前批次有有效数据，则计算CLIP分数
        if batch_images:
            # 使用CLIP处理器处理图片和文本
            inputs = processor(
                text=batch_prompts,
                images=batch_images,
                return_tensors="pt",
                padding=True
            ).to(device)
            
            # 计算CLIP分数
            with torch.no_grad():
                outputs = model(**inputs)
                logits_per_image = outputs.logits_per_image  # 图片相对于文本的相似度
                batch_scores = torch.diag(logits_per_image).cpu().tolist()  # 取对角线元素作为匹配分数
                
            scores.extend(batch_scores)
            
            # 打印当前批次的平均分数
            batch_avg = sum(batch_scores) / len(batch_scores)
            print(f"处理了 {len(batch_scores)} 张图片，当前批次平均分数: {batch_avg:.4f}")

    # 计算总体平均分数
    if scores:
        average_score = sum(scores) / len(scores)
        print(f"\n总共处理了 {len(scores)} 张图片")
        print(f"平均CLIP分数: {average_score:.4f}")
        return average_score
    else:
        print("没有成功处理任何图片")
        return 0

if __name__ == "__main__":
    # 使用示例
    parser = argparse.ArgumentParser(description="CLIP TEST")
    parser.add_argument(
        "--path",
        type=str,
        required=True,
    )
    args = parser.parse_args()

    image_folder = "/workspace/erase/diffusers/examples/dreambooth/Ablation_data/GA/" + args.path  # 替换为你的图片文件夹路径
    csv_file = "./coco_ablation.csv"  # CSV文件路径
    
    calculate_clip_scores(image_folder, csv_file)