from PIL import Image
import io
import numpy as np
import torch
import random
from collections import defaultdict

def jpeg_incompressibility():
    def _fn(images, prompts, metadata):
        if isinstance(images, torch.Tensor):
            images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
            images = images.transpose(0, 2, 3, 1)  # NCHW -> NHWC
        images = [Image.fromarray(image) for image in images]
        buffers = [io.BytesIO() for _ in images]
        for image, buffer in zip(images, buffers):
            image.save(buffer, format="JPEG", quality=95)
        sizes = [buffer.tell() / 1000 for buffer in buffers]
        return np.array(sizes), {}

    return _fn


def jpeg_compressibility():
    jpeg_fn = jpeg_incompressibility()

    def _fn(images, prompts, metadata):
        rew, meta = jpeg_fn(images, prompts, metadata)
        return -rew/500, meta

    return _fn

def aesthetic_score():
    from flow_grpo.aesthetic_scorer import AestheticScorer

    scorer = AestheticScorer(dtype=torch.float32).cuda()

    def _fn(images, prompts, metadata):
        if isinstance(images, torch.Tensor):
            images = (images * 255).round().clamp(0, 255).to(torch.uint8)
        else:
            images = images.transpose(0, 3, 1, 2)  # NHWC -> NCHW
            images = torch.tensor(images, dtype=torch.uint8)
        scores = scorer(images)
        return scores, {}

    return _fn


def hpsv20_score(device=None, hps_version: str = "v2.0", checkpoint: str | None = None, batch_size: int = 32):
    """
    Returns a callable(images, prompts, metadata=None) -> (scores, {})
    - images: list[torch.Tensor (C,H,W) | PIL.Image.Image | str path]
    - prompts: list[str] (same length as images)
    """
    import torch
    import numpy as np
    from PIL import Image
    from einops import rearrange
    from hpsv2.src.open_clip import create_model_and_transforms, get_tokenizer
    from hpsv2.utils import hps_version_map
    import huggingface_hub
    
    if device is None:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
    dev = torch.device(device)
    use_amp = (dev.type == "cuda")

    model, _, preprocess_val = create_model_and_transforms(
        'ViT-H-14',
        'laion2B-s32B-b79K',
        precision='amp',
        device=dev,
        jit=False,
        force_quick_gelu=False,
        force_custom_text=False,
        force_patch_dropout=False,
        force_image_size=None,
        pretrained_image=False,
        image_mean=None,
        image_std=None,
        light_augmentation=True,
        aug_cfg={},
        output_dict=True,
        with_score_predictor=False,
        with_region_predictor=False
    )

    if checkpoint is None:
        checkpoint = huggingface_hub.hf_hub_download("xswu/HPSv2", hps_version_map[hps_version])
    state = torch.load(checkpoint, map_location=dev)
    model.load_state_dict(state['state_dict'])
    model.to(dev).eval()

    tokenizer = get_tokenizer('ViT-H-14')

    def _to_pil(img):
        if isinstance(img, torch.Tensor):
            t = img.detach().cpu()
            if not t.is_floating_point():
                t = t.float()
            if t.max().item() > 1.5:
                t = t / 255.0
            t = t.clamp(0, 1)
            arr = (rearrange(t, "c h w -> h w c").numpy() * 255).astype(np.uint8)
            return Image.fromarray(arr).convert("RGB")
        elif isinstance(img, Image.Image):
            return img.convert("RGB")
        elif isinstance(img, str):
            return Image.open(img).convert("RGB")
        else:
            raise TypeError(f"Unsupported image type: {type(img)}")

    def _score_chunk(imgs, txts):
        with torch.no_grad():
            pil_imgs = [ _to_pil(im) for im in imgs ]
            img_batch = torch.stack([preprocess_val(p) for p in pil_imgs], dim=0).to(device=dev, non_blocking=True)
            txt_batch = tokenizer(txts).to(device=dev, non_blocking=True)

            
            out = model(img_batch, txt_batch)
            img_f, txt_f = out["image_features"], out["text_features"]
            logits = img_f @ txt_f.T                      # [B, B]
            diag = torch.diagonal(logits).detach().cpu()  # [B]
            
        return diag.numpy().tolist()

    def _fn(images, prompts, metadata=None):
        assert len(images) == len(prompts), "images and prompts must have same length"
        scores = []
        n = len(images)
        if n == 0:
            return scores, {}

        bs = max(1, int(batch_size))
        for i in range(0, n, bs):
            scores.extend(_score_chunk(images[i:i+bs], prompts[i:i+bs]))
        scores = [float(s) for s in scores]
        return scores, {}

    return _fn



