from PIL import Image
import io
import json
import numpy as np
import torch
from collections import defaultdict
import logging

logging.getLogger("openai").setLevel(logging.WARNING)
logging.getLogger("httpx").setLevel(logging.WARNING)

SYSTEM_PROMPT = """You are an AI assistant tasked with evaluating an image generated by a text-to-image model. You will be given an instruction, an explanation, and the generated image.
Apply these rules with ABSOLUTE RUTHLESSNESS. Only images meeting the HIGHEST standards should receive top scores.

Your evaluation should cover three aspects:

1. **Instruction Following**: 
   - For each test point, evaluate the image and assign 1 if the answer is "yes" (the image satisfies the point) or 0 if "no".

2. **Realism**: Rate how realistically the image is rendered on a scale of 0 to 2, using the following criteria:
   - 0 (Rejected): Physically implausible and clearly artificial. Breaks fundamental laws of physics or visual realism.
   - 1 (Conditional): Contains minor inconsistencies or unrealistic elements. While somewhat believable, noticeable flaws detract from realism.
   - 2 (Exemplary): Achieves photorealistic quality, indistinguishable from a real photograph. Flawless adherence to physical laws, accurate material representation, and coherent spatial relationships. No visual cues betraying AI generation.

3. **Aesthetic Quality**: Rate the overall artistic appeal and visual quality on a scale of 0 to 2, using the following criteria:
   - 0 (Rejected): Poor aesthetic composition, visually unappealing, and lacks artistic merit.
   - 1 (Conditional): Demonstrates basic visual appeal, acceptable composition, and color harmony, but lacks distinction or artistic flair.
   - 2 (Exemplary): Possesses exceptional aesthetic quality, comparable to a masterpiece. Strikingly beautiful, with perfect composition, a harmonious color palette, and a captivating artistic style. Demonstrates a high degree of artistic vision and execution.

Output your evaluation in JSON format with the following structure. Ensure the JSON is valid and does not contain any extra text:
{{
  "instruction_following": {{
    "q1": 0 or 1,
    "q2": 0 or 1
  }},
  "realism": 0-2,
  "aesthetic": 0-2
}}

Instruction: {instruction}
Explanation: {explanation}
Testpoints: 
{testpoints}"""


def jpeg_incompressibility():
    def _fn(images, prompts, metadata):
        if isinstance(images, torch.Tensor):
            images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
            images = images.transpose(0, 2, 3, 1)  # NCHW -> NHWC
        images = [Image.fromarray(image) for image in images]
        buffers = [io.BytesIO() for _ in images]
        for image, buffer in zip(images, buffers):
            image.save(buffer, format="JPEG", quality=95)
        sizes = [buffer.tell() / 1000 for buffer in buffers]
        return np.array(sizes), {}

    return _fn


def jpeg_compressibility():
    jpeg_fn = jpeg_incompressibility()

    def _fn(images, prompts, metadata):
        rew, meta = jpeg_fn(images, prompts, metadata)
        return -rew/500, meta

    return _fn

def aesthetic_score():
    from flow_grpo.aesthetic_scorer import AestheticScorer

    scorer = AestheticScorer(dtype=torch.float32).cuda()

    def _fn(images, prompts, metadata):
        if isinstance(images, torch.Tensor):
            images = (images * 255).round().clamp(0, 255).to(torch.uint8)
        else:
            images = images.transpose(0, 3, 1, 2)  # NHWC -> NCHW
            images = torch.tensor(images, dtype=torch.uint8)
        scores = scorer(images)
        return scores, {}

    return _fn

def clip_score(device):
    from flow_grpo.clip_scorer import ClipScorer

    scorer = ClipScorer(device=device)

    def _fn(images, prompts, metadata):
        if not isinstance(images, torch.Tensor):
            images = images.transpose(0, 3, 1, 2)  # NHWC -> NCHW
            images = torch.tensor(images, dtype=torch.uint8)/255.0
        scores = scorer(images, prompts)
        return scores, {}

    return _fn

def image_similarity_score(device):
    from flow_grpo.clip_scorer import ClipScorer

    scorer = ClipScorer(device=device).cuda()

    def _fn(images, ref_images):
        if not isinstance(images, torch.Tensor):
            images = images.transpose(0, 3, 1, 2)  # NHWC -> NCHW
            images = torch.tensor(images, dtype=torch.uint8)/255.0
        if not isinstance(ref_images, torch.Tensor):
            ref_images = [np.array(img) for img in ref_images]
            ref_images = np.array(ref_images)
            ref_images = ref_images.transpose(0, 3, 1, 2)  # NHWC -> NCHW
            ref_images = torch.tensor(ref_images, dtype=torch.uint8)/255.0
        scores = scorer.image_similarity(images, ref_images)
        return scores, {}

    return _fn

def pickscore_score(device):
    from flow_grpo.pickscore_scorer import PickScoreScorer

    scorer = PickScoreScorer(dtype=torch.float32, device=device)

    def _fn(images, prompts, metadata):
        if isinstance(images, torch.Tensor):
            images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
            images = images.transpose(0, 2, 3, 1)  # NCHW -> NHWC
            images = [Image.fromarray(image) for image in images]
        scores = scorer(prompts, images)
        return scores, {}

    return _fn

def imagereward_score(device):
    from flow_grpo.imagereward_scorer import ImageRewardScorer

    scorer = ImageRewardScorer(dtype=torch.float32, device=device)

    def _fn(images, prompts, metadata):
        if isinstance(images, torch.Tensor):
            images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
            images = images.transpose(0, 2, 3, 1)  # NCHW -> NHWC
            images = [Image.fromarray(image) for image in images]
        prompts = [prompt for prompt in prompts]
        scores = scorer(prompts, images)
        return scores, {}

    return _fn

def qwenvl_score(device):
    from flow_grpo.qwenvl import QwenVLScorer

    scorer = QwenVLScorer(dtype=torch.bfloat16, device=device)

    def _fn(images, prompts, metadata):
        if isinstance(images, torch.Tensor):
            images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
            images = images.transpose(0, 2, 3, 1)  # NCHW -> NHWC
            images = [Image.fromarray(image) for image in images]
        prompts = [prompt for prompt in prompts]
        scores = scorer(prompts, images)
        return scores, {}

    return _fn

    
def ocr_score(device):
    from flow_grpo.ocr import OcrScorer

    scorer = OcrScorer()

    def _fn(images, prompts, metadata):
        if isinstance(images, torch.Tensor):
            images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
            images = images.transpose(0, 2, 3, 1)  # NCHW -> NHWC
        scores = scorer(images, prompts)
        # change tensor to list
        return scores, {}

    return _fn

def video_ocr_score(device):
    from flow_grpo.ocr import OcrScorer_video_or_image

    scorer = OcrScorer_video_or_image()

    def _fn(images, prompts, metadata):
        if isinstance(images, torch.Tensor):
            if images.dim() == 4 and images.shape[1] == 3:
                images = images.permute(0, 2, 3, 1) 
            elif images.dim() == 5 and images.shape[2] == 3:
                images = images.permute(0, 1, 3, 4, 2)
            images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
        scores = scorer(images, prompts)
        # change tensor to list
        return scores, {}

    return _fn

