import os

os.environ['XFORMERS_FORCE_DISABLE_TRITON'] = '1'
import os.path as osp
import json
import argparse

import numpy as np
import ImageReward as RM

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--meta_file", type=str, default="./runtime/cfgTrue_tokenTrue_numsteps32_maxlength128_bsz1_timestamp20250508_112517/metadata.jsonl")
    args = parser.parse_args()

    image_reward_model = RM.load("ImageReward-v1.0")
    clip_model = RM.load_score("CLIP")

    with open(args.meta_file, 'r') as f:
        meta_infos = json.load(f)

    results = []
    average_image_reward = []
    average_clip_scores = []

    for meta in meta_infos:
        image_paths = meta['gen_image_paths']
        prompt = meta['prompt']
        image_rewards = image_reward_model.score(prompt, image_paths)
        if isinstance(image_rewards, float):
            image_rewards = [image_rewards]  # 转换为列表
        _, clip_scores = clip_model.inference_rank(prompt, image_paths)
        if isinstance(clip_scores, float):
            clip_scores = [clip_scores]  # 转换为列表

        average_image_reward.extend(image_rewards)
        average_clip_scores.extend(clip_scores)

        # 记录每个 id 对应的 score
        results.append({
            'id': meta['id'],
            'image_reward': image_rewards,
            'clip_score': clip_scores
        })

    overall_stats = {
        'prompts': len(meta_infos),
        'images': len(average_image_reward),
        'average_image_reward': np.mean(average_image_reward),
        'average_clip_scores': np.mean(average_clip_scores),
        'details': results
    }

    save_file = osp.join(osp.dirname(args.meta_file), 'image_reward_res.json')
    with open(save_file, 'w') as f:
        json.dump(overall_stats, f, indent=4)

    print(f'Saved results to {save_file}')