def hpsv21_score(device):
    from imscore.hps.model import HPSv2
    from einops import rearrange

    hpsv2 = HPSv2.from_pretrained("RE-N-Y/hpsv21")
    hpsv2 = hpsv2.to(device)
    def _fn(images, prompts, metadata):
        all_scores = []
        for img, prompt in zip(images, prompts):

            if isinstance(img, torch.Tensor):
                tensor = img
                if not tensor.is_floating_point():
                    tensor = tensor.to(torch.float32)
                try:
                    needs_scaling = (tensor.max() > 1.5).item()
                except Exception:
                    needs_scaling = False
                if needs_scaling:
                    tensor = tensor / 255.0
                pixels = rearrange(tensor, "c h w -> 1 c h w").to(device)
            elif isinstance(img, Image.Image):
                pixels = np.array(img)
                pixels = rearrange(torch.tensor(pixels), "h w c -> 1 c h w").to(torch.float32) / 255.0
                pixels = pixels.to(device)
            elif isinstance(img, str):
                pixels = Image.open(img).convert("RGB")
                pixels = np.array(pixels)
                pixels = rearrange(torch.tensor(pixels), "h w c -> 1 c h w").to(torch.float32) / 255.0
                pixels = pixels.to(device)
            else:
                raise ValueError(f"Unsupported image type: {type(img)}")
            
            with torch.no_grad():
                score = hpsv2.score(pixels, prompt)
                score = score.item()
            all_scores.append(score)
        return all_scores, {}

    return _fn


def pickscore_score(device):
    from flow_grpo.pickscore_scorer import PickScoreScorer

    scorer = PickScoreScorer(dtype=torch.float32, device=device)

    def _fn(images, prompts, metadata):
        if isinstance(images, torch.Tensor):
            images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
            images = images.transpose(0, 2, 3, 1)  # NCHW -> NHWC
            images = [Image.fromarray(image) for image in images]
        scores = scorer(prompts, images)
        return scores, {}

    return _fn

def imagereward_score(device):
    from flow_grpo.imagereward_scorer import ImageRewardScorer

    scorer = ImageRewardScorer(dtype=torch.float32, device=device)

    def _fn(images, prompts, metadata):
        if isinstance(images, torch.Tensor):
            images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
            images = images.transpose(0, 2, 3, 1)  # NCHW -> NHWC
            images = [Image.fromarray(image) for image in images]
        prompts = [prompt for prompt in prompts]
        scores = scorer(prompts, images)
        return scores, {}

    return _fn

def qwenvl_score(device):
    from flow_grpo.qwenvl import QwenVLScorer

    scorer = QwenVLScorer(dtype=torch.bfloat16, device=device)

    def _fn(images, prompts, metadata):
        if isinstance(images, torch.Tensor):
            images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
            images = images.transpose(0, 2, 3, 1)  # NCHW -> NHWC
            images = [Image.fromarray(image) for image in images]
        prompts = [prompt for prompt in prompts]
        scores = scorer(prompts, images)
        return scores, {}

    return _fn