def deqa_score_remote(device):
    """Submits images to DeQA and computes a reward.
    """
    import requests
    from requests.adapters import HTTPAdapter, Retry
    from io import BytesIO
    import pickle

    batch_size = 64
    url = "http://127.0.0.1:18086"
    sess = requests.Session()
    retries = Retry(
        total=1000, backoff_factor=1, status_forcelist=[500], allowed_methods=False
    )
    sess.mount("http://", HTTPAdapter(max_retries=retries))

    def _fn(images, prompts, metadata):
        del prompts
        if isinstance(images, torch.Tensor):
            images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
            images = images.transpose(0, 2, 3, 1)  # NCHW -> NHWC
        images_batched = np.array_split(images, np.ceil(len(images) / batch_size))
        all_scores = []
        for image_batch in images_batched:
            jpeg_images = []

            # Compress the images using JPEG
            for image in image_batch:
                img = Image.fromarray(image)
                buffer = BytesIO()
                img.save(buffer, format="JPEG")
                jpeg_images.append(buffer.getvalue())

            # format for LLaVA server
            data = {
                "images": jpeg_images,
            }
            data_bytes = pickle.dumps(data)

            # send a request to the llava server
            response = sess.post(url, data=data_bytes, timeout=120)
            response_data = pickle.loads(response.content)

            all_scores += response_data["outputs"]

        return all_scores, {}

    return _fn

def geneval_score(device):
    """Submits images to GenEval and computes a reward.
    """
    import requests
    from requests.adapters import HTTPAdapter, Retry
    from io import BytesIO
    import pickle

    batch_size = 64
    url = "http://127.0.0.1:18085"
    sess = requests.Session()
    retries = Retry(
        total=1000, backoff_factor=1, status_forcelist=[500], allowed_methods=False
    )
    sess.mount("http://", HTTPAdapter(max_retries=retries))

    def _fn(images, prompts, metadatas, only_strict):
        del prompts
        if isinstance(images, torch.Tensor):
            images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
            images = images.transpose(0, 2, 3, 1)  # NCHW -> NHWC
        images_batched = np.array_split(images, np.ceil(len(images) / batch_size))
        metadatas_batched = np.array_split(metadatas, np.ceil(len(metadatas) / batch_size))
        all_scores = []
        all_rewards = []
        all_strict_rewards = []
        all_group_strict_rewards = []
        all_group_rewards = []
        for image_batch, metadata_batched in zip(images_batched, metadatas_batched):
            jpeg_images = []

            # Compress the images using JPEG
            for image in image_batch:
                img = Image.fromarray(image)
                buffer = BytesIO()
                img.save(buffer, format="JPEG")
                jpeg_images.append(buffer.getvalue())

            # format for LLaVA server
            data = {
                "images": jpeg_images,
                "meta_datas": list(metadata_batched),
                "only_strict": only_strict,
            }
            data_bytes = pickle.dumps(data)

            # send a request to the llava server
            response = sess.post(url, data=data_bytes, timeout=120)
            response_data = pickle.loads(response.content)

            all_scores += response_data["scores"]
            all_rewards += response_data["rewards"]
            all_strict_rewards += response_data["strict_rewards"]
            all_group_strict_rewards.append(response_data["group_strict_rewards"])
            all_group_rewards.append(response_data["group_rewards"])
        all_group_strict_rewards_dict = defaultdict(list)
        all_group_rewards_dict = defaultdict(list)
        for current_dict in all_group_strict_rewards:
            for key, value in current_dict.items():
                all_group_strict_rewards_dict[key].extend(value)
        all_group_strict_rewards_dict = dict(all_group_strict_rewards_dict)

        for current_dict in all_group_rewards:
            for key, value in current_dict.items():
                all_group_rewards_dict[key].extend(value)
        all_group_rewards_dict = dict(all_group_rewards_dict)

        return all_scores, all_rewards, all_strict_rewards, all_group_rewards_dict, all_group_strict_rewards_dict

    return _fn

def unifiedreward_score_remote(device):
    """Submits images to DeQA and computes a reward.
    """
    import requests
    from requests.adapters import HTTPAdapter, Retry
    from io import BytesIO
    import pickle

    batch_size = 64
    url = "http://10.82.120.15:18085"
    sess = requests.Session()
    retries = Retry(
        total=1000, backoff_factor=1, status_forcelist=[500], allowed_methods=False
    )
    sess.mount("http://", HTTPAdapter(max_retries=retries))

    def _fn(images, prompts, metadata):
        if isinstance(images, torch.Tensor):
            images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
            images = images.transpose(0, 2, 3, 1)  # NCHW -> NHWC
        images_batched = np.array_split(images, np.ceil(len(images) / batch_size))
        prompts_batched = np.array_split(prompts, np.ceil(len(prompts) / batch_size))

        all_scores = []
        for image_batch, prompt_batch in zip(images_batched, prompts_batched):
            jpeg_images = []

            # Compress the images using JPEG
            for image in image_batch:
                img = Image.fromarray(image)
                buffer = BytesIO()
                img.save(buffer, format="JPEG")
                jpeg_images.append(buffer.getvalue())

            # format for LLaVA server
            data = {
                "images": jpeg_images,
                "prompts": prompt_batch
            }
            data_bytes = pickle.dumps(data)

            # send a request to the llava server
            response = sess.post(url, data=data_bytes, timeout=120)
            print("response: ", response)
            print("response: ", response.content)
            response_data = pickle.loads(response.content)

            all_scores += response_data["outputs"]

        return all_scores, {}

    return _fn

