from PIL import Image
import io
import numpy as np
import torch
from collections import defaultdict

def latent_score_sd3(device):
    import sys
    import os
    sys.path.append("/workspace/user_code/DiffusionDPO/LPO/lrm/lrm_sd3_score")
    from flow_grpo.SLRM_scorer import SLRMScorer

    model_path = "/cfs/cfs-1dafgugv/connorxian/SD3_output/logs/lrm/reward_model/step_sd3_variable-t_lr1e-5_step-8000_nocfg_0901/checkpoint-gstep155000"

    # Move to device and set to eval mode
    scorer = SLRMScorer(ckpt_path=model_path, timesteps=10)
    scorer = scorer.to(device=device)
    scorer.eval()
    scorer.requires_grad_(False)
    
    def _fn(pre_latents, prompts, metadata, timesteps):
        # if isinstance(images, torch.Tensor):
        #     images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
        #     images = images.transpose(0, 2, 3, 1)  # NCHW -> NHWC
        #     images = [Image.fromarray(image) for image in images]
        t_cond = timesteps
        latent = pre_latents
        scores, _ = scorer(latent, prompts, metadata, t_cond)
        return scores, {}

    return _fn