def qwenvl_thinking_score_remote(device):
    import requests
    from requests.adapters import HTTPAdapter, Retry
    from io import BytesIO
    import pickle
    import os

    batch_size = 8
    url_1 = "http://127.0.0.1:5000"
    url_2 = "http://127.0.0.1:5001"
    url_3 = "http://127.0.0.1:5002"   
    url_4 = "http://127.0.0.1:5003"
    urls = [url_1, url_2, url_3, url_4]
    req_timeout = int(os.environ.get("VLLM_REWARD_TIMEOUT", "1000"))

    sess = requests.Session()
    retries = Retry(total=1000, backoff_factor=1, status_forcelist=[500])
    sess.mount("http://", HTTPAdapter(max_retries=retries))

    def _fn(images, prompts, metadata, mode="paired_thinking", num_anchors=12): 
        if isinstance(images, torch.Tensor):
            images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
            images = images.transpose(0, 2, 3, 1)
        
        if mode == "paired_thinking":
            jpeg_images = []
            for img in images:
                buffer = BytesIO()
                Image.fromarray(img).save(buffer, format="JPEG")
                jpeg_images.append(buffer.getvalue())

            prompt_for_comparison = prompts[0] if prompts else "Which image is better?"
            data = {
                "images": jpeg_images, 
                "prompts": [prompt_for_comparison], 
                "mode": mode,
                "num_anchors": num_anchors
            }
            data_bytes = pickle.dumps(data)
            try:
                device_id = None
                if hasattr(device, 'index') and device.index is not None:
                    device_id = device.index
                elif str(device).startswith('cuda:'):
                    try:
                        device_id = int(str(device).split(':')[1])
                    except Exception:
                        device_id = 0
                else:
                    device_id = 0
                device_id = int(device_id) % max(1, len(urls))
                response = sess.post(urls[device_id], data=data_bytes, timeout=req_timeout)
                response.raise_for_status()
                result = pickle.loads(response.content)
                if "outputs" in result and "scores" in result["outputs"]:
                    all_scores = result["outputs"]["scores"]
                else:
                    all_scores = result["outputs"]
                return all_scores, {}
            except Exception as e:
                print(f"QwenVL server request failed: {e}")
                return [0.5] * len(jpeg_images), {}


    return _fn

    
def ocr_score(device):
    from flow_grpo.ocr import OcrScorer

    scorer = OcrScorer()

    def _fn(images, prompts, metadata):
        if isinstance(images, torch.Tensor):
            images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
            images = images.transpose(0, 2, 3, 1)
        scores = scorer(images, prompts)
        return scores, {}

    return _fn




def qwenvl_score_remote(device):
    import requests
    from requests.adapters import HTTPAdapter, Retry
    from io import BytesIO
    import pickle

    batch_size = 32
    url = "http://127.0.0.1:5000" 
    sess = requests.Session()
    retries = Retry(total=1000, backoff_factor=1, status_forcelist=[500])
    sess.mount("http://", HTTPAdapter(max_retries=retries))

    def _fn(images, prompts, metadata, mode="paired_fake_thinking", num_anchors=2):
        if isinstance(images, torch.Tensor):
            images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
            images = images.transpose(0, 2, 3, 1)
        
        if mode == "paired":
            jpeg_images = []
            for img in images:
                buffer = BytesIO()
                Image.fromarray(img).save(buffer, format="JPEG")
                jpeg_images.append(buffer.getvalue())
            
            prompt_for_comparison = prompts[0] if prompts else "Which image is better?"
            data = {
                "images": jpeg_images, 
                "prompts": [prompt_for_comparison], 
                "mode": mode,
                "num_anchors": num_anchors
            }
            data_bytes = pickle.dumps(data)
            try:
                response = sess.post(url, data=data_bytes, timeout=120)
                response.raise_for_status()
                result = pickle.loads(response.content)
                if "outputs" in result and "scores" in result["outputs"]:
                    all_scores = result["outputs"]["scores"]
                else:
                    all_scores = result["outputs"]
                return all_scores, {}
            except Exception as e:
                print(f"QwenVL server request failed: {e}")
                return [0.5] * len(jpeg_images), {}
        
        else:
            images_batched = np.array_split(images, np.ceil(len(images) / batch_size))
            prompts_batched = np.array_split(prompts, np.ceil(len(prompts) / batch_size))
            all_scores = []
            for img_batch, pr_batch in zip(images_batched, prompts_batched):
                jpeg_images = []
                for img in img_batch:
                    buffer = BytesIO()
                    Image.fromarray(img).save(buffer, format="JPEG")
                    jpeg_images.append(buffer.getvalue())
                data = {"images": jpeg_images, "prompts": list(pr_batch), "mode": mode}
                data_bytes = pickle.dumps(data)
                try:
                    response = sess.post(url, data=data_bytes, timeout=120)
                    response.raise_for_status()
                    result = pickle.loads(response.content)
                    if "outputs" in result and "scores" in result["outputs"]:
                        all_scores += result["outputs"]["scores"]
                    else:
                        all_scores += result["outputs"]
                except Exception as e:
                    print(f"QwenVL server request failed: {e}")
                    all_scores += [0.5] * len(jpeg_images)
            return all_scores, {}

    return _fn


