from transformers import AutoProcessor, AutoModel
from PIL import Image
import torch
import json
import ImageReward as RM
from tqdm import tqdm
import numpy as np

class ImageRewardScorer(torch.nn.Module):
    def __init__(self, device="cuda", dtype=torch.float32):
        super().__init__()
        self.model_path = "ImageReward-v1.0"
        self.device = device
        self.dtype = dtype
        self.model = RM.load(self.model_path, device=device).eval().to(dtype=dtype)
        self.model.requires_grad_(False)
        
    @torch.no_grad()
    def __call__(self, prompts, images):
        rewards = []
        for prompt,image in zip(prompts, images):
            _, reward = self.model.inference_rank(prompt, [image])
            rewards.append(reward)
        return rewards

# Usage example
def main():

    data = json.load(open("/workspace/user_code/DiffusionDPO/flow_grpo/scripts/inference/result.json", "r"))
    result = []

    scorer = ImageRewardScorer(
        device="cuda:1",
        dtype=torch.float32
    )
    
    for i in tqdm(range(len(data))):
        item = data[i]
        pil_images = [Image.open(item['image']), Image.open(item['ori_image'])]
        if np.all(np.array(pil_images[0]) == 0) or np.all(np.array(pil_images[1]) == 0):
            # print(os.path.basename(item['image']))
            continue
        prompts = [item["prompt"]] * 2
        score = scorer(prompts, pil_images)
        result.append(score)

    print(f'data number: {len(result)}')
    result = np.array(result)
    print(result.mean(axis=0))

if __name__ == "__main__":
    main()