import imp
from turtle import mode
from transformers import AutoProcessor, AutoModel
from PIL import Image
import torch
from omegaconf import OmegaConf
import sys
sys.path.append("/workspace/user_code/DiffusionDPO/LPO/lrm/lrm_sd3_score")
from trainer.models.sd3_base_preference_model import sd3_base_preference_model, SD3BasePreferenceModelConfig

class SLRMScorer(torch.nn.Module):
    def __init__(self, dtype=torch.bfloat16,ckpt_path = None ,timesteps=None):
        super().__init__()
        model_cfg = SD3BasePreferenceModelConfig()
        model_cfg.total_timesteps = timesteps
        model_cfg.score_model = True
        self.scorer = sd3_base_preference_model(model_cfg).to(dtype=dtype)
        self.scorer.load(ckpt_path)
        self.timesteps = self.scorer.timesteps

    @torch.no_grad()
    def forward(self, images, prompts, metadata, timesteps):
        # 找到timesteps对应的下表为u
        u = torch.where(self.timesteps.to(images.device) == timesteps.to(images.device))[0]
        u = torch.cat([u] * images.shape[0], dim=0)
        scores = self.scorer.get_latent_preference_scores(prompts, images, u, generator=None)
        return scores