def qwen_pref_score_remote(device):
    import requests
    from requests.adapters import HTTPAdapter, Retry
    from io import BytesIO
    import pickle
    import numpy as np
    from PIL import Image

    batch_size = 32
    sess = requests.Session()
    retries = Retry(total=1000, backoff_factor=1, status_forcelist=[500])
    sess.mount("http://", HTTPAdapter(max_retries=retries))

    def _fn(images, prompts, metadata, mode="binary", num_anchors=4):
        device_id = 0
        if hasattr(device, 'index') and device.index is not None:
            device_id = device.index
        elif str(device).startswith('cuda:'):
            try:
                device_id = int(str(device).split(':')[1])
            except (ValueError, IndexError):
                pass
        
        port = 5004 + device_id
        url = f"http://127.0.0.1:{port}"

        if isinstance(images, torch.Tensor):
            images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
            images = images.transpose(0, 2, 3, 1)
        
        jpeg_images = []
        for img in images:
            buffer = BytesIO()
            Image.fromarray(img).save(buffer, format="JPEG")
            jpeg_images.append(buffer.getvalue())

        if mode == "pairwise":
            prompt_for_comparison = prompts[0] if prompts else "Which image is better?"
            data = {
                "images": jpeg_images, 
                "prompts": [prompt_for_comparison], 
                "mode": mode,
                "num_anchors": num_anchors
            }
            data_bytes = pickle.dumps(data)
            try:
                response = sess.post(url, data=data_bytes, timeout=120)
                response.raise_for_status()
                result = pickle.loads(response.content)
                if "outputs" in result and "scores" in result["outputs"]:
                    all_scores = result["outputs"]["scores"]
                else:
                    all_scores = result["outputs"]
                all_scores = [float(s) for s in all_scores]
                return all_scores, {}
            except Exception as e:
                print(f"Qwen-Pref server request failed (pairwise): {e}")
                return [0.5] * len(jpeg_images), {}
        
        elif mode == "binary":
            if len(prompts) != len(jpeg_images):
                first_prompt = prompts[0] if prompts else ""
                prompts = [first_prompt] * len(jpeg_images)

            images_batched = np.array_split(jpeg_images, np.ceil(len(jpeg_images) / batch_size))
            prompts_batched = np.array_split(prompts, np.ceil(len(prompts) / batch_size))
            all_scores = []
            for img_batch, pr_batch in zip(images_batched, prompts_batched):
                data = {"images": img_batch, "prompts": list(pr_batch), "mode": mode}
                data_bytes = pickle.dumps(data)
                try:
                    response = sess.post(url, data=data_bytes, timeout=120)
                    response.raise_for_status()
                    result = pickle.loads(response.content)
                    if "outputs" in result and "scores" in result["outputs"]:
                        all_scores.extend(result["outputs"]["scores"])
                    else:
                        all_scores.extend(result["outputs"])
                except Exception as e:
                    print(f"Qwen-Pref server request failed (binary): {e}")
                    all_scores.extend([0.0] * len(img_batch))
            all_scores = [float(s) for s in all_scores]
            return all_scores, {}
        
        else:
            raise ValueError(f"Unsupported mode for qwen_pref_score_remote: {mode}")

    return _fn

    