def unifiedreward_score_sglang(device):
    from openai import OpenAI
    import base64
    from io import BytesIO
    import re 
    import time

    def pil_image_to_base64(image):
        buffered = BytesIO()
        image.save(buffered, format="PNG")
        encoded_image_text = base64.b64encode(buffered.getvalue()).decode("utf-8")
        # 修正为包含 MIME 类型与编码格式的前缀，避免后端阻塞在图像解析
        base64_qwen = f"data:image/png;base64,{encoded_image_text}"
        return base64_qwen

    def _extract_scores(sample_list):
        pattern = r"(\w+) Score \(1-5\):\s*([0-5](?:\.\d+)?)"
        # print(sample_list[0])

        all_scores = []
        for response in sample_list:
            matches = re.findall(pattern, response)
            score = {}
            for key, value in matches:
                if key not in ["Alignment", "Style", "Coherence"]:
                    continue
                score[key] = float(value)
            all_scores.append(score)

        return all_scores

    client = OpenAI(
        base_url="", 
        api_key="flowgrpo",
        timeout=120.0,  # 添加超时设置
        max_retries=3   # 添加重试次数
    )
        
    def evaluate_image(prompt, image, retry_count=3):
        question = (
            "You are presented with a generated image and its associated text caption. "
            "Your task is to analyze the image across multiple dimensions in relation to the caption. Specifically:\n"
            "Provide overall assessments for the image along the following axes (each rated from 1 to 5):\n"
            "- Alignment Score: How well the image matches the caption in terms of content.\n"
            "- Coherence Score: How logically consistent the image is (absence of visual glitches, object distortions, etc.).\n"
            "- Style Score: How aesthetically appealing the image looks, regardless of caption accuracy.\n\n"
            "Output your evaluation using the format below:\n\n"
            "Alignment Score (1-5): X\n"
            "Coherence Score (1-5): Y\n"
            "Style Score (1-5): Z\n\n"
            "Your task is provided as follows:\n"
            f"Text Caption: [{prompt}]"
        )
        images_base64 = pil_image_to_base64(image)
        
        # 添加重试逻辑
        for attempt in range(retry_count):
            try:
                response = client.chat.completions.create(
                    model="UnifiedReward",
                    messages=[
                        {
                            "role": "user",
                            "content": [
                                {
                                    "type": "image_url",
                                    "image_url": {"url": images_base64},
                                },
                                {
                                    "type": "text",
                                    "text": question,
                                }
                            ],
                        },
                    ],
                    temperature=0,
                    max_tokens=256,
                )
                return response.choices[0].message.content
            except Exception as e:
                if attempt < retry_count - 1:
                    wait_time = (attempt + 1) * 2  # 指数退避
                    print(f"Attempt {attempt + 1} failed, retrying in {wait_time}s: {e}")
                    time.sleep(wait_time)
                else:
                    print(f"All retry attempts failed: {e}")
                    return "Final Score: 0"

    def evaluate_batch_image(images, prompts, batch_size=4):
        """分批处理以避免过载vLLM"""
        results = []
        for i in range(0, len(images), batch_size):
            batch_images = images[i:i + batch_size]
            batch_prompts = prompts[i:i + batch_size]
            
            batch_results = []
            for prompt, img in zip(batch_prompts, batch_images):
                result = evaluate_image(prompt, img)
                batch_results.append(result)
            
            results.extend(batch_results)
            
            # 添加批次间的短暂延迟，避免过载
            if i + batch_size < len(images):
                time.sleep(0.1)
        
        return results

    def _fn(images, prompts, metadata):
        # 处理Tensor类型转换
        if isinstance(images, torch.Tensor):
            images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
            images = images.transpose(0, 2, 3, 1)  # NCHW -> NHWC
        
        # 转换为PIL Image并调整尺寸（强制RGB，避免RGBA引起后端阻塞或异常）
        images = [Image.fromarray(image).convert("RGB").resize((512, 512)) for image in images]

        # 执行批量评估（分批处理）
        text_outputs = evaluate_batch_image(images, prompts, batch_size=4)
        scores = _extract_scores(text_outputs)
        final_score = []
        detailed_scores = {
            "alignment": [],
            "coherence": [],
            "style": [],
        }
        for score in scores:
            alignment = score.get("Alignment", 0) / 5.0
            coherence = score.get("Coherence", 0) / 5.0
            style = score.get("Style", 0) / 5.0
            combined_score = 0.3 * alignment + 0.2 * coherence + 0.5 * style
            final_score.append(combined_score)
            detailed_scores["alignment"].append(alignment)
            detailed_scores["coherence"].append(coherence)
            detailed_scores["style"].append(style)
        return final_score, detailed_scores
    
    return _fn

def testpoint_score(device):
    import asyncio
    from openai import AsyncOpenAI
    import base64
    from io import BytesIO
    import re 

    def pil_image_to_base64(image):
        buffered = BytesIO()
        image.save(buffered, format="PNG")
        encoded_image_text = base64.b64encode(buffered.getvalue()).decode("utf-8")
        base64_qwen = f"data:image;base64,{encoded_image_text}"
        return base64_qwen

    def _extract_scores(text_outputs):
        scores = []
        pattern = r'```json\s*({.*?})\s*```'
        for idx, text in enumerate(text_outputs):
            try:
                print(text)
                match = re.search(pattern, text, re.DOTALL)
                if match:
                    json_data = json.loads(match.group(1))
                    # 验证 JSON 结构
                    if not isinstance(json_data.get("instruction_following"), dict):
                        print(f"Warning: Invalid instruction_following format in output {idx}")
                        json_data["instruction_following"] = {}
                    if not isinstance(json_data.get("realism"), (int, float)):
                        print(f"Warning: Invalid realism value in output {idx}")
                        json_data["realism"] = 0
                    if not isinstance(json_data.get("aesthetic"), (int, float)):
                        print(f"Warning: Invalid aesthetic value in output {idx}")
                        json_data["aesthetic"] = 0
                    scores.append(json_data)
                else:
                    # 尝试直接解析整个文本
                    json_data = json.loads(text)
                    if not isinstance(json_data.get("instruction_following"), dict):
                        json_data["instruction_following"] = {}
                    if not isinstance(json_data.get("realism"), (int, float)):
                        json_data["realism"] = 0
                    if not isinstance(json_data.get("aesthetic"), (int, float)):
                        json_data["aesthetic"] = 0
                    scores.append(json_data)
            except (ValueError, json.JSONDecodeError) as e:
                print(f"Error parsing JSON from output {idx}: {e}")
                print(f"Output text: {text[:500]}...")  # 只打印前500字符
                scores.append({
                    "instruction_following": {},
                    "realism": 0,
                    "aesthetic": 0
                })
        return scores

    client = AsyncOpenAI(base_url="", api_key="flowgrpo")
        
    async def evaluate_image(prompt, image, metadata):
        question = SYSTEM_PROMPT.format(
            instruction=prompt,
            explanation=metadata["explanation"],
            testpoints=metadata["testpoints"]
        )
        images_base64 = pil_image_to_base64(image)
        response = await client.chat.completions.create(
            model="Qwen3-VL-30B-A3B-Instruct",
            messages=[
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "text",
                            "text": question,
                        },
                        {
                            "type": "image_url",
                            "image_url": {"url": images_base64},
                        },
                    ],
                },
            ],
            temperature=0.2,
        )
        return response.choices[0].message.content

    async def evaluate_batch_image(images, prompts, metadata):
        tasks = [evaluate_image(prompt, img, meta) for prompt, img, meta in zip(prompts, images, metadata)]
        results = await asyncio.gather(*tasks)
        return results

    def _fn(images, prompts, metadata):
        # 处理Tensor类型转换
        if isinstance(images, torch.Tensor):
            images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
            images = images.transpose(0, 2, 3, 1)  # NCHW -> NHWC
        
        # 转换为PIL Image并调整尺寸
        images = [Image.fromarray(image).resize((512, 512)) for image in images]

        # 执行异步批量评估
        text_outputs = asyncio.run(evaluate_batch_image(images, prompts, metadata))
        scores = _extract_scores(text_outputs)
        final_score = []
        detailed_scores = {
            "consistency": [],
            "realism": [],
            "aesthetic": []
        }
        for score in scores:
            score["consistency"] = np.mean(list(score["instruction_following"].values())) if score["instruction_following"] else 0.0
            detailed_scores["consistency"].append(score["consistency"])
            detailed_scores["realism"].append(score["realism"] / 2.0)
            detailed_scores["aesthetic"].append(score["aesthetic"] / 2.0)
            final_score.append(
                0.7 * score["consistency"] + 0.2 * score["realism"] / 2.0 + 0.1 * score["aesthetic"] / 2.0
            )
        return final_score, detailed_scores

    return _fn


