
import os
import torch
from PIL import Image, ExifTags
from tqdm import tqdm
import argparse
import ImageReward as RM
from torch.utils.data import DataLoader
import torchvision.transforms.v2.functional as TF
import torchvision.transforms.v2 as T

os.environ['TOKENIZERS_PARALLELISM'] = 'false'

def load_images_and_prompts(image_folder, limit=None):
    images = []
    prompts = []
    # 按文件名排序以确保一致的处理顺序
    filenames = sorted(os.listdir(image_folder))
    
    print(f"在文件夹 {image_folder} 中找到 {len(filenames)} 个文件")
    if limit:
        print(f"限制处理前 {limit} 个有效图像")
    
    for filename in filenames:
        if filename.lower().endswith((".jpg", ".jpeg", ".png")):
            img_path = os.path.join(image_folder, filename)
            try:
                img = Image.open(img_path)
                exif_data = img._getexif()
                if exif_data is not None:
                    exif = {}
                    for tag_id, value in exif_data.items():
                        tag = ExifTags.TAGS.get(tag_id, tag_id)
                        exif[tag] = value
                    prompt = exif.get("ImageDescription", None)
                    if prompt:
                        images.append(img_path)
                        prompts.append(prompt)
                        # 检查是否达到限制
                        if limit and len(images) >= limit:
                            print(f"已达到限制数量 {limit}，停止加载更多图像")
                            break
                    else:
                        print(f"图像 {filename} 中未找到提示信息。")
                else:
                    print(f"图像 {filename} 中无 Exif 元数据。")
            except Exception as e:
                print(f"处理图像 {filename} 时出错：{e}")
    
    print(f"最终加载了 {len(images)} 个有效图像")
    return images, prompts

class ImageFolder(torch.utils.data.Dataset):
    def __init__(self, image_folder, limit=None):
        self.image_folder = image_folder
        self.images, self.prompts = load_images_and_prompts(image_folder, limit)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        prompt = self.prompts[idx]
        img = Image.open(img_path).convert("RGB")
        img = TF.pil_to_tensor(img)
        return {
            'image': img,
            'prompt': prompt,
            'path': img_path
        }

def get_reward_transform():
    return T.Compose([
        T.Resize(224, interpolation=T.InterpolationMode.BICUBIC),
        T.CenterCrop(224),
        T.ToImage(),
        T.ToDtype(torch.float32, scale=True),
        T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
    ])

def load_imagereward_from_local(local_model_path, device):
    """从本地文件夹或文件加载ImageReward模型"""
    try:
        # 检查本地模型路径是否存在
        if not os.path.exists(local_model_path):
            raise FileNotFoundError(f"本地模型路径不存在: {local_model_path}")
        
        print(f"正在从本地路径加载模型: {local_model_path}")
        
        # 如果是.pt文件，需要特殊处理
        if local_model_path.endswith('.pt') or local_model_path.endswith('.pth'):
            # 首先创建一个标准的ImageReward模型实例
            print("检测到.pt文件，正在创建ImageReward模型实例...")
            reward_model = RM.load("ImageReward-v1.0", download_root='.').to(device)
            
            # 然后尝试加载.pt文件中的权重
            print("正在加载.pt文件中的权重...")
            checkpoint = torch.load(local_model_path, map_location=device)
            
            # 根据checkpoint的结构加载权重
            if isinstance(checkpoint, dict):
                if 'state_dict' in checkpoint:
                    reward_model.load_state_dict(checkpoint['state_dict'])
                elif 'model' in checkpoint:
                    reward_model.load_state_dict(checkpoint['model'])
                else:
                    # 假设整个字典就是state_dict
                    reward_model.load_state_dict(checkpoint)
            else:
                # 如果checkpoint本身就是模型，直接使用
                reward_model = checkpoint.to(device)
            
            print("本地.pt模型加载成功！")
            return reward_model
        else:
            # 使用本地路径加载模型（目录形式）
            reward_model = RM.load(local_model_path, download_root=local_model_path).to(device)
            print("本地模型加载成功！")
            return reward_model
        
    except Exception as e:
        print(f"从本地加载模型失败: {e}")
        print("尝试从在线加载模型...")
        
        # 如果本地加载失败，回退到在线加载
        reward_model = RM.load("ImageReward-v1.0", download_root='.').to(device)
        return reward_model

def imagereward_evaluation(test_folder, batch_size=64, num_workers=32, 
                           local_model_path="/root/autodl-tmp/pretrained_models/THUDM/ImageReward/", 
                           use_gpu=True, limit=None):
    device = torch.device("cuda" if torch.cuda.is_available() and use_gpu else "cpu")
    print(f"当前使用设备：{device}")

    # 从本地加载模型
    reward_model = load_imagereward_from_local(local_model_path, device)
    reward_transform = get_reward_transform()

    data = ImageFolder(test_folder, limit)
    if len(data) == 0:
        print("测试文件夹中没有找到有效的图像文件")
        return None

    dl = DataLoader(data, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    reward_scores = []

    if limit:
        print(f"开始处理 {len(data)} 张图像（限制：{limit}）...")
    else:
        print(f"开始处理 {len(data)} 张图像...")

    for batch in tqdm(dl, desc="处理批次"):
        images = batch['image'].to(device)
        prompts = batch['prompt']
        inputs = reward_model.blip.tokenizer(prompts, padding='max_length', truncation=True, max_length=512, return_tensors="pt").to(device)
        with torch.no_grad():
            images_reward = reward_transform(images)
            scores = reward_model.score_gard(inputs.input_ids, inputs.attention_mask, images_reward)
            # 使用 reshape(-1) 将二维张量展平为一维，再转换为列表
            reward_scores.extend(scores.reshape(-1).tolist())

    overall_reward_avg = sum(reward_scores) / len(reward_scores) if reward_scores else 0

    print("\n" + "="*50)
    print("ImageReward 评估结果:")
    print("="*50)
    print(f"处理的图像总数: {len(data)}")
    print(f"平均 ImageReward Score: {overall_reward_avg:.4f}")
    print("="*50)

    results = {
        "image_reward": overall_reward_avg,
        "total_images_processed": len(data),
    }

    return results

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="仅计算 ImageReward 分数")
    parser.add_argument("--test_folder", type=str, default="/root/autodl-tmp/TaylorSeer/TaylorSeer-FLUX/bdf2/interval6", help="测试图像文件夹的路径")
    parser.add_argument("--batch_size", type=int, default=64, help="批处理大小")
    parser.add_argument("--num_workers", type=int, default=32, help="数据加载的工作线程数")
    parser.add_argument("--local_model_path", type=str, default="/root/autodl-tmp/pretrained_models/THUDM/ImageReward", help="本地ImageReward模型路径")
    parser.add_argument("--gpu", action='store_true', default=True, help="使用GPU加速计算")
    parser.add_argument("--limit", type=int, default=None, help="限制处理的图像数量")

    args = parser.parse_args()

    imagereward_evaluation(
        test_folder=args.test_folder,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        local_model_path=args.local_model_path,
        use_gpu=args.gpu,
        limit=args.limit
    )