def ocr_score(device):
    from flow_grpo.ocr import OcrScorer

    scorer = OcrScorer()

    def _fn(images, prompts, metadata):
        if isinstance(images, torch.Tensor):
            images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
            images = images.transpose(0, 2, 3, 1)  # NCHW -> NHWC
        scores = scorer(images, prompts)
        # change tensor to list
        return scores, {}

    return _fn


def deqa_score_remote(device):
    """Submits images to DeQA and computes a reward.
    """
    import requests
    from requests.adapters import HTTPAdapter, Retry
    from io import BytesIO
    import pickle

    batch_size = 64
    url = "http://127.0.0.1:18086"
    sess = requests.Session()
    retries = Retry(
        total=1000, backoff_factor=1, status_forcelist=[500], allowed_methods=False
    )
    sess.mount("http://", HTTPAdapter(max_retries=retries))

    def _fn(images, prompts, metadata):
        del prompts
        if isinstance(images, torch.Tensor):
            images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
            images = images.transpose(0, 2, 3, 1)  # NCHW -> NHWC
        images_batched = np.array_split(images, np.ceil(len(images) / batch_size))
        all_scores = []
        for image_batch in images_batched:
            jpeg_images = []

            # Compress the images using JPEG
            for image in image_batch:
                img = Image.fromarray(image)
                buffer = BytesIO()
                img.save(buffer, format="JPEG")
                jpeg_images.append(buffer.getvalue())

            # format for LLaVA server
            data = {
                "images": jpeg_images,
            }
            data_bytes = pickle.dumps(data)

            # send a request to the llava server
            response = sess.post(url, data=data_bytes, timeout=120)
            response_data = pickle.loads(response.content)

            all_scores += response_data["outputs"]

        return all_scores, {}

    return _fn

def geneval_score(device):
    """Submits images to GenEval and computes a reward.
    """
    import requests
    from requests.adapters import HTTPAdapter, Retry
    from io import BytesIO
    import pickle

    batch_size = 64
    url = "http://127.0.0.1:18085"
    sess = requests.Session()
    retries = Retry(
        total=1000, backoff_factor=1, status_forcelist=[500], allowed_methods=False
    )
    sess.mount("http://", HTTPAdapter(max_retries=retries))

    def _fn(images, prompts, metadatas, only_strict):
        del prompts
        if isinstance(images, torch.Tensor):
            images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
            images = images.transpose(0, 2, 3, 1)  # NCHW -> NHWC
        images_batched = np.array_split(images, np.ceil(len(images) / batch_size))
        metadatas_batched = np.array_split(metadatas, np.ceil(len(metadatas) / batch_size))
        all_scores = []
        all_rewards = []
        all_strict_rewards = []
        all_group_strict_rewards = []
        all_group_rewards = []
        for image_batch, metadata_batched in zip(images_batched, metadatas_batched):
            jpeg_images = []

            # Compress the images using JPEG
            for image in image_batch:
                img = Image.fromarray(image)
                buffer = BytesIO()
                img.save(buffer, format="JPEG")
                jpeg_images.append(buffer.getvalue())

            # format for LLaVA server
            data = {
                "images": jpeg_images,
                "meta_datas": list(metadata_batched),
                "only_strict": only_strict,
            }
            data_bytes = pickle.dumps(data)

            # send a request to the llava server
            response = sess.post(url, data=data_bytes, timeout=120)
            response_data = pickle.loads(response.content)

            all_scores += response_data["scores"]
            all_rewards += response_data["rewards"]
            all_strict_rewards += response_data["strict_rewards"]
            all_group_strict_rewards.append(response_data["group_strict_rewards"])
            all_group_rewards.append(response_data["group_rewards"])
        all_group_strict_rewards_dict = defaultdict(list)
        all_group_rewards_dict = defaultdict(list)
        for current_dict in all_group_strict_rewards:
            for key, value in current_dict.items():
                all_group_strict_rewards_dict[key].extend(value)
        all_group_strict_rewards_dict = dict(all_group_strict_rewards_dict)

        for current_dict in all_group_rewards:
            for key, value in current_dict.items():
                all_group_rewards_dict[key].extend(value)
        all_group_rewards_dict = dict(all_group_rewards_dict)

        return all_scores, all_rewards, all_strict_rewards, all_group_rewards_dict, all_group_strict_rewards_dict

    return _fn

