import os
import pandas as pd
import torch
import ImageReward as reward
from tqdm import tqdm

# 读取CSV文件
csv_file = "./coco_ablation.csv"
df = pd.read_csv(csv_file)

# 图片文件夹路径
image_dir = "/workspace/erase/diffusers/examples/dreambooth/Ablation_data/wo_mask_alpha/coco"

# 初始化ImageReward模型
model = reward.load("./ImageReward.pt")

# 创建输出目录
output_dir = "./imagereward_evaluation_results"
os.makedirs(output_dir, exist_ok=True)

# 用于存储结果
results = []
total_score = 0
valid_count = 0

print("Evaluating images with ImageReward...")

# 遍历CSV中的每个prompt和对应的图片
for index, row in tqdm(df.iterrows(), total=len(df), desc="Evaluating"):
    prompt = row['prompt']
    coco_id = row['coco_id']
    
    # 构建图片路径
    image_path = os.path.join(image_dir, f"{coco_id}.png")
    
    # 检查图片是否存在
    if not os.path.exists(image_path):
        print(f"Warning: Image {coco_id}.png not found, skipping...")
        continue
    
    try:
        # 计算ImageReward评分
        score = model.score(prompt, image_path)
        results.append({
            'coco_id': coco_id,
            'prompt': prompt,
            'score': score
        })
        total_score += score
        valid_count += 1
        
        print(f"Evaluated {coco_id}.png: score = {score:.4f}")
        
    except Exception as e:
        print(f"Error evaluating {coco_id}.png: {e}")
        continue

# 计算平均评分
if valid_count > 0:
    average_score = total_score / valid_count
    print(f"\nAverage ImageReward score: {average_score:.4f}")
    print(f"Valid evaluations: {valid_count}/{len(df)}")
else:
    print("No valid evaluations found!")
    exit(1)

# 保存详细结果到CSV文件
results_df = pd.DataFrame(results)
results_file = os.path.join(output_dir, "image_reward_scores_GA.csv")
results_df.to_csv(results_file, index=False)

# 保存平均评分到文本文件
summary_file = os.path.join(output_dir, "image_reward_summary_GA.txt")
with open(summary_file, 'w') as f:
    f.write(f"Average ImageReward score: {average_score:.4f}\n")
    f.write(f"Valid evaluations: {valid_count}/{len(df)}\n")

print(f"\nResults saved to:")
print(f"  Detailed scores: {results_file}")
print(f"  Summary: {summary_file}")