def uniredit_reward(device, image_size=512, max_tokens=2048, temperature=0.2):
    from openai import OpenAI
    import base64
    from io import BytesIO
    import os
    import re
    import time

    def pil_image_to_base64(image):
        buffered = BytesIO()
        image.save(buffered, format="PNG")
        encoded_image_text = base64.b64encode(buffered.getvalue()).decode("utf-8")
        base64_qwen = f"data:image;base64,{encoded_image_text}"
        return base64_qwen

    prompt_consistency = """You are a highly skilled image evaluator. You will receive two images (an original image and a modified image) along with a specific modification instruction. The second image is known to have been altered based on this instruction, starting from the first image. Your task is to evaluate whether the two images maintain consistency in aspects not related to the given instruction.\n\n## Task\nEvaluate the consistency between the images according to the following scale (1 to 5):\n\n- **5 (Perfect Consistency)**: Apart from changes explicitly required by the instruction, all other details (e.g., personal features, clothing, background, layout, colors, positions of objects) are completely identical between the two images.\n\n- **4 (Minor Differences)**: Apart from changes explicitly required by the instruction, the second image is mostly consistent with the original image but contains a minor discrepancy (such as a missing minor personal feature, accessory, or tattoo).\n\n- **3 (Noticeable Differences)**: Apart from changes explicitly required by the instruction, the second image has one significant difference from the original (such as a noticeable alteration in a person's appearance like hair or skin color, or a significant change in background environment).\n\n- **2 (Significant Differences)**: Apart from changes explicitly required by the instruction, the second image has two or more significant differences or multiple noticeable inconsistencies (such as simultaneous changes in both personal appearance and background environment).\n\n- **1 (Severe Differences)**: Apart from changes explicitly required by the instruction, nearly all key details (e.g., gender, major appearance features, background environment, or scene layout) significantly differ from the original image, clearly deviating from the original.\n\nExample:\n\nOriginal image: A blue-and-white floral vase on a wooden stand near a sunlit window, with beige wall and curtains in the background.\nInstruction: \"Throw the baseball towards the vase with sufficient force to make a crack.\"\n\n- **Score 5**: All non-instruction details are identical: same window panes and view, curtain shape, wooden stand design and wood grain, same lighting.\n- **Score 4**: The window changes (e.g., pane geometry, framing, or outside view), while curtains, wooden stand, lighting.\n- **Score 3**: Both the window and the curtains show noticeable changes; wooden stand, lighting.\n- **Score 2**: The background as a whole is clearly different (e.g., wall tone/texture and window/curtain styling or lighting all shift), though the wooden stand still matches.\n- **Score 1**: The background differs and the wooden stand also changes (design/color/material) or is missing; overall scene layout is largely different.\n\nNote: When assigning scores, only consider details unrelated to the instruction. Changes explicitly requested by the instruction should NOT be regarded as inconsistencies.\n\n## Input\n\n**Instruction:** {instruct}\n\n## Output Format\n\nProvide a detailed, step-by-step explanation of your scoring process. Conclude clearly with the final score, formatted as:\n\n**Final Score:** **1-5**"""

    prompt_reasoning = """You are an expert image evaluator. For each task, you will be provided with:\n\n\n1. An **original image**. The image before image editing.\n2. An **instruction** describing how the original image should be modified.\n3. A **ground-truth textual description** that represents the intended result of the modification.\n4. A **reference image**. This shows the intended visual effects of the result. It is a guide for how the result should look.\n5. An **output image**. generated by an assistant.\n\n\nYour task is to assess the output image based on the following evaluation dimension:\n\n## Evaluation Dimension: Alignment Between Image and Reference Description\nAssess how accurately the output image aligns with the literal text of the **ground-truth description**. You should also compare the output image with the reference image to help judge whether the intended visual effects have been successfully achieved.\n\n**Scoring Criteria:**\n- **5**: The image completely matches the description, accurately reflecting every detail and degree.\n- **4**: The image mostly matches the description, with minor discrepancies.\n- **3**: The image partially matches the description but contains differences or lacks some details.\n- **2**: The image contains noticeable difference. Important details are missed or clearly inaccurate.\n- **1**: The image fails to follow the instruction and does not correspond to the description at all.\n\n**Example**\nInstruction: Turn this image into a rainy evening.\nDescription: The street ground has become wet and shiny, and there are visible puddles reflecting lights from the streetlamps and shop windows. The sky is a dark, dim blue, and the streetlamps are turned on.\nReference Image: A photo of a different city street at night, capturing the intended rainy effect. It clearly shows a dark, moody blue sky, bright orange streetlights, and strong, sharp reflections of these lights on the wet ground.\n- **5**: Wet/shiny ground, visible puddles with clear reflections, dark blue sky, and streetlamps are on.\n- **4**: Wet/shiny ground and streetlamps are on, but no visible puddles or unclear reflections.\n- **3**: Dark sky and streetlamps are on, but the ground is completely dry.\n- **2**: Only a slightly darkened sky; ground remains dry and streetlamps are off.\n- **1**: Image is still in daytime or incorrectly edited (e.g., shows snow).\n\n## Input\n****\n**original image:**  The first image uploaded\n**Instruction**  {instruct}\n**GroundTruth Description:** {reference}\n**reference image:** The second image uploaded.\n**Output Image:** The third image uploaded.\n\n\n## Output Format\n\nProvide a detailed, step-by-step explanation of your scoring process. Conclude clearly with the final score, formatted as:\n\n**Final Score:** **X**"""

    prompt_quality = """You are an expert image evaluator. For each task, you will be provided with an **output image** generated by an assistant.\n\nYour task is to independently assess the image along the following dimension and assign an integer score from **1 to 5**:\n\n### Evaluation Dimension: Realism and Generation Quality\n\nAssess the overall visual realism and generation fidelity of the image. Consider the image’s clarity, natural appearance, and compliance with physical plausibility and real-world constraints.\n\n**Scoring Guidelines:**\n\n- **5** The image is sharp, visually coherent, and all elements appear highly realistic and physically plausible.\n- **4** The image is clear, with most elements appearing realistic; minor details may show slight unreality.\n- **3** The image is mostly clear, but some significant elements appear unrealistic or physically implausible.\n- **2** The image is noticeably blurry or contains major unrealistic components or visual distortions.\n- **1** The image is extremely blurry, incoherent, or severely unrealistic; realism is nearly absent.\n\n## Output Format\n\nAfter the evaluation, conclude clearly with the final score, formatted as:\n\n**Final Score:** **X**"""

    client = OpenAI(
        base_url="",
        api_key="",
        timeout=120.0,
        max_retries=3,
    )

    def _pil_to_data_url(image):
        # 调整尺寸并转换为RGB
        if image.mode in ("RGBA", "P"):
            image = image.convert("RGB")
        if image.size != (image_size, image_size):
            image = image.resize((image_size, image_size))
        return pil_image_to_base64(image)

    def _build_message(prompt_text, images):
        content = [{"type": "text", "text": prompt_text}]
        for img in images:
            content.append(
                {
                    "type": "image_url",
                    "image_url": {"url": _pil_to_data_url(img)},
                }
            )
        return [{"role": "user", "content": content}]

    def _call_model(messages, retries=3):
        last_error = None
        for attempt in range(retries):
            try:
                # print(messages)
                response = client.chat.completions.create(
                    model="",
                    messages=messages,
                    max_tokens=max_tokens,
                    temperature=temperature,
                )
                return response.choices[0].message.content
            except Exception as exc:
                last_error = exc
                wait_time = min(2 ** attempt, 8)
                time.sleep(wait_time)
        print(f"UniREdit reward request failed: {last_error}")
        return ""

    def _extract_scores(answer):
        if not answer:
            return None
        matches = re.findall(r"\*?\*?Final Score\*?\*?:?\s*([\d*\s,\n]*)", answer, re.IGNORECASE)
        numbers = []
        for match in matches:
            extracted = re.findall(r"\d+", match.replace("\n", " "))
            if extracted:
                numbers.extend(map(int, extracted))
                break
        if numbers:
            return numbers
        matches = re.findall(r"\*?\*?Final Scores\*?\*?:?\s*([\d*\s,\n]*)", answer, re.IGNORECASE)
        for match in matches:
            extracted = re.findall(r"\d+", match.replace("\n", " "))
            if extracted:
                return list(map(int, extracted))
        return None

    def _normalize(value):
        if value is None:
            return 0.0
        return float(np.clip((float(value) - 1.0) / 4.0, 0.0, 1.0))

    def _ensure_pil_list(batch, label):
        if batch is None:
            raise ValueError(f"{label} is required for UniREdit reward.")
        if isinstance(batch, torch.Tensor):
            batch = (batch * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
            batch = batch.transpose(0, 2, 3, 1)
        if isinstance(batch, np.ndarray):
            arr = batch
            if arr.dtype != np.uint8:
                if np.issubdtype(arr.dtype, np.floating):
                    arr = np.clip(arr * 255.0, 0, 255).astype(np.uint8)
                else:
                    arr = arr.astype(np.uint8)
            return [Image.fromarray(img).convert("RGB") for img in arr]
        if isinstance(batch, (list, tuple)):
            pil_batch = []
            for item in batch:
                if isinstance(item, Image.Image):
                    pil_batch.append(item.convert("RGB"))
                elif isinstance(item, np.ndarray):
                    arr = item
                    if arr.dtype != np.uint8:
                        if np.issubdtype(arr.dtype, np.floating):
                            arr = np.clip(arr * 255.0, 0, 255).astype(np.uint8)
                        else:
                            arr = arr.astype(np.uint8)
                    pil_batch.append(Image.fromarray(arr).convert("RGB"))
                else:
                    raise TypeError(f"Unsupported item in {label}: {type(item)}")
            return pil_batch
        raise TypeError(f"Unsupported data type for {label}: {type(batch)}")

    def _load_gt_image(path):
        if not path:
            raise FileNotFoundError("metadata must provide 'gt_image' for UniREdit reward.")
        expanded = os.path.expanduser(path)
        if not os.path.exists(expanded):
            raise FileNotFoundError(f"Ground-truth image not found: {expanded}")
        with Image.open(expanded) as pil_img:
            return pil_img.convert("RGB").copy().resize((image_size, image_size))

    def _prepare_metadata(metadata, batch_size):
        if metadata is None:
            return [{} for _ in range(batch_size)]
        if isinstance(metadata, (list, tuple)):
            if len(metadata) != batch_size:
                raise ValueError("metadata length must match batch size for UniREdit reward.")
            return [m or {} for m in metadata]
        return [metadata for _ in range(batch_size)]

    def _prepare_prompts(prompts, batch_size):
        if isinstance(prompts, (str, bytes)):
            return [prompts] * batch_size
        prompt_list = list(prompts)
        if len(prompt_list) == 1 and batch_size > 1:
            return prompt_list * batch_size
        if len(prompt_list) != batch_size:
            raise ValueError("prompts length does not match batch size for UniREdit reward.")
        return prompt_list

    def _fn(images, prompts, metadata, ref_images=None):
        generated = _ensure_pil_list(images, "images")
        originals = _ensure_pil_list(ref_images, "ref_images")

        if len(generated) != len(originals):
            raise ValueError("images and ref_images must share the same batch size for UniREdit reward.")

        num_samples = len(generated)
        metadata_list = _prepare_metadata(metadata, num_samples)
        prompt_list = _prepare_prompts(prompts, num_samples)

        final_scores = []
        detailed_scores = {"consistency": [], "reasoning": [], "visual_quality": []}

        for prompt_text, gen_img, ref_img, meta in zip(prompt_list, generated, originals, metadata_list):
            meta = meta or {}
            if not isinstance(meta, dict):
                meta = {"value": meta}

            gt_img = _load_gt_image(meta.get("gt_image"))
            instruction = meta.get("instruction") or prompt_text
            rules = meta.get("rules")
            if rules:
                instruction = f"{rules}\n{instruction}"

            reference_text = (
                meta.get("reference_effect")
                or ""
            )

            judge1 = _call_model(
                _build_message(
                    prompt_consistency.format(instruct=instruction),
                    [ref_img, gen_img],
                )
            )
            judge2 = _call_model(
                _build_message(
                    prompt_reasoning.format(instruct=instruction, reference=reference_text),
                    [ref_img, gt_img, gen_img],
                )
            )
            judge3 = _call_model(_build_message(prompt_quality, [gen_img]))

            cons_scores = _extract_scores(judge1)
            reas_scores = _extract_scores(judge2)
            qual_scores = _extract_scores(judge3)

            cons_norm = _normalize(cons_scores[0] if cons_scores else None)
            reas_norm = _normalize(reas_scores[0] if reas_scores else None)
            qual_norm = _normalize(qual_scores[0] if qual_scores else None)

            detailed_scores["consistency"].append(cons_norm)
            detailed_scores["reasoning"].append(reas_norm)
            detailed_scores["visual_quality"].append(qual_norm)

            final_scores.append(0.3 * cons_norm + 0.5 * reas_norm + 0.2 * qual_norm)

        return final_scores, detailed_scores

    return _fn

def wise_score(device, n_eval_per_image=1):
    from openai import OpenAI
    import base64
    from io import BytesIO
    import re 
    import time

    def pil_image_to_base64(image):
        buffered = BytesIO()
        image.save(buffered, format="PNG")
        encoded_image_text = base64.b64encode(buffered.getvalue()).decode("utf-8")
        base64_qwen = f"data:image;base64,{encoded_image_text}"
        return base64_qwen

    def _extract_scores(text_outputs):
        """从评估文本中提取分数"""
        scores = []
        for idx, text in enumerate(text_outputs):
            try:
                score_dict = {}
                
                # 匹配 Consistency: X
                consistency_match = re.search(r'Consistency:\s*(\d+(?:\.\d+)?)', text, re.IGNORECASE)
                if consistency_match:
                    score_dict['consistency'] = float(consistency_match.group(1))
                else:
                    score_dict['consistency'] = 0.0
                
                # 匹配 Realism: X
                realism_match = re.search(r'Realism:\s*(\d+(?:\.\d+)?)', text, re.IGNORECASE)
                if realism_match:
                    score_dict['realism'] = float(realism_match.group(1))
                else:
                    score_dict['realism'] = 0.0
                
                # 匹配 Aesthetic Quality: X
                aesthetic_match = re.search(r'Aesthetic Quality:\s*(\d+(?:\.\d+)?)', text, re.IGNORECASE)
                if aesthetic_match:
                    score_dict['aesthetic_quality'] = float(aesthetic_match.group(1))
                else:
                    score_dict['aesthetic_quality'] = 0.0
                
                scores.append(score_dict)
                
            except Exception as e:
                print(f"Error parsing output {idx}: {e}")
                print(f"Output text: {text[:500]}...")
                scores.append({
                    "consistency": 0.0,
                    "realism": 0.0,
                    "aesthetic_quality": 0.0
                })
        return scores

    client = OpenAI(
        base_url="", 
        api_key="flowgrpo",
        timeout=120.0,
        max_retries=3
    )
        
    def evaluate_image(prompt, image, metadata, retry_count=3):
        """同步评估函数"""
        
        question = f"""Please evaluate strictly and return ONLY the three scores as requested.

# Text-to-Image Quality Evaluation Protocol

## System Instruction
You are an AI quality auditor for text-to-image generation. Apply these rules with ABSOLUTE RUTHLESSNESS. No assuming is allowed. You MUST strictly follow the criteria.
Only images meeting the HIGHEST standards should receive top scores. As long as the image doesn't satisfy the criteria, give lower scores.

**Input Parameters**  
- PROMPT: [User's original prompt to]  
- EXPLANATION: [Further explanation of the original prompt] 
---

## Scoring Criteria

**Consistency (0-2):**  How accurately and completely the image reflects the PROMPT.
* **0 (Rejected):**  Fails to capture key elements of the prompt, or contradicts the prompt.
* **1 (Conditional):** Partially captures the prompt. Some elements are present, but not all, or not accurately.  Noticeable deviations from the prompt's intent.
* **2 (Exemplary):**  Perfectly and completely aligns with the PROMPT.  Every single element and nuance of the prompt is flawlessly represented in the image. The image is an ideal, unambiguous visual realization of the given prompt.

**Realism (0-2):**  How realistically the image is rendered.
* **0 (Rejected):**  Physically implausible and clearly artificial. Breaks fundamental laws of physics or visual realism.
* **1 (Conditional):** Contains minor inconsistencies or unrealistic elements.  While somewhat believable, noticeable flaws detract from realism.
* **2 (Exemplary):**  Achieves photorealistic quality, indistinguishable from a real photograph.  Flawless adherence to physical laws, accurate material representation, and coherent spatial relationships. No visual cues betraying AI generation.

**Aesthetic Quality (0-2):**  The overall artistic appeal and visual quality of the image.
* **0 (Rejected):**  Poor aesthetic composition, visually unappealing, and lacks artistic merit.
* **1 (Conditional):**  Demonstrates basic visual appeal, acceptable composition, and color harmony, but lacks distinction or artistic flair.
* **2 (Exemplary):**  Possesses exceptional aesthetic quality, comparable to a masterpiece.  Strikingly beautiful, with perfect composition, a harmonious color palette, and a captivating artistic style. Demonstrates a high degree of artistic vision and execution.

---

## Output Format

**Do not include any other text, explanations, or labels.** You must return only three lines of text, each containing a metric and the corresponding score, for example:

**Example Output:**
Consistency: 2
Realism: 1
Aesthetic Quality: 0

---

**IMPORTANT Enforcement:**

Be EXTREMELY strict in your evaluation. A score of '2' should be exceedingly rare and reserved only for images that truly excel and meet the highest possible standards in each metric. If there is any doubt, downgrade the score.

For **Consistency**, a score of '2' requires complete and flawless adherence to every aspect of the prompt, leaving no room for misinterpretation or omission.

For **Realism**, a score of '2' means the image is virtually indistinguishable from a real photograph in terms of detail, lighting, physics, and material properties.

For **Aesthetic Quality**, a score of '2' demands exceptional artistic merit, not just pleasant visuals.

--- 
Here are the Prompt and EXPLANATION for this evaluation:
PROMPT: "{metadata['original_prompt']}"
EXPLANATION: "{metadata['explanation']}"
Please strictly adhere to the scoring criteria and follow the template format when providing your results."""
        
        images_base64 = pil_image_to_base64(image)
        
        # 添加重试逻辑
        for attempt in range(retry_count):
            try:
                response = client.chat.completions.create(
                    model="Qwen3-VL-30B-A3B-Instruct",
                    messages=[
                        {
                            "role": "user",
                            "content": [
                                {
                                    "type": "text",
                                    "text": question,
                                },
                                {
                                    "type": "image_url",
                                    "image_url": {"url": images_base64},
                                },
                            ],
                        },
                    ],
                    temperature=0.2,
                )
                return response.choices[0].message.content
            except Exception as e:
                if attempt < retry_count - 1:
                    wait_time = (attempt + 1) * 2  # 指数退避
                    print(f"Attempt {attempt + 1} failed, retrying in {wait_time}s: {e}")
                    time.sleep(wait_time)
                else:
                    print(f"All retry attempts failed: {e}")
                    return "Consistency: 0\nRealism: 0\nAesthetic Quality: 0"

    def evaluate_batch_image(images, prompts, metadata, n_eval_per_image=1):
        """同步批量评估，支持多次评测"""
        # 用于存储每张图片的多次评测结果
        all_results = [[] for _ in range(len(images))]
        
        # 对每张图片进行 n_eval_per_image 次评测
        for eval_round in range(n_eval_per_image):
            # if n_eval_per_image > 1:
            #     print(f"Evaluation round {eval_round + 1}/{n_eval_per_image}...")
            
            for i, (prompt, img, meta) in enumerate(zip(prompts, images, metadata)):
                try:
                    result = evaluate_image(prompt, img, meta)
                    all_results[i].append(result)
                except Exception as e:
                    print(f"Task {i} in round {eval_round + 1} failed with exception: {e}")
                    all_results[i].append("Consistency: 0\nRealism: 0\nAesthetic Quality: 0")
        
        # 如果只评测一次，直接返回结果
        if n_eval_per_image == 1:
            return [results[0] for results in all_results]
        
        # 多次评测则需要计算平均值
        averaged_results = []
        for i, results_per_image in enumerate(all_results):
            # 解析每次评测的分数
            parsed_scores = []
            for result in results_per_image:
                scores = _extract_scores([result])[0]
                parsed_scores.append(scores)
            
            # 计算平均分数
            avg_consistency = np.mean([s["consistency"] for s in parsed_scores])
            avg_realism = np.mean([s["realism"] for s in parsed_scores])
            avg_aesthetic = np.mean([s["aesthetic_quality"] for s in parsed_scores])
            
            # 构造平均结果文本
            averaged_text = f"Consistency: {avg_consistency:.2f}\nRealism: {avg_realism:.2f}\nAesthetic Quality: {avg_aesthetic:.2f}"
            averaged_results.append(averaged_text)
        
        return averaged_results

    def _fn(images, prompts, metadata):
        # 处理Tensor类型转换
        if isinstance(images, torch.Tensor):
            images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
            images = images.transpose(0, 2, 3, 1)  # NCHW -> NHWC
        
        # 转换为PIL Image并调整尺寸
        images = [Image.fromarray(image).resize((512, 512)) for image in images]

        # 执行同步批量评估（支持多次评测）
        text_outputs = evaluate_batch_image(images, prompts, metadata, n_eval_per_image)
        scores = _extract_scores(text_outputs)
        
        final_score = []
        detailed_scores = {
            "consistency": [],
            "realism": [],
            "aesthetic": []
        }
        
        for score in scores:
            # 归一化到 0-1 范围
            consistency_norm = score["consistency"] / 2.0
            realism_norm = score["realism"] / 2.0
            aesthetic_norm = score["aesthetic_quality"] / 2.0
            
            detailed_scores["consistency"].append(consistency_norm)
            detailed_scores["realism"].append(realism_norm)
            detailed_scores["aesthetic"].append(aesthetic_norm)
            
            # 加权计算最终分数: consistency=0.7, realism=0.2, aesthetic=0.1
            final_score.append(
                0.7 * consistency_norm + 0.2 * realism_norm + 0.1 * aesthetic_norm
            )
        
        return final_score, detailed_scores

    return _fn

def wise_consistency_score(device, n_eval_per_image=1):
    from openai import OpenAI
    import base64
    from io import BytesIO
    import re 
    import time

    def pil_image_to_base64(image):
        buffered = BytesIO()
        image.save(buffered, format="PNG")
        encoded_image_text = base64.b64encode(buffered.getvalue()).decode("utf-8")
        base64_qwen = f"data:image;base64,{encoded_image_text}"
        return base64_qwen

    def _extract_scores(text_outputs):
        """从评估文本中提取分数"""
        scores = []
        for idx, text in enumerate(text_outputs):
            try:
                score_dict = {}
                
                # 匹配 Consistency: X
                consistency_match = re.search(r'Consistency:\s*(\d+(?:\.\d+)?)', text, re.IGNORECASE)
                if consistency_match:
                    score_dict['consistency'] = float(consistency_match.group(1))
                else:
                    score_dict['consistency'] = 0.0
                
                scores.append(score_dict)
                
            except Exception as e:
                print(f"Error parsing output {idx}: {e}")
                print(f"Output text: {text[:500]}...")
                scores.append({
                    "consistency": 0.0,
                })
        return scores

    client = OpenAI(
        base_url="", 
        api_key="flowgrpo",
        timeout=120.0,
        max_retries=3
    )
        
    def evaluate_image(prompt, image, metadata, retry_count=3):
        """同步评估函数"""
        
        question = f"""Please evaluate strictly and return ONLY the three scores as requested.

# Text-to-Image Quality Evaluation Protocol

## System Instruction
You are an AI quality auditor for text-to-image generation. Apply these rules with ABSOLUTE RUTHLESSNESS. No assuming is allowed. You MUST strictly follow the criteria.
Only images meeting the HIGHEST standards should receive top scores. As long as the image doesn't satisfy the criteria, give lower scores.

**Input Parameters**  
- PROMPT: [User's original prompt to]  
- EXPLANATION: [Further explanation of the original prompt] 
---

## Scoring Criteria

**Consistency (0-2):**  How accurately and completely the image reflects the PROMPT.
* **0 (Rejected):**  Fails to capture key elements of the prompt, or contradicts the prompt.
* **1 (Conditional):** Partially captures the prompt. Some elements are present, but not all, or not accurately.  Noticeable deviations from the prompt's intent.
* **2 (Exemplary):**  Perfectly and completely aligns with the PROMPT.  Every single element and nuance of the prompt is flawlessly represented in the image. The image is an ideal, unambiguous visual realization of the given prompt.

---

## Output Format

**Do not include any other text, explanations, or labels.** You must return only three lines of text, each containing a metric and the corresponding score, for example:

**Example Output:**
Consistency: 2

---

**IMPORTANT Enforcement:**

Be EXTREMELY strict in your evaluation. A score of '2' should be exceedingly rare and reserved only for images that truly excel and meet the highest possible standards in each metric. If there is any doubt, downgrade the score.

For **Consistency**, a score of '2' requires complete and flawless adherence to every aspect of the prompt, leaving no room for misinterpretation or omission.

--- 
Here are the Prompt and EXPLANATION for this evaluation:
PROMPT: "{metadata['original_prompt']}"
EXPLANATION: "{metadata['explanation']}"
Please strictly adhere to the scoring criteria and follow the template format when providing your results."""
        
        images_base64 = pil_image_to_base64(image)
        
        # 添加重试逻辑
        for attempt in range(retry_count):
            try:
                response = client.chat.completions.create(
                    model="Qwen3-VL-30B-A3B-Instruct",
                    messages=[
                        {
                            "role": "user",
                            "content": [
                                {
                                    "type": "text",
                                    "text": question,
                                },
                                {
                                    "type": "image_url",
                                    "image_url": {"url": images_base64},
                                },
                            ],
                        },
                    ],
                    temperature=0.2,
                )
                return response.choices[0].message.content
            except Exception as e:
                if attempt < retry_count - 1:
                    wait_time = (attempt + 1) * 2  # 指数退避
                    print(f"Attempt {attempt + 1} failed, retrying in {wait_time}s: {e}")
                    time.sleep(wait_time)
                else:
                    print(f"All retry attempts failed: {e}")
                    return "Consistency: 0\nRealism: 0\nAesthetic Quality: 0"

    def evaluate_batch_image(images, prompts, metadata, n_eval_per_image=1):
        """同步批量评估，支持多次评测"""
        # 用于存储每张图片的多次评测结果
        all_results = [[] for _ in range(len(images))]
        
        # 对每张图片进行 n_eval_per_image 次评测
        for eval_round in range(n_eval_per_image):
            # if n_eval_per_image > 1:
            #     print(f"Evaluation round {eval_round + 1}/{n_eval_per_image}...")
            
            for i, (prompt, img, meta) in enumerate(zip(prompts, images, metadata)):
                try:
                    result = evaluate_image(prompt, img, meta)
                    all_results[i].append(result)
                except Exception as e:
                    print(f"Task {i} in round {eval_round + 1} failed with exception: {e}")
                    all_results[i].append("Consistency: 0\nRealism: 0\nAesthetic Quality: 0")
        
        # 如果只评测一次，直接返回结果
        if n_eval_per_image == 1:
            return [results[0] for results in all_results]
        
        # 多次评测则需要计算平均值
        averaged_results = []
        for i, results_per_image in enumerate(all_results):
            # 解析每次评测的分数
            parsed_scores = []
            for result in results_per_image:
                scores = _extract_scores([result])[0]
                parsed_scores.append(scores)
            
            # 计算平均分数
            avg_consistency = np.mean([s["consistency"] for s in parsed_scores])

            # 构造平均结果文本
            averaged_text = f"Consistency: {avg_consistency:.2f}"
            averaged_results.append(averaged_text)
        
        return averaged_results

    def _fn(images, prompts, metadata):
        # 处理Tensor类型转换
        if isinstance(images, torch.Tensor):
            images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
            images = images.transpose(0, 2, 3, 1)  # NCHW -> NHWC
        
        # 转换为PIL Image并调整尺寸
        images = [Image.fromarray(image).resize((512, 512)) for image in images]

        # 执行同步批量评估（支持多次评测）
        text_outputs = evaluate_batch_image(images, prompts, metadata, n_eval_per_image)
        scores = _extract_scores(text_outputs)
        
        final_score = []
        detailed_scores = {
            "consistency": [],
        }
        
        for score in scores:
            # 归一化到 0-1 范围
            consistency_norm = score["consistency"] / 2.0
            
            detailed_scores["consistency"].append(consistency_norm)
            
            # 加权计算最终分数: consistency=0.7, realism=0.2, aesthetic=0.1
            final_score.append(
                1.0 * consistency_norm
            )
        
        return final_score, detailed_scores

    return _fn


def random_score(device):
    import numpy as np
    def _fn(images, prompts, metadata):
        # 生成与输入图片数量相同的随机分数（0~1之间的浮点数）
        num_images = len(images) if isinstance(images, (list, np.ndarray)) else images.shape[0]
        scores = np.random.rand(num_images).tolist()
        return scores, {}
    return _fn

def multi_score(device, score_dict):
    score_functions = {
        "deqa": deqa_score_remote,
        "ocr": ocr_score,
        "video_ocr": video_ocr_score,
        "imagereward": imagereward_score,
        "pickscore": pickscore_score,
        "qwenvl": qwenvl_score,
        "aesthetic": aesthetic_score,
        "jpeg_compressibility": jpeg_compressibility,
        "unifiedreward": unifiedreward_score_sglang,
        "geneval": geneval_score,
        "clipscore": clip_score,
        "image_similarity": image_similarity_score,
        "wise": wise_score,
        "uniredit": uniredit_reward,
        "random": random_score,
        "wise_consistency": wise_consistency_score,
    }
    score_fns={}
    for score_name, weight in score_dict.items():
        score_fns[score_name] = score_functions[score_name](device) if 'device' in score_functions[score_name].__code__.co_varnames else score_functions[score_name]()

    # only_strict is only for geneval. During training, only the strict reward is needed, and non-strict rewards don't need to be computed, reducing reward calculation time.
    def _fn(images, prompts, metadata, ref_images=None, only_strict=True):
        total_scores = []
        score_details = {}
        
        for score_name, weight in score_dict.items():
            if score_name == "geneval":
                scores, rewards, strict_rewards, group_rewards, group_strict_rewards = score_fns[score_name](images, prompts, metadata, only_strict)
                score_details['accuracy'] = rewards
                score_details['strict_accuracy'] = strict_rewards
                for key, value in group_strict_rewards.items():
                    score_details[f'{key}_strict_accuracy'] = value
                for key, value in group_rewards.items():
                    score_details[f'{key}_accuracy'] = value
                # 使用 strict_rewards 作为主要分数
                scores = strict_rewards
            elif score_name == "image_similarity":
                scores, rewards = score_fns[score_name](images, ref_images)
            elif score_name in ["wise", "unifiedreward"]:
                scores, detailed_scores = score_fns[score_name](images, prompts, metadata)
                for key, value in detailed_scores.items():
                    score_details[f'{score_name}_{key}'] = value
            elif score_name == "uniredit":
                scores, detailed_scores = score_fns[score_name](images, prompts, metadata, ref_images=ref_images)
                for key, value in detailed_scores.items():
                    score_details[f'uniredit_{key}'] = value
            else:
                scores, rewards = score_fns[score_name](images, prompts, metadata)
            
            score_details[score_name] = scores
            weighted_scores = [weight * score for score in scores]
            
            if not total_scores:
                total_scores = weighted_scores
            else:
                total_scores = [total + weighted for total, weighted in zip(total_scores, weighted_scores)]
        
        score_details['avg'] = total_scores
        return score_details, {}

    return _fn

def main():
    import torchvision.transforms as transforms

    ori_image_path = ""
    
    gen_image_path = ""
    
    gt_image_path = ""
    
    transform = transforms.Compose([
        transforms.Resize((512, 512)),  # Resize to 512x512
        transforms.ToTensor(),  # Convert to tensor
    ])
    
    ori_images = torch.stack([transform(Image.open(ori_image_path).convert('RGB'))])
    
    gen_images = torch.stack([transform(Image.open(gen_image_path).convert('RGB'))])
    
    prompts = [
        'Plane shot'
    ]
    
    metadata = [{
        "gt_image": gt_image_path,
    }]
    
    score_dict = {
        "uniredit": 1.0
    }
    
    print(ori_images.shape)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    scoring_fn = multi_score(device, score_dict)
    scores, _ = scoring_fn(ori_images, prompts, metadata, ref_images=ori_images)
    print("Scores:", scores)
    

    # image_paths = [
    #     "nasa.jpg",
    # ]

    # images = torch.stack([transform(Image.open(image_path).convert('RGB')) for image_path in image_paths])
    # prompts=[
    #     'A astronaut’s glove floating in zero-g with "NASA 2049" on the wrist',
    # ]
    # metadata = {}  # Example metadata
    # score_dict = {
    #     "unifiedreward": 1.0
    # }
    # # Initialize the multi_score function with a device and score_dict
    # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # scoring_fn = multi_score(device, score_dict)
    # # Get the scores
    # scores, _ = scoring_fn(images, prompts, metadata)
    # # Print the scores
    # print("Scores:", scores)


if __name__ == "__main__":
    main()