def unifiedreward_score_remote(device):
    """Submits images to DeQA and computes a reward.
    """
    import requests
    from requests.adapters import HTTPAdapter, Retry
    from io import BytesIO
    import pickle

    batch_size = 64
    url = "http://10.82.120.15:18085"
    sess = requests.Session()
    retries = Retry(
        total=1000, backoff_factor=1, status_forcelist=[500], allowed_methods=False
    )
    sess.mount("http://", HTTPAdapter(max_retries=retries))

    def _fn(images, prompts, metadata):
        if isinstance(images, torch.Tensor):
            images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
            images = images.transpose(0, 2, 3, 1)  # NCHW -> NHWC
        images_batched = np.array_split(images, np.ceil(len(images) / batch_size))
        prompts_batched = np.array_split(prompts, np.ceil(len(prompts) / batch_size))

        all_scores = []
        for image_batch, prompt_batch in zip(images_batched, prompts_batched):
            jpeg_images = []

            # Compress the images using JPEG
            for image in image_batch:
                img = Image.fromarray(image)
                buffer = BytesIO()
                img.save(buffer, format="JPEG")
                jpeg_images.append(buffer.getvalue())

            data = {
                "images": jpeg_images,
                "prompts": prompt_batch
            }
            data_bytes = pickle.dumps(data)

            response = sess.post(url, data=data_bytes, timeout=120)
            print("response: ", response)
            print("response: ", response.content)
            response_data = pickle.loads(response.content)

            all_scores += response_data["outputs"]

        return all_scores, {}

    return _fn

def unifiedreward_score_sglang(device):
    import asyncio
    from openai import AsyncOpenAI
    import base64
    from io import BytesIO
    import re 

    def pil_image_to_base64(image):
        buffered = BytesIO()
        image.save(buffered, format="PNG")
        encoded_image_text = base64.b64encode(buffered.getvalue()).decode("utf-8")
        base64_qwen = f"data:image;base64,{encoded_image_text}"
        return base64_qwen

    def _extract_scores(text_outputs):
        scores = []
        pattern = r"Final Score:\s*([1-5](?:\.\d+)?)"
        for text in text_outputs:
            match = re.search(pattern, text)
            if match:
                try:
                    scores.append(float(match.group(1)))
                except ValueError:
                    scores.append(0.0)
            else:
                scores.append(0.0)
        return scores

    client = AsyncOpenAI(base_url="http://127.0.0.1:17140/v1", api_key="flowgrpo")
        
    async def evaluate_image(prompt, image):
        question = f"<image>\nYou are given a text caption and a generated image based on that caption. Your task is to evaluate this image based on two key criteria:\n1. Alignment with the Caption: Assess how well this image aligns with the provided caption. Consider the accuracy of depicted objects, their relationships, and attributes as described in the caption.\n2. Overall Image Quality: Examine the visual quality of this image, including clarity, detail preservation, color accuracy, and overall aesthetic appeal.\nBased on the above criteria, assign a score from 1 to 5 after \'Final Score:\'.\nYour task is provided as follows:\nText Caption: [{prompt}]"
        images_base64 = pil_image_to_base64(image)
        response = await client.chat.completions.create(
            model="UnifiedReward-7b-v1.5",
            messages=[
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "image_url",
                            "image_url": {"url": images_base64},
                        },
                        {
                            "type": "text",
                            "text": question,
                        },
                    ],
                },
            ],
            temperature=0,
        )
        return response.choices[0].message.content

    async def evaluate_batch_image(images, prompts):
        tasks = [evaluate_image(prompt, img) for prompt, img in zip(prompts, images)]
        results = await asyncio.gather(*tasks)
        return results

    def _fn(images, prompts, metadata):
        if isinstance(images, torch.Tensor):
            images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
            images = images.transpose(0, 2, 3, 1)  # NCHW -> NHWC
        
        images = [Image.fromarray(image).resize((512, 512)) for image in images]

        text_outputs = asyncio.run(evaluate_batch_image(images, prompts))
        score = _extract_scores(text_outputs)
        score = [sc/5.0 for sc in score]
        return score, {}
    
    return _fn

def multi_score(device, score_dict):
    score_functions = {
        "deqa": deqa_score_remote,
        "ocr": ocr_score,
        "imagereward": imagereward_score,
        "pickscore": pickscore_score,
        "qwenvl": qwenvl_score_remote,
        "qwenvl_thinking": qwenvl_thinking_score_remote,
        "qwenvl_thinking_fake": qwenvl_score_remote,
        "aesthetic": aesthetic_score,
        "jpeg_compressibility": jpeg_compressibility,
        "unifiedreward": unifiedreward_score_sglang,
        "geneval": geneval_score,
        "hpsv2": hpsv20_score,
        "qwen_pref": qwen_pref_score_remote,
    }
    score_fns={}
    for score_name, weight in score_dict.items():
        score_fns[score_name] = score_functions[score_name](device) if 'device' in score_functions[score_name].__code__.co_varnames else score_functions[score_name]()

    def _fn(images, prompts, metadata, only_strict=True):
        total_scores = []
        score_details = {}
        
        for score_name, weight in score_dict.items():
            if score_name == "geneval":
                scores, rewards, strict_rewards, group_rewards, group_strict_rewards = score_fns[score_name](images, prompts, metadata, only_strict)
                score_details['accuracy'] = rewards
                score_details['strict_accuracy'] = strict_rewards
                for key, value in group_strict_rewards.items():
                    score_details[f'{key}_strict_accuracy'] = value
                for key, value in group_rewards.items():
                    score_details[f'{key}_accuracy'] = value
            else:
                scores, rewards = score_fns[score_name](images, prompts, metadata)
            score_details[score_name] = scores
            weighted_scores = [weight * score for score in scores]
            
            if not total_scores:
                total_scores = weighted_scores
            else:
                total_scores = [total + weighted for total, weighted in zip(total_scores, weighted_scores)]
        
        score_details['avg'] = total_scores
        return score_details, {}

    return _fn

def main():
    import torchvision.transforms as transforms

    image_paths = [
        "nasa.jpg",
    ]

    transform = transforms.Compose([
        transforms.ToTensor(),
    ])

    images = torch.stack([transform(Image.open(image_path).convert('RGB')) for image_path in image_paths])
    prompts=[
        'A astronaut’s glove floating in zero-g with "NASA 2049" on the wrist',
    ]
    metadata = {}
    score_dict = {
        "unifiedreward": 1.0
    }
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    scoring_fn = multi_score(device, score_dict)
    scores, _ = scoring_fn(images, prompts, metadata)
    print("Scores:", scores)


if __name__ == "__main__":
